/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.common.session.filters;

import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.cipher.Cipher;
import org.apache.sshd.common.cipher.CipherNone;
import org.apache.sshd.common.filter.InputHandler;
import org.apache.sshd.common.filter.IoFilter;
import org.apache.sshd.common.filter.OutputHandler;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.mac.Mac;
import org.apache.sshd.common.random.Random;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.session.filters.CryptStatisticsProvider;
import org.apache.sshd.common.session.filters.ThreadLocalRandom;
import org.apache.sshd.common.util.Readable;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CryptFilter
extends IoFilter
implements CryptStatisticsProvider {
    public static final int MAX_PADDING = 127;
    public static final int MAX_TAG_LENGTH = 64;
    private static final Logger LOG = LoggerFactory.getLogger(CryptFilter.class);
    private static final int MIN_PACKET_LENGTH = 8;
    private static final int UNKNOWN_PACKET_LENGTH = -1;
    private final AtomicReference<Settings> decryption = new AtomicReference();
    private final AtomicReference<Settings> encryption = new AtomicReference();
    private final AtomicReference<Counters> inCounts = new AtomicReference();
    private final AtomicReference<Counters> outCounts = new AtomicReference();
    private final DecryptionHandler input = new DecryptionHandler();
    private final EncryptionHandler output = new EncryptionHandler();
    private final CopyOnWriteArrayList<EncryptionListener> listeners = new CopyOnWriteArrayList();
    private Random random = ThreadLocalRandom.INSTANCE;
    private Session session;

    public CryptFilter() {
        this.decryption.set(new Settings(null, null));
        this.encryption.set(new Settings(null, null));
        this.inCounts.set(new Counters());
        this.outCounts.set(new Counters());
    }

    public void setRandom(Random random) {
        this.random = random;
    }

    public void setSession(Session session) {
        this.session = session;
    }

    @Override
    public InputHandler in() {
        return this.input;
    }

    @Override
    public OutputHandler out() {
        return this.output;
    }

    public void resetInputCounters() {
        this.inCounts.set(new Counters());
    }

    public void resetOutputCounters() {
        this.outCounts.set(new Counters());
    }

    @Override
    public Counters getInputCounters() {
        return this.inCounts.get();
    }

    @Override
    public Counters getOutputCounters() {
        return this.outCounts.get();
    }

    public void setInput(Settings settings, boolean resetSequence) {
        this.decryption.set(Objects.requireNonNull(settings));
        if (resetSequence) {
            this.input.sequenceNumber.set(0);
        }
    }

    public void setOutput(Settings settings, boolean resetSequence) {
        this.encryption.set(Objects.requireNonNull(settings));
        if (resetSequence) {
            this.output.sequenceNumber.set(0);
        }
    }

    public Settings getInputSettings() {
        return this.decryption.get();
    }

    public Settings getOutputSettings() {
        return this.encryption.get();
    }

    @Override
    public long getLastInputSequenceNumber() {
        return (long)(this.input.sequenceNumber.get() - 1) & 0xFFFFFFFFL;
    }

    @Override
    public long getInputSequenceNumber() {
        return (long)this.input.sequenceNumber.get() & 0xFFFFFFFFL;
    }

    @Override
    public long getOutputSequenceNumber() {
        return (long)this.output.sequenceNumber.get() & 0xFFFFFFFFL;
    }

    @Override
    public boolean isSecure() {
        return this.decryption.get().isSecure() && this.encryption.get().isSecure();
    }

    public void addEncryptionListener(EncryptionListener listener) {
        this.listeners.addIfAbsent(Objects.requireNonNull(listener));
    }

    public void removeEncryptionListener(EncryptionListener listener) {
        if (listener != null) {
            this.listeners.remove(listener);
        }
    }

    private class DecryptionHandler
    extends WithSequenceNumber
    implements InputHandler {
        private Buffer buffer;
        private int packetLength;
        private SshException discarding;

        DecryptionHandler() {
            this.buffer = new ByteArrayBuffer();
            this.packetLength = -1;
        }

        @Override
        public synchronized void received(Readable message) throws Exception {
            this.buffer.putBuffer(message);
            while (this.buffer.available() >= 4) {
                int bytes;
                int cipherSize;
                Settings settings = (Settings)CryptFilter.this.decryption.get();
                Cipher cipher = settings.getCipher();
                boolean isAead = cipher != null && settings.isAead();
                boolean isEtm = settings.isEtm();
                int n = cipherSize = cipher == null ? 8 : cipher.getCipherBlockSize();
                if (this.packetLength < 0) {
                    assert (this.buffer.rpos() == 0);
                    int need = 4;
                    if (cipher != null && !isEtm && !isAead) {
                        need = cipherSize;
                    }
                    if (this.buffer.available() < need) break;
                    if (cipher != null) {
                        byte[] data = this.buffer.array();
                        if (isAead) {
                            cipher.updateAAD(data, 0, 4);
                        } else if (!isEtm) {
                            cipher.update(data, 0, need);
                        }
                    }
                    this.packetLength = this.buffer.getInt();
                    boolean lengthOK = true;
                    if (this.packetLength < 8 || this.packetLength > 262144) {
                        LOG.warn("received({}) Error decoding packet (invalid length): {}", (Object)CryptFilter.this.session, (Object)this.packetLength);
                        lengthOK = false;
                    } else if ((this.packetLength + (isAead || isEtm ? 0 : 4) & cipherSize - 1) != 0) {
                        LOG.warn("received({}) Error decoding packet(padding; not multiple of {}): {}", new Object[]{CryptFilter.this.session, cipherSize, this.packetLength});
                        lengthOK = false;
                    }
                    if (!lengthOK) {
                        this.discarding = new SshException(2, "Invalid packet length: " + this.packetLength);
                        this.packetLength = this.buffer.available() + (2 + CryptFilter.this.random.random(20)) * cipherSize;
                        this.packetLength = this.packetLength + (cipherSize - 1) & ~(cipherSize - 1);
                        if (!isAead && !isEtm) {
                            this.packetLength -= 4;
                        }
                        LOG.warn("received({}) Invalid packet length; requesting {} bytes before disconnecting", (Object)CryptFilter.this.session, (Object)(this.packetLength - this.buffer.available()));
                    }
                }
                assert (this.buffer.rpos() == 4);
                if (this.buffer.available() < this.packetLength + settings.getTagSize()) break;
                byte[] data = this.buffer.array();
                if (isAead) {
                    bytes = this.packetLength;
                    cipher.update(data, 4, bytes);
                } else if (isEtm) {
                    bytes = this.packetLength;
                    this.checkMac(data, 0, bytes + 4, settings.getMac());
                    if (cipher != null) {
                        cipher.update(data, 4, bytes);
                    }
                } else {
                    bytes = this.packetLength + 4;
                    if (cipher != null) {
                        cipher.update(data, cipherSize, bytes - cipherSize);
                    }
                    this.checkMac(data, 0, bytes, settings.getMac());
                }
                if (this.discarding != null) {
                    throw this.discarding;
                }
                ((Counters)CryptFilter.this.inCounts.get()).update(bytes / cipherSize, bytes);
                this.sequenceNumber.incrementAndGet();
                int endOfDataReceived = this.buffer.wpos();
                int afterPacket = this.packetLength + 4 + settings.getTagSize();
                int padding = this.buffer.getUByte();
                if (padding < 4) {
                    throw new SshException(2, "Invalid packet padding, must have at least 4 padding bytes according to RFC 4253, got " + padding);
                }
                int endOfPayload = this.packetLength + 4 - padding;
                if (endOfPayload <= this.buffer.rpos()) {
                    throw new SshException(2, "Invalid packet payload length " + (this.buffer.rpos() - endOfPayload));
                }
                this.buffer.wpos(endOfPayload);
                CryptFilter.this.owner().passOn(this.buffer);
                this.buffer.rpos(afterPacket);
                this.buffer.wpos(endOfDataReceived);
                this.buffer.compact();
                this.packetLength = -1;
            }
        }

        private void checkMac(byte[] data, int offset, int length, Mac mac) throws Exception {
            if (mac != null) {
                mac.updateUInt(this.sequenceNumber.get());
                mac.update(data, offset, length);
                byte[] x = mac.doFinal();
                if (!Mac.equals(x, 0, data, offset + length, x.length)) {
                    throw new SshException(5, "MAC error");
                }
            }
        }
    }

    private class EncryptionHandler
    extends WithSequenceNumber
    implements OutputHandler {
        EncryptionHandler() {
        }

        @Override
        public synchronized IoWriteFuture send(int cmd, Buffer message) throws IOException {
            Buffer encrypted = message;
            if (encrypted != null) {
                try {
                    CryptFilter.this.listeners.forEach(listener -> listener.aboutToEncrypt(message, CryptFilter.this.getOutputSequenceNumber()));
                    encrypted = this.encode(cmd, message);
                }
                catch (IOException e) {
                    throw e;
                }
                catch (Exception e) {
                    throw new IOException(e.getMessage(), e);
                }
            }
            return CryptFilter.this.owner().send(cmd, encrypted);
        }

        private Buffer encode(int cmd, Buffer packet) throws Exception {
            int bytes;
            Settings settings = (Settings)CryptFilter.this.encryption.get();
            Cipher cipher = settings.getCipher();
            boolean isAead = cipher != null && settings.isAead();
            boolean isEtm = settings.isEtm();
            int cipherSize = cipher == null ? 8 : cipher.getCipherBlockSize();
            int rpos = packet.rpos();
            int length = packet.available();
            int start = rpos - 5;
            if (start < 0) {
                throw new IllegalArgumentException("Message is not an SSH packet buffer; need 5 spare bytes at the front");
            }
            int pad = this.paddingLength(cmd, length, cipherSize, !isAead && !isEtm);
            if (pad < 4 || pad > 127) {
                throw new IllegalStateException("Invalid packet length computed: " + pad + " not in range [4..255]");
            }
            packet.wpos(start);
            packet.putUInt(1L + (long)length + (long)pad);
            packet.putByte((byte)pad);
            int tagSize = settings.getTagSize();
            packet.wpos(packet.wpos() + length + pad + tagSize);
            byte[] data = packet.array();
            CryptFilter.this.random.fill(data, packet.wpos() - tagSize - pad, pad);
            if (isAead) {
                cipher.updateAAD(data, start, 4);
                bytes = length + pad + 1;
                cipher.update(data, start + 4, bytes);
            } else if (isEtm) {
                bytes = length + pad + 1;
                if (cipher != null) {
                    cipher.update(data, start + 4, bytes);
                }
                this.appendMac(data, start, packet.wpos() - tagSize, settings.getMac());
            } else {
                this.appendMac(data, start, packet.wpos() - tagSize, settings.getMac());
                bytes = length + pad + 5;
                if (cipher != null) {
                    cipher.update(data, start, bytes);
                }
            }
            ((Counters)CryptFilter.this.outCounts.get()).update(bytes / cipherSize, bytes);
            this.sequenceNumber.incrementAndGet();
            packet.rpos(start);
            return packet;
        }

        private int paddingLength(int cmd, int payloadLength, int blockSize, boolean includePacketLength) {
            int totalLength;
            int toEncrypt = payloadLength + 1;
            if (includePacketLength) {
                toEncrypt += 4;
            }
            int minPadding = 4;
            if (cmd >= 60 && cmd <= 66) {
                minPadding = 64;
            }
            int pad = minPadding;
            if (cmd >= 20) {
                pad = minPadding + CryptFilter.this.random.random(128 - minPadding);
            }
            if ((pad = ((totalLength = toEncrypt + pad) & ~(blockSize - 1)) - toEncrypt) < 4) {
                pad += blockSize;
            }
            return pad;
        }

        private void appendMac(byte[] data, int start, int end, Mac mac) throws Exception {
            if (mac != null) {
                mac.updateUInt(this.sequenceNumber.get());
                mac.update(data, start, end - start);
                mac.doFinal(data, end);
            }
        }
    }

    public static class Settings {
        private final Cipher cipher;
        private final Mac mac;
        private final int tagSize;
        private final boolean etm;
        private final boolean aead;

        public Settings(Cipher cipher, Mac mac) {
            this.cipher = cipher;
            this.mac = mac;
            int tagSz = 0;
            if (cipher != null) {
                tagSz += cipher.getAuthenticationTagSize();
            }
            boolean bl = this.aead = tagSz > 0;
            if (this.aead && mac != null) {
                throw new IllegalStateException("AEAD cipher " + cipher + " must not have a MAC: " + mac);
            }
            if (mac != null) {
                tagSz += mac.getBlockSize();
            }
            this.tagSize = tagSz;
            this.etm = mac != null && mac.isEncryptThenMac();
        }

        public Cipher getCipher() {
            return this.cipher;
        }

        public Mac getMac() {
            return this.mac;
        }

        public int getTagSize() {
            return this.tagSize;
        }

        public boolean isEtm() {
            return this.etm;
        }

        public boolean isAead() {
            return this.aead;
        }

        public boolean isSecure() {
            return this.cipher != null && !(this.cipher instanceof CipherNone) && this.tagSize > 0;
        }
    }

    public static class Counters
    implements CryptStatisticsProvider.Counters {
        private AtomicLong bytes = new AtomicLong();
        private AtomicLong blocks = new AtomicLong();
        private AtomicLong packets = new AtomicLong();

        Counters() {
        }

        public void update(int blocks, int bytes) {
            this.blocks.addAndGet(blocks);
            this.bytes.addAndGet(bytes);
            this.packets.incrementAndGet();
        }

        @Override
        public long getBytes() {
            return this.bytes.get();
        }

        @Override
        public long getBlocks() {
            return this.blocks.get();
        }

        @Override
        public long getPackets() {
            return this.packets.get();
        }
    }

    public static interface EncryptionListener {
        public void aboutToEncrypt(Readable var1, long var2);
    }

    private abstract class WithSequenceNumber {
        final AtomicInteger sequenceNumber = new AtomicInteger();

        WithSequenceNumber() {
        }
    }
}

