package net.deweger.crypto;

import java.io.*;
import java.math.BigInteger;

// Basic AES-128 functionality
// Benne de Weger - TU/e - January 2021

// Counter mode uses only the first 8 bytes of the IV as nonce,
// and concatenates an 8 byte counter to it.

public class AES
{
    private int mode_min = 0;
    public static int MODE_ECB = 0;
    public static int MODE_CBC = 1;
    public static int MODE_CFB = 2;
    public static int MODE_OFB = 3;
    public static int MODE_CTR = 4;
    private int mode_max = 4;

    private int pad_min = 10;
    public static int PAD_NONE = 10;
    public static int PAD_PKCS7 = 11;
    public static int PAD_BIT = 12;
    private int pad_max = 12;

    private byte[] key = new byte[0];
    private byte[] iv = new byte[0];

    private byte[][] roundKeys;
    private byte[][] state;

    private int mode = mode_min;
    private int padding = pad_min;

    private byte[] sBox;
    private byte[] invsBox;
    private byte[] rcon;

    Util util;

    // constructor
    public AES() throws Exception
    {
        util = new Util();
        sBox = Util.hexToBytes(
            "637c777bf26b6fc53001672bfed7ab76ca82c97dfa5947f0add4a2af9ca472c0" +
            "b7fd9326363ff7cc34a5e5f171d8311504c723c31896059a071280e2eb27b275" +
            "09832c1a1b6e5aa0523bd6b329e32f8453d100ed20fcb15b6acbbe394a4c58cf" +
            "d0efaafb434d338545f9027f503c9fa851a3408f929d38f5bcb6da2110fff3d2" +
            "cd0c13ec5f974417c4a77e3d645d197360814fdc222a908846eeb814de5e0bdb" +
            "e0323a0a4906245cc2d3ac629195e479e7c8376d8dd54ea96c56f4ea657aae08" +
            "ba78252e1ca6b4c6e8dd741f4bbd8b8a703eb5664803f60e613557b986c11d9e" +
            "e1f8981169d98e949b1e87e9ce5528df8ca1890dbfe6426841992d0fb054bb16");
        invsBox = Util.hexToBytes(
            "52096ad53036a538bf40a39e81f3d7fb7ce339829b2fff87348e4344c4dee9cb" +
            "547b9432a6c2233dee4c950b42fac34e082ea16628d924b2765ba2496d8bd125" +
            "72f8f66486689816d4a45ccc5d65b6926c704850fdedb9da5e154657a78d9d84" +
            "90d8ab008cbcd30af7e45805b8b34506d02c1e8fca3f0f02c1afbd0301138a6b" +
            "3a9111414f67dcea97f2cfcef0b4e67396ac7422e7ad3585e2f937e81c75df6e" +
            "47f11a711d29c5896fb7620eaa18be1bfc563e4bc6d279209adbc0fe78cd5af4" +
            "1fdda8338807c731b11210592780ec5f60517fa919b54a0d2de57a9f93c99cef" +
            "a0e03b4dae2af5b0c8ebbb3c83539961172b047eba77d626e169146355210c7d");
        rcon = Util.hexToBytes("01020408102040801b36");
    }

    // set key
    public void setKey(byte[] k) throws Exception
    {
        if (k.length == 16)
        {
            key = k;
            keyExpansion();
        }
        else throw new Exception("net.deweger.crypto.AES.setKey: illegal key length: " + k.length);
    }

    // set iv
    public void setIv(byte[] i) throws Exception
    {
        if (i.length == 16)
            iv = i;
        else throw new Exception("net.deweger.crypto.AES.setIv: illegal iv length: " + i.length);
    }

    // set mode
    public void setMode(int m) throws Exception
    {
        if (m >= mode_min && mode <= mode_max)
            mode = m;
        else throw new Exception("net.deweger.crypto.AES.setMode: illegal mode: " + m);
    }

