/*
 * Decompiled with CFR 0.152.
 */
package org.apache.gobblin.crypto;

import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import javax.xml.bind.DatatypeConverter;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.io.IOUtils;
import org.apache.gobblin.codec.Base64Codec;
import org.apache.gobblin.codec.StreamCodec;
import org.apache.gobblin.crypto.CredentialStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RotatingAESCodec
implements StreamCodec {
    private static final Logger log = LoggerFactory.getLogger(RotatingAESCodec.class);
    private static final int AES_KEY_LEN = 16;
    private static final String TAG = "aes_rotating";
    private final Random random;
    private final CredentialStore credentialStore;
    private volatile Map<Integer, KeyRecord> keyRecords_cache;
    private volatile KeyRecord[] keyRecords_cache_arr;

    public RotatingAESCodec(CredentialStore credentialStore) {
        this.credentialStore = credentialStore;
        this.random = new Random();
    }

    public OutputStream encodeOutputStream(OutputStream origStream) throws IOException {
        return new EncodingStreamInstance(this.selectRandomKey(), origStream).wrapOutputStream();
    }

    public InputStream decodeInputStream(InputStream origStream) throws IOException {
        return new DecodingStreamInstance(origStream).wrapInputStream();
    }

    private synchronized KeyRecord getKey(Integer key) {
        this.fillKeyRecords();
        return this.keyRecords_cache.get(key);
    }

    private synchronized KeyRecord selectRandomKey() {
        KeyRecord[] keyRecords = this.getKeyRecords();
        if (keyRecords.length == 0) {
            throw new IllegalStateException("Couldn't find any valid keys in store!");
        }
        return keyRecords[this.random.nextInt(keyRecords.length)];
    }

    private synchronized KeyRecord[] getKeyRecords() {
        this.fillKeyRecords();
        return this.keyRecords_cache_arr;
    }

    private synchronized void fillKeyRecords() {
        if (this.keyRecords_cache == null) {
            this.keyRecords_cache = new HashMap<Integer, KeyRecord>();
            for (Map.Entry entry : this.credentialStore.getAllEncodedKeys().entrySet()) {
                if (((byte[])entry.getValue()).length != 16) {
                    log.debug("Skipping keyId {} because it is length {}; expected {}", new Object[]{entry.getKey(), ((byte[])entry.getValue()).length, 16});
                    continue;
                }
                try {
                    Integer keyId = Integer.parseInt((String)entry.getKey());
                    SecretKeySpec key = new SecretKeySpec((byte[])entry.getValue(), "AES");
                    this.keyRecords_cache.put(keyId, new KeyRecord(keyId, key));
                }
                catch (NumberFormatException e) {
                    log.debug("Skipping keyId {} because this algorithm can only use numeric key ids", entry.getKey());
                }
            }
            this.keyRecords_cache_arr = this.keyRecords_cache.values().toArray(new KeyRecord[this.keyRecords_cache.size()]);
        }
    }

    public String getTag() {
        return TAG;
    }

    private class DecodingStreamInstance {
        private final InputStream origStream;
        private final byte[] buffer = new byte[32];
        private final Cipher cipher;

        DecodingStreamInstance(InputStream origStream) throws IOException {
            this.origStream = origStream;
            Integer keyId = this.readKey();
            KeyRecord key = RotatingAESCodec.this.getKey(keyId);
            if (key == null) {
                throw new IOException("Cannot load key " + String.valueOf(keyId) + " which is specified in input stream");
            }
            try {
                byte[] iv = this.readIv();
                this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
                if (iv != null) {
                    IvParameterSpec ivParameterSpec = new IvParameterSpec(iv);
                    this.cipher.init(2, (Key)key.getSecretKey(), ivParameterSpec);
                } else {
                    this.cipher.init(2, key.getSecretKey());
                }
            }
            catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
                throw new IllegalStateException("Failed to load AES which should never happen", e);
            }
            catch (InvalidKeyException e) {
                throw new IllegalStateException("Failed to parse key from keystore", e);
            }
            catch (InvalidAlgorithmParameterException e) {
                throw new IllegalStateException("Failed to initialize IV", e);
            }
        }

        InputStream wrapInputStream() throws IOException {
            InputStream base64Decoder = new Base64Codec().decodeInputStream(this.origStream);
            return new CipherInputStream(base64Decoder, this.cipher);
        }

        private Integer readKey() throws IOException {
            IOUtils.readFully((InputStream)this.origStream, (byte[])this.buffer, (int)0, (int)4);
            try {
                return Integer.valueOf(new String(this.buffer, 0, 4, StandardCharsets.UTF_8));
            }
            catch (NumberFormatException e) {
                throw new IOException("Expected to be able to parse first 4 bytes of stream as an ASCII keyId");
            }
        }

        private byte[] readIv() throws IOException {
            Integer ivLen;
            IOUtils.readFully((InputStream)this.origStream, (byte[])this.buffer, (int)0, (int)3);
            try {
                ivLen = Integer.valueOf(new String(this.buffer, 0, 3, StandardCharsets.UTF_8));
            }
            catch (NumberFormatException e) {
                throw new IOException("Expected to parse next 3 bytes of stream as an IV len");
            }
            if (ivLen < 0 || ivLen > this.buffer.length) {
                throw new IOException("Corrupted data suspected; expected IVLen to be between 0 and " + String.valueOf(this.buffer.length) + ", read " + String.valueOf(ivLen));
            }
            if (ivLen == 0) {
                return null;
            }
            byte[] ivBuffer = new byte[ivLen.intValue()];
            IOUtils.readFully((InputStream)this.origStream, (byte[])ivBuffer, (int)0, (int)ivBuffer.length);
            return Base64.decodeBase64((byte[])ivBuffer);
        }
    }

    static class EncodingStreamInstance {
        private final OutputStream origStream;
        private final KeyRecord secretKey;
        private Cipher cipher;
        private String base64Iv;
        private boolean headerWritten = false;

        EncodingStreamInstance(KeyRecord secretKey, OutputStream origStream) {
            this.secretKey = secretKey;
            this.origStream = origStream;
        }

        OutputStream wrapOutputStream() throws IOException {
            this.initCipher();
            OutputStream base64OutputStream = this.getBase64Stream(this.origStream);
            final CipherOutputStream encryptedStream = new CipherOutputStream(base64OutputStream, this.cipher);
            return new FilterOutputStream(this.origStream){

                @Override
                public void write(int b) throws IOException {
                    this.writeHeaderIfNecessary();
                    encryptedStream.write(b);
                }

                @Override
                public void write(byte[] b) throws IOException {
                    this.writeHeaderIfNecessary();
                    encryptedStream.write(b);
                }

                @Override
                public void write(byte[] b, int off, int len) throws IOException {
                    this.writeHeaderIfNecessary();
                    encryptedStream.write(b, off, len);
                }

                @Override
                public void close() throws IOException {
                    encryptedStream.close();
                }
            };
        }

        private OutputStream getBase64Stream(OutputStream origStream) throws IOException {
            return new Base64Codec().encodeOutputStream(origStream);
        }

        private void initCipher() {
            if (this.origStream == null) {
                throw new IllegalStateException("Can't initCipher stream before encodeOutputStream() has been called!");
            }
            try {
                this.cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
                this.cipher.init(1, this.secretKey.getSecretKey());
                byte[] iv = this.cipher.getIV();
                this.base64Iv = DatatypeConverter.printBase64Binary((byte[])iv);
                this.headerWritten = false;
            }
            catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
                throw new IllegalStateException("Error creating AES algorithm? Should always exist in JRE");
            }
            catch (InvalidKeyException e) {
                throw new IllegalStateException("Key " + this.secretKey.getKeyId() + " is illegal - please check credential store");
            }
        }

        private void writeHeaderIfNecessary() throws IOException {
            if (!this.headerWritten) {
                String header = String.format("%04d%03d%s", this.secretKey.getKeyId(), this.base64Iv.length(), this.base64Iv);
                this.origStream.write(header.getBytes(StandardCharsets.UTF_8));
                this.headerWritten = true;
            }
        }
    }

    static class KeyRecord {
        private final int keyId;
        private final SecretKey secretKey;

        KeyRecord(int keyId, SecretKey secretKey) {
            this.keyId = keyId;
            this.secretKey = secretKey;
        }

        int getKeyId() {
            return this.keyId;
        }

        SecretKey getSecretKey() {
            return this.secretKey;
        }
    }
}