    // set padding
    public void setPadding(int p) throws Exception
    {
        if (p >= pad_min && p <= pad_max)
            padding = p;
        else throw new Exception("net.deweger.crypto.AES.setMode: illegal padding: " + p);
    }

    // generate and set random key
    public byte[] generateKey() throws Exception
    {
        key = util.randomBytes(16);
        setKey(key);
        return key;
    }

    // generate and set random iv
    public byte[] generateIv() throws Exception
    {
        iv = util.randomBytes(16);
        setIv(iv);
        return iv;
    }

    // get key
    public byte[] getKey() throws Exception
    {
        if (key.length != 16)
            throw new Exception("net.deweger.crypto.AES.getKey: illegal key length: " + key.length);
        return key;
    }

    // get iv
    public byte[] getIv() throws Exception
    {
        if (iv.length != 16)
            throw new Exception("net.deweger.crypto.AES.getIv: illegal iv length: " + iv.length);
        return iv;
    }

    // get mode
    public int getMode() throws Exception
    {
        if (mode < mode_min || mode > mode_max)
            throw new Exception("net.deweger.crypto.AES.getMode: illegal mode: " + mode);
        return mode;
    }

    // get padding
    public int getPadding() throws Exception
    {
        if (padding < pad_min || padding > pad_max)
            throw new Exception("net.deweger.crypto.AES.getMode: illegal padding: " + padding);
        return padding;
    }

    // key expansion
    private void keyExpansion()
    {
        roundKeys = new byte[44][4];
        for (int i = 0; i <= 3; i++)
            for (int j = 0; j <= 3; j++)
                roundKeys[i][j] = key[4*i+j];
        for (int i = 4; i <= 43; i++)
        {
            if (i % 4 == 0)
            {
                roundKeys[i][0] = (byte)(sBox[Byte.toUnsignedInt(roundKeys[i-1][1])] ^ roundKeys[i-4][0] ^ rcon[i/4-1]);
                roundKeys[i][1] = (byte)(sBox[Byte.toUnsignedInt(roundKeys[i-1][2])] ^ roundKeys[i-4][1]);
                roundKeys[i][2] = (byte)(sBox[Byte.toUnsignedInt(roundKeys[i-1][3])] ^ roundKeys[i-4][2]);
                roundKeys[i][3] = (byte)(sBox[Byte.toUnsignedInt(roundKeys[i-1][0])] ^ roundKeys[i-4][3]);
            }
            else
            {
                for (int j = 0; j <= 3; j++)
                    roundKeys[i][j] = (byte)(roundKeys[i-1][j] ^ roundKeys[i-4][j]);
            }
        }
    }

    // subBytes
    private void subBytes()
    {
        for (int r = 0; r <= 3; r++)
            for (int c = 0; c <= 3; c++)
                state[r][c] = sBox[Byte.toUnsignedInt(state[r][c])];
    }

    // subBytesInverse
    private void subBytesInverse()
    {
        for (int r = 0; r <= 3; r++)
            for (int c = 0; c <= 3; c++)
                state[r][c] = invsBox[Byte.toUnsignedInt(state[r][c])];
    }

    // shiftRows
    private void shiftRows()
    {
        byte aux = state[1][0];
        state[1][0] = state[1][1];
        state[1][1] = state[1][2];
        state[1][2] = state[1][3];
        state[1][3] = aux;
        aux = state[2][0];
        state[2][0] = state[2][2];
        state[2][2] = aux;
        aux = state[2][1];
        state[2][1] = state[2][3];
        state[2][3] = aux;
        aux = state[3][0];
        state[3][0] = state[3][3];
        state[3][3] = state[3][2];
        state[3][2] = state[3][1];
        state[3][1] = aux;
    }

    // shiftRowsInverse
    private void shiftRowsInverse()
    {
        byte aux = state[1][0];
        state[1][0] = state[1][3];
        state[1][3] = state[1][2];
        state[1][2] = state[1][1];
        state[1][1] = aux;
        aux = state[2][0];
        state[2][0] = state[2][2];
        state[2][2] = aux;
        aux = state[2][1];
        state[2][1] = state[2][3];
        state[2][3] = aux;
        aux = state[3][0];
        state[3][0] = state[3][1];
        state[3][1] = state[3][2];
        state[3][2] = state[3][3];
        state[3][3] = aux;
    }

    // multiply by k (should be only 4 bits)
    private byte mul(byte b, int n)
    {
        byte x = 0;
        if (n % 2 == 1)
            x = b;
        if ((n/2) % 2 == 1)
        {
            x = (byte)(x ^ (b << 1));
            if ((b & 0x80) != 0)
                x = (byte)(x ^ 0x1b);
        }
        if ((n/4) % 2 == 1)
        {
            x = (byte)(x ^ (b << 2));
            if ((b & 0x80) != 0)
                x = (byte)(x ^ 0x36);
            if ((b & 0x40) != 0)
                x = (byte)(x ^ 0x1b);
        }
        if ((n/8) % 2 == 1)
        {
            x = (byte)(x ^ (b << 3));
            if ((b & 0x80) != 0)
                x = (byte)(x ^ 0x6c);
            if ((b & 0x40) != 0)
                x = (byte)(x ^ 0x36);
            if ((b & 0x20) != 0)
                x = (byte)(x ^ 0x1b);
        }
        return x;
    }

    // mixColumns
    private void mixColumns()
    {
        byte s[] = new byte[4];
        for (int c = 0; c <= 3; c++)
        {
            s[0] = (byte)(mul(state[0][c],2) ^ mul(state[1][c],3) ^ state[2][c] ^ state[3][c]);
            s[1] = (byte)(state[0][c] ^ mul(state[1][c],2) ^ mul(state[2][c],3) ^ state[3][c]);
            s[2] = (byte)(state[0][c] ^ state[1][c] ^ mul(state[2][c],2) ^ mul(state[3][c],3));
            s[3] = (byte)(mul(state[0][c],3) ^ state[1][c] ^ state[2][c] ^ mul(state[3][c],2));
            state[0][c] = s[0];
            state[1][c] = s[1];
            state[2][c] = s[2];
            state[3][c] = s[3];
        }
    }

    // mixColumnsInverse
    private void mixColumnsInverse()
    {
        byte s[] = new byte[4];
        for (int c = 0; c <= 3; c++)
        {
            s[0] = (byte)(mul(state[0][c],14) ^ mul(state[1][c],11) ^ mul(state[2][c],13) ^ mul(state[3][c],9));
            s[1] = (byte)(mul(state[0][c],9) ^ mul(state[1][c],14) ^ mul(state[2][c],11) ^ mul(state[3][c],13));
            s[2] = (byte)(mul(state[0][c],13) ^ mul(state[1][c],9) ^ mul(state[2][c],14) ^ mul(state[3][c],11));
            s[3] = (byte)(mul(state[0][c],11) ^ mul(state[1][c],13) ^ mul(state[2][c],9) ^ mul(state[3][c],14));
            state[0][c] = s[0];
            state[1][c] = s[1];
            state[2][c] = s[2];
            state[3][c] = s[3];
        }
    }

    // addRoundKey
    private void addRoundKey(int round)
    {
        for (int r = 0; r <= 3; r++)
            for (int c = 0; c <= 3; c++)
                state[r][c] = (byte)(state[r][c] ^ roundKeys[4*round+c][r]);
    }

    // encrypt block
    private byte[] encryptBlock(byte[] b) throws Exception
    {
        byte[] ct = null;
        if (b.length == 16)
        {
            if (key.length == 16)
            {
                state = new byte[4][4];
                for (int r = 0; r <= 3; r++)
                    for (int c = 0; c <= 3; c++)
                        state[r][c] = b[r+4*c];
                addRoundKey(0);
                for (int i = 1; i <= 10; i++)
                {
                    subBytes();
                    shiftRows();
                    if (i < 10)
                    {
                        mixColumns();
                    }
                    addRoundKey(i);
                }
                ct = new byte[16];
                for (int r = 0; r <= 3; r++)
                    for (int c = 0; c <= 3; c++)
                        ct[r+4*c] = state[r][c];
            }
            else throw new Exception("net.deweger.crypto.AES.encryptBlock: illegal key size: " + key.length);
        }
        else throw new Exception("net.deweger.crypto.AES.encryptBlock: illegal block size: " + b.length);
        return ct;
    }

    // decrypt block
    private byte[] decryptBlock(byte[] b) throws Exception
    {
        byte[] pt = null;
        if (b.length == 16)
        {
            if (key.length == 16)
            {
                state = new byte[4][4];
                for (int r = 0; r <= 3; r++)
                    for (int c = 0; c <= 3; c++)
                        state[r][c] = b[r+4*c];
                addRoundKey(10);
                for (int i = 10; i >= 1; i--)
                {
                    shiftRowsInverse();
                    subBytesInverse();
                    addRoundKey(i-1);
                    if (i > 1)
                    {
                        mixColumnsInverse();
                    }
                }
                pt = new byte[16];
                for (int r = 0; r <= 3; r++)
                    for (int c = 0; c <= 3; c++)
                        pt[r+4*c] = state[r][c];
            }
            else throw new Exception("net.deweger.crypto.AES.decryptBlock: illegal key size: " + key.length);
        }
        else throw new Exception("net.deweger.crypto.AES.decryptBlock: illegal block size: " + b.length);
        return pt;
    }

    // encrypt data
    // for the time being only one block, no padding, no mode
    public byte[] encrypt(byte[] plaintext) throws Exception
    {
        byte[] ciphertext;
        byte[] plaintextBlock = new byte[16];
        byte[] ciphertextBlock = new byte[16];

        // set padded length
        int len = plaintext.length;
        int padlen = 16 - (len % 16);
        if (padlen == 0)
            padlen = 16;
        if (padding == PAD_NONE)
            padlen = 0;
        int paddedlength = len + padlen;

        // apply padding
        byte[] paddedplaintext = new byte[paddedlength];
        System.arraycopy(plaintext, 0, paddedplaintext, 0, len);

        if (padding == PAD_PKCS7)
        {
            for (int i = 1; i <= padlen; i++)
                paddedplaintext[paddedlength - i] = (byte)padlen;
        }
        if (padding == PAD_BIT)
        {
            for (int i = 1; i < padlen; i++)
                paddedplaintext[paddedlength - i] = (byte)0x00;
            paddedplaintext[paddedlength - padlen] = (byte)0x80;
        }

        // do the chaining
        ciphertext = new byte[paddedlength];
        if (paddedlength % 16 != 0)
            throw new Exception("net.deweger.crypto.AES.encrypt: illegal padded plaintext length: " + paddedlength);
        else
        {
            if (mode == MODE_ECB)
            {
                int offset = 0;
                while (offset < paddedlength)
                {
                    System.arraycopy(paddedplaintext, offset, plaintextBlock, 0, 16);
                    ciphertextBlock = encryptBlock(plaintextBlock);
                    System.arraycopy(ciphertextBlock, 0, ciphertext, offset, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_CBC)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.encrypt: illegal iv length: " + iv.length);
                int offset = 0;
                System.arraycopy(iv, 0, ciphertextBlock, 0, 16);
                while (offset < paddedlength)
                {
                    System.arraycopy(paddedplaintext, offset, plaintextBlock, 0, 16);
                    for (int i = 0; i <= 15; i++)
                        plaintextBlock[i] = (byte)(plaintextBlock[i] ^ ciphertextBlock[i]);
                    ciphertextBlock = encryptBlock(plaintextBlock);
                    System.arraycopy(ciphertextBlock, 0, ciphertext, offset, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_CFB)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.encrypt: illegal iv length: " + iv.length);
                int offset = 0;
                System.arraycopy(iv, 0, ciphertextBlock, 0, 16);
                while (offset < paddedlength)
                {
                    ciphertextBlock = encryptBlock(ciphertextBlock);
                    System.arraycopy(paddedplaintext, offset, plaintextBlock, 0, 16);
                    for (int i = 0; i <= 15; i++)
                        ciphertextBlock[i] = (byte)(plaintextBlock[i] ^ ciphertextBlock[i]);
                    System.arraycopy(ciphertextBlock, 0, ciphertext, offset, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_OFB)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.encrypt: illegal iv length: " + iv.length);
                int offset = 0;
                byte[] chainingBlock = new byte[16];
                System.arraycopy(iv, 0, chainingBlock, 0, 16);
                while (offset < paddedlength)
                {
                    chainingBlock = encryptBlock(chainingBlock);
                    System.arraycopy(paddedplaintext, offset, plaintextBlock, 0, 16);
                    for (int i = 0; i <= 15; i++)
                        ciphertextBlock[i] = (byte)(plaintextBlock[i] ^ chainingBlock[i]);
                    System.arraycopy(ciphertextBlock, 0, ciphertext, offset, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_CTR)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.encrypt: illegal iv length: " + iv.length);
                int offset = 0;
                byte[] chainingBlock = new byte[16];
                System.arraycopy(iv, 0, chainingBlock, 0, 8);
                BigInteger counter = BigInteger.ZERO;
                while (offset < paddedlength)
                {
                    byte[] ctr = counter.toByteArray();
                    System.arraycopy(ctr, 0, chainingBlock, 16-ctr.length, ctr.length);
                    ciphertextBlock = encryptBlock(chainingBlock);
                    System.arraycopy(paddedplaintext, offset, plaintextBlock, 0, 16);
                    for (int i = 0; i <= 15; i++)
                        ciphertextBlock[i] = (byte)(plaintextBlock[i] ^ ciphertextBlock[i]);
                    System.arraycopy(ciphertextBlock, 0, ciphertext, offset, 16);
                    offset += 16;
                    counter = counter.add(BigInteger.ONE);
                }
            }
        }

        return ciphertext;
    }

    // decrypt data
    // for the time being only one block, no padding, no mode
    public byte[] decrypt(byte[] ciphertext) throws Exception
    {
        byte[] plaintext;
        byte[] plaintextBlock = new byte[16];
        byte[] ciphertextBlock = new byte[16];

        int paddedlength = ciphertext.length;

        // do the chaining
        byte[] paddedplaintext = new byte[paddedlength];
        if (paddedlength % 16 != 0)
            throw new Exception("net.deweger.crypto.AES.decrypt: illegal ciphertext length: " + paddedlength);
        else
        {
            if (mode == MODE_ECB)
            {
                int offset = 0;
                while (offset < paddedlength)
                {
                    System.arraycopy(ciphertext, offset, ciphertextBlock, 0, 16);
                    plaintextBlock = decryptBlock(ciphertextBlock);
                    System.arraycopy(plaintextBlock, 0, paddedplaintext, offset, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_CBC)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.decrypt: illegal iv length: " + iv.length);
                int offset = 0;
                byte[] chainingBlock = new byte[16];
                System.arraycopy(iv, 0, chainingBlock, 0, 16);
                while (offset < paddedlength)
                {
                    System.arraycopy(ciphertext, offset, ciphertextBlock, 0, 16);
                    plaintextBlock = decryptBlock(ciphertextBlock);
                    for (int i = 0; i <= 15; i++)
                        plaintextBlock[i] = (byte)(plaintextBlock[i] ^ chainingBlock[i]);
                    System.arraycopy(plaintextBlock, 0, paddedplaintext, offset, 16);
                    System.arraycopy(ciphertextBlock, 0, chainingBlock, 0, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_CFB)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.decrypt: illegal iv length: " + iv.length);
                int offset = 0;
                byte[] chainingBlock = new byte[16];
                System.arraycopy(iv, 0, chainingBlock, 0, 16);
                while (offset < paddedlength)
                {
                    System.arraycopy(ciphertext, offset, ciphertextBlock, 0, 16);
                    plaintextBlock = encryptBlock(chainingBlock);
                    for (int i = 0; i <= 15; i++)
                        plaintextBlock[i] = (byte)(plaintextBlock[i] ^ ciphertextBlock[i]);
                    System.arraycopy(plaintextBlock, 0, paddedplaintext, offset, 16);
                    System.arraycopy(ciphertextBlock, 0, chainingBlock, 0, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_OFB)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.decrypt: illegal iv length: " + iv.length);
                int offset = 0;
                byte[] chainingBlock = new byte[16];
                System.arraycopy(iv, 0, chainingBlock, 0, 16);
                while (offset < paddedlength)
                {
                    chainingBlock = encryptBlock(chainingBlock);
                    System.arraycopy(ciphertext, offset, ciphertextBlock, 0, 16);
                    for (int i = 0; i <= 15; i++)
                        plaintextBlock[i] = (byte)(chainingBlock[i] ^ ciphertextBlock[i]);
                    System.arraycopy(plaintextBlock, 0, paddedplaintext, offset, 16);
                    offset += 16;
                }
            }
            if (mode == MODE_CTR)
            {
                if (iv.length != 16)
                    throw new Exception("net.deweger.crypto.AES.decrypt: illegal iv length: " + iv.length);
                int offset = 0;
                byte[] chainingBlock = new byte[16];
                System.arraycopy(iv, 0, chainingBlock, 0, 8);
                BigInteger counter = BigInteger.ZERO;
                while (offset < paddedlength)
                {
                    byte[] ctr = counter.toByteArray();
                    System.arraycopy(ctr, 0, chainingBlock, 16-ctr.length, ctr.length);
                    plaintextBlock = encryptBlock(chainingBlock);
                    System.arraycopy(ciphertext, offset, ciphertextBlock, 0, 16);
                    for (int i = 0; i <= 15; i++)
                        plaintextBlock[i] = (byte)(plaintextBlock[i] ^ ciphertextBlock[i]);
                    System.arraycopy(plaintextBlock, 0, paddedplaintext, offset, 16);
                    offset += 16;
                    counter = counter.add(BigInteger.ONE);
                }
            }
        }

        // find padded length
        int padlen = 0;
        if (padding == PAD_PKCS7)
        {
            padlen = paddedplaintext[paddedlength-1];
        	if (padlen <= 0 || padlen > paddedlength)
                throw new Exception("net.deweger.crypto.AES.decrypt: illegal padding (PKCS7)");
            for (int i = 1; i <= padlen; i++)
                if (paddedplaintext[paddedlength-i] != padlen)
                    throw new Exception("net.deweger.crypto.AES.decrypt: illegal padding (PKCS7)");
        }
        if (padding == PAD_BIT)
        {
            padlen = 1;
            while (paddedplaintext[paddedlength-padlen] == 0 && padlen < 16)
                padlen++;
        	if (padlen <= 0 || padlen > paddedlength)
                throw new Exception("net.deweger.crypto.AES.decrypt: illegal padding (BIT)");
            if (paddedplaintext[paddedlength-padlen] != -128)
                throw new Exception("net.deweger.crypto.AES.decrypt: illegal padding (BIT)");
        }

        // remove padding
        int len = paddedlength - padlen;
        plaintext = new byte[len];
        System.arraycopy(paddedplaintext, 0, plaintext, 0, len);

        return plaintext;
    }
}