// RSA
// Benne de Weger - TU/e - February 2021

// FAECTOR Workshop Cryptographic Programming
// Assignment 6, 7, 8, 10, 11, 12 - API:
//    // constructors
//    public RSA() throws Exception
//    public RSA(String id) throws Exception
//    // generate RSA key pair,
//    //     b = number of bits in the modulus
//    //     fermat4: e = 65537 or big random
//    //     crt: use CRT private key or not
//    public void generateRSAKeyPair(int b, boolean fermat4, boolean crt)
//    // write RSA public key to file fn
//    public void writeRSAPublicKey(String fn) throws Exception
//    // write RSA private key to file fn
//    public void writeRSAPrivateKey(String fn) throws Exception
//    // write encrypted RSA private key to file fn
//    public void secureWriteRSAPrivateKey(String password, String fn) throws Exception
//    // read RSA public key from file fn
//    public void readRSAPublicKey(String fn) throws Exception
//    // process RSA private key after reading
//    public void readRSAPrivateKey(String fn) throws Exception
//    // read encrypted RSA private key from file fn
//    public void secureReadRSAPrivateKey(String password, String fn) throws Exception
//    // RSA encryption with PKCS1v1.5-padding
//    public byte[] encryptKey(byte[] plaintext) throws Exception
//    // RSA decryption with PKCS1v1.5-padding
//    public byte[] decryptKey(byte[] ciphertext) throws Exception
//    // RSA signature generation
//    public byte[] generateSignature(byte[] message) throws Exception
//    // RSA signature verification
//    public boolean verifySignature(byte[] message, byte[] signature) throws Exception
//    // computes p and q from n, d and e, obviously only in the non-CRT case
//    public void factorFromKeyPair() throws Exception

import net.deweger.crypto.*;
import java.math.BigInteger;
import java.util.List;
import java.util.ArrayList;
import java.nio.file.*;
import java.nio.charset.StandardCharsets;

public class RSA
{
    private String id = "";
    private BigInteger p = BigInteger.ONE;
    private BigInteger q = BigInteger.ONE;
    private BigInteger n = BigInteger.ONE;
    private BigInteger e = BigInteger.ZERO;
    private BigInteger d = BigInteger.ZERO;
    private BigInteger dp = BigInteger.ZERO;
    private BigInteger dq = BigInteger.ZERO;
    private BigInteger u = BigInteger.ZERO;
    boolean crt = false;
    int modulusBitSize = 0;

    Util util;
    SHA256 sha256;
    PBE pbe;

    // constructors
    public RSA() throws Exception
    {
        util = new Util();
        sha256 = new SHA256();
        pbe = new PBE();
    }

    // constructor
    public RSA(String id) throws Exception
    {
        this.id = id;
        util = new Util();
        sha256 = new SHA256();
        pbe = new PBE();
    }

    // Assignment 6, 12
    // generate RSA key pair, b = number of bits in the modulus
    // fermat4: e = 65537 or big random
    // crt: use CRT private key or not
    public void generateRSAKeyPair(int b, boolean fermat4, boolean crt)
    {
        modulusBitSize = b;
        this.crt = crt;
        BigInteger F4 = new BigInteger("65537");

        // primes and modulus
        while (n.bitLength() != b || n.mod(F4).equals(BigInteger.ZERO))
        {
            p = BigInteger.probablePrime((b+1)/2, util.secureRandom);
            q = BigInteger.probablePrime((b+1)/2, util.secureRandom);
            n = p.multiply(q);
        }

        // phi(n)
        BigInteger phi = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));

        // public exponent e
        if (fermat4)
            e = F4;
        else
        {
            e = phi;
            while (e.compareTo(phi) >= 0 || e.gcd(phi).compareTo(BigInteger.ONE) > 0)
                e = new BigInteger(b, util.secureRandom);
        }

        // private exponent d
        d = e.modInverse(phi);
        if (crt)
        {
            dp = d.mod(p.subtract(BigInteger.ONE));
            dq = d.mod(q.subtract(BigInteger.ONE));
            u = q.modInverse(p);
        }
    }

    // Assignment 6
    // write RSA public key to file fn
    public void writeRSAPublicKey(String fn) throws Exception
    {
        ArrayList<String[]> key = new ArrayList<String[]>();

        // build the tag-value structure
        // identity
        String[] identity = new String[2];
        identity[0] = "[id] (identity)";
        identity[1] = id;
        key.add(identity);
        // modulus
        String[] modulus = new String[2];
        modulus[0] = "[n] (modulus)";
        modulus[1] = Util.bigIntegerToHex64(n);
        key.add(modulus);
        // public exponent
        String[] publicExponent = new String[2];
        publicExponent[0] = "[e] (public exponent)";
        publicExponent[1] = Util.bigIntegerToHex64(e);
        key.add(publicExponent);

        // write to file
        List<String> lines = new ArrayList<String>();
        for (String[] item: key)
            lines.add(item[0] + "\n" + item[1]);
        Files.write(Paths.get(fn), lines, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
    }

    // Assignment 6, 12
    // prepare private key for writing
    private List<String> prepareRSAPrivateKey() throws Exception
    {
        ArrayList<String[]> key = new ArrayList<String[]>();
        if (crt)
        {
            // build the tag-value structure - CRT-case
            // identity
            String[] identity = new String[2];
            identity[0] = "[id] (identity)";
            identity[1] = id;
            key.add(identity);
            // two primes
            String[] primeP = new String[2];
            primeP[0] = "[p] (first prime)";
            primeP[1] = Util.bigIntegerToHex64(p);
            key.add(primeP);
            String[] primeQ = new String[2];
            primeQ[0] = "[q] (second prime)";
            primeQ[1] = Util.bigIntegerToHex64(q);
            key.add(primeQ);
            // two parts of private exponent
            String[] privateExponentP = new String[2];
            privateExponentP[0] = "[dp] (private exponent modulo p-1)";
            privateExponentP[1] = Util.bigIntegerToHex64(dp);
            key.add(privateExponentP);
            String[] privateExponentQ = new String[2];
            privateExponentQ[0] = "[dq] (private exponent modulo q-1)";
            privateExponentQ[1] = Util.bigIntegerToHex64(dq);
            key.add(privateExponentQ);
            // u = q^{-1} mod p
            String[] numberU = new String[2];
            numberU[0] = "[u] (inverse of q modulo p)";
            numberU[1] = Util.bigIntegerToHex64(u);
            key.add(numberU);
        }
        else
        {
            // build the tag-value structure - non-CRT case
            // identity
            String[] identity = new String[2];
            identity[0] = "[id] (identity)";
            identity[1] = id;
            key.add(identity);
            // modulus
            String[] modulus = new String[2];
            modulus[0] = "[n] (modulus)";
            modulus[1] = Util.bigIntegerToHex64(n);
            key.add(modulus);
            // private exponent
            String[] privateExponent = new String[2];
            privateExponent[0] = "[d] (private exponent)";
            privateExponent[1] = Util.bigIntegerToHex64(d);
            key.add(privateExponent);
        }

        // prepare for write to file
        List<String> lines = new ArrayList<String>();
        for (String[] item: key)
            lines.add(item[0] + "\n" + item[1]);
        return lines;
    }

    // Assignment 6, 12
    // write RSA private key to file fn
    public void writeRSAPrivateKey(String fn) throws Exception
    {
        // prepare key
        List<String> lines = prepareRSAPrivateKey();

        // write key
        Files.write(Paths.get(fn), lines, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
    }

    // Assignment 7, 12
    // write encrypted RSA private key to file fn
    public void secureWriteRSAPrivateKey(String password, String fn) throws Exception
    {
        // prepare key
        List<String> lines = prepareRSAPrivateKey();

        // encrypt key under password
        String serialized = ""; // lines separator is "@#"
        for (String line: lines)
            serialized += line + "@#";
        pbe.setPassword(password);
        pbe.setSalt(id.getBytes());
        pbe.setIterationCount(1000);
        String[] encrypted = Util.bytesToHex64(pbe.encrypt(serialized.getBytes())).split("\n");
        ArrayList<String> encryptedLines = new ArrayList<String>();
        for (int i = 0; i < encrypted.length; i++)
            encryptedLines.add(encrypted[i]);
        // write encrypted key
        Files.write(Paths.get(fn), encryptedLines, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
    }

    // Assignment 8
    // read RSA public key from file fn
    public void readRSAPublicKey(String fn) throws Exception
    {
        String id2 = "";

        ArrayList<String[]> key = new ArrayList<String[]>();

        // read file
        List<String> lines = Files.readAllLines(Paths.get(fn));

        // parse into tag-value structure
        String[] tagValue = new String[2];
        boolean first = true;
        for(String line: lines)
        {
            if (line.startsWith("["))
            {
                if (first)
                    first = false;
                else
                    key.add(tagValue);
                tagValue = new String[2];
                tagValue[0] = line;
                tagValue[1] = "";
            }
            else
                tagValue[1] += line.trim();
        }
        key.add(tagValue);

        // set variables
        for (String[] s: key)
        {
            if (s[0].startsWith("[id]"))
            {
                id2 = s[1];
                if (!id.equals(id2))
                    throw new Exception("RSA.readRSAPublicKey: identities do not match");
            }
            if (s[0].startsWith("[n]"))
                n = new BigInteger(s[1], 16);
            if (s[0].startsWith("[e]"))
                e = new BigInteger(s[1], 16);
        }
        modulusBitSize = n.bitLength();
    }

    // Assignment 8, 12
    // process RSA private key after reading
    private void processRSAPrivateKey(List<String> lines) throws Exception
    {
        String id2 = "";

        ArrayList<String[]> key = new ArrayList<String[]>();

        // parse into tag-value structure
        String[] tagValue = new String[2];
        boolean first = true;
        for(String line: lines)
        {
            if (line.startsWith("["))
            {
                if (first)
                    first = false;
                else
                    key.add(tagValue);
                tagValue = new String[2];
                tagValue[0] = line;
                tagValue[1] = "";
            }
            else
                tagValue[1] += line.trim();
        }
        key.add(tagValue);

        // set variables
        for (String[] s: key)
        {
            if (s[0].startsWith("[id]"))
            {
                id2 = s[1];
                if (!id.equals(id2))
                    throw new Exception("RSA.processRSAPrivateKey: identities do not match");
            }
            if (s[0].startsWith("[n]"))
                n = new BigInteger(s[1], 16);
            if (s[0].startsWith("[d]"))
                d = new BigInteger(s[1], 16);
            if (s[0].startsWith("[p]"))
                p = new BigInteger(s[1], 16);
            if (s[0].startsWith("[q]"))
                q = new BigInteger(s[1], 16);
            if (s[0].startsWith("[dp]"))
                dp = new BigInteger(s[1], 16);
            if (s[0].startsWith("[dq]"))
                dq = new BigInteger(s[1], 16);
            if (s[0].startsWith("[u]"))
                u = new BigInteger(s[1], 16);
        }

        // determine CRT and compute n
        crt = (p.compareTo(BigInteger.ONE) >= 0) && (q.compareTo(BigInteger.ONE) >= 0) &&
              (dp.compareTo(BigInteger.ONE) >= 0) && (dp.compareTo(BigInteger.ONE) >= 0) &&
              (u.compareTo(BigInteger.ONE) >= 0);
        if (crt)
            n = p.multiply(q);
        modulusBitSize = n.bitLength();
    }

    // Assignment 8
    // read RSA private key from file fn
    public void readRSAPrivateKey(String fn) throws Exception
    {
        // read key
        List<String> lines = Files.readAllLines(Paths.get(fn));

        // process key
        processRSAPrivateKey(lines);
    }

    // Assignment 8
    // read encrypted RSA private key from file fn
    public void secureReadRSAPrivateKey(String password, String fn) throws Exception
    {
        // read encrypted key
        List<String> encryptedLines = Files.readAllLines(Paths.get(fn));

        // decrypt key under password
        String encrypted = "";
        for (String line: encryptedLines)
            encrypted += line;
        pbe.setPassword(password);
        pbe.setSalt(id.getBytes());
        pbe.setIterationCount(1000);
        String[] serialized = (new String(pbe.decrypt(Util.hexToBytes(encrypted)), StandardCharsets.UTF_8)).split("@#");
        ArrayList<String> lines = new ArrayList<String>();
        for (int i = 0; i < serialized.length; i++)
        {
            String[] s = serialized[i].split("\n");
            for (int j = 0; j < s.length; j++)
                lines.add(s[j]);
        }

        // process key
        processRSAPrivateKey(lines);
    }

    // Assignment 8
    // raw RSA encryption / signature verification
    // result will be padded with zero bytes to the modulus bytelength
    //   to prevent two's complement problems in the PKCS#1v1.5 situation
    private byte[] rawPublicKeyOperation(byte[] m) throws Exception
    {
        byte[] c = new byte[(modulusBitSize + 7)/8];
        byte[] cc = Util.bigIntegerToByteArray((new BigInteger(1, m)).modPow(e, n));
        System.arraycopy(cc, 0, c, c.length-cc.length, cc.length);
        return c;
    }

    // Assignment 8, 12
    // raw RSA decryption / signature generation
    // result will be padded with zero bytes to the modulus bytelength
    //   to prevent two's complement problems in the PKCS#1v1.5 situation
    private byte[] rawPrivateKeyOperation(byte[] c) throws Exception
    {
        byte[] m = new byte[(modulusBitSize + 7)/8];
        byte[] mm = null;
        if (crt)
        {
            //Chinese Remainder Theorem
            BigInteger cc = new BigInteger(1, c);
            BigInteger mp = cc.modPow(dp, p);
            BigInteger mq = cc.modPow(dq, q);
            mm = Util.bigIntegerToByteArray(mp.subtract(mq).multiply(u).multiply(q).add(mq).mod(n));
        }
        else
            mm = Util.bigIntegerToByteArray((new BigInteger(1, c)).modPow(d, n));
        System.arraycopy(mm, 0, m, m.length-mm.length, mm.length);
        return m;
    }

    // Assignment 8
    // RSA encryption with PKCS1v1.5-padding
    public byte[] encryptKey(byte[] plaintext) throws Exception
    {
        int nb = (modulusBitSize + 7)/8;
        if (plaintext.length > nb - 11)
            throw new Exception("RSA.encryptKey: plaintext (" + plaintext.length +
                                " bytes) too long for modulus (" + modulusBitSize + " bits)");

        // create padding
        byte[] ps = util.randomBytes(nb - plaintext.length - 3);
        for (int i = 0; i < ps.length; i++)
        {
            while (ps[i] == (byte)0x00)
            {
                byte[] b = util.randomBytes(1);
                ps[i] = b[0];
            }
        }

        // build padded plaintext
        byte[] paddedPlaintext = new byte[nb];
        paddedPlaintext[0] = (byte)0x00;
        paddedPlaintext[1] = (byte)0x02;
        System.arraycopy(ps, 0, paddedPlaintext, 2, ps.length);
        paddedPlaintext[ps.length+2] = (byte)0x00;
        System.arraycopy(plaintext, 0, paddedPlaintext, ps.length+3, plaintext.length);

        // raw encryption
        return rawPublicKeyOperation(paddedPlaintext);
    }

    // Assignment 8
    // RSA decryption with PKCS1v1.5-padding
    public byte[] decryptKey(byte[] ciphertext) throws Exception
    {
        int nb = (modulusBitSize + 7)/8;
        if (ciphertext.length > nb)
            throw new Exception("RSA.decryptKey: ciphertext (" + ciphertext.length +
                                " bytes) too long for modulus (" + modulusBitSize + " bits)");

        // raw decryption
        byte[] paddedPlaintext = rawPrivateKeyOperation(ciphertext);

        // remove padding
        if (paddedPlaintext[0] != (byte)0x00)
            throw new Exception("RSA.decryptKey: wrong padding (first byte)");
        if (paddedPlaintext[1] != (byte)0x02)
            throw new Exception("RSA.decryptKey: wrong padding (second byte)");
        int psLength = -1;
        for (int i = 0; i < paddedPlaintext.length - 2; i++)
        {
            if (paddedPlaintext[i+2] == (byte)0x00)
            {
                psLength = i;
                break;
            }
        }
        if (psLength == -1)
            throw new Exception("RSA.decryptKey: wrong padding (end of padding not found)");
        if (psLength < 8)
            throw new Exception("RSA.decryptKey: wrong padding (padding too short)");
        byte[] plaintext = new byte[paddedPlaintext.length - psLength - 3];
        System.arraycopy(paddedPlaintext, psLength + 3, plaintext, 0, paddedPlaintext.length - psLength - 3);
        return plaintext;
    }

    // Assignment 10
    // RSA signature generation
    public byte[] generateSignature(byte[] message) throws Exception
    {
        byte[] hash = sha256.hash(message);
        int nb = (modulusBitSize + 7)/8;

        // create padding
        if (nb < 62)
            throw new Exception("RSA.generateSignature: modulus too short");
        byte[] ps = new byte[nb - 54];
        for (int i = 0; i < nb - 54; i++)
            ps[i] = (byte)0xff;

        // build padded hash
        byte[] paddedHash = new byte[nb];
        paddedHash[0] = (byte)0x00;
        paddedHash[1] = (byte)0x01;
        System.arraycopy(ps, 0, paddedHash, 2, nb - 54);
        paddedHash[ps.length+2] = (byte)0x00;
        System.arraycopy(Util.hexToBytes("3031300d060960864801650304020105000420"), 0, paddedHash, nb - 51, 19);
        System.arraycopy(hash, 0, paddedHash, nb - 32, 32);

        // raw signature generation
        return rawPrivateKeyOperation(paddedHash);
    }

    // Assignment 10
    // RSA signature verification
    public boolean verifySignature(byte[] message, byte[] signature) throws Exception
    {
        byte[] hash = sha256.hash(message);
        int nb = (modulusBitSize + 7)/8;
        if (nb < 62)
            throw new Exception("RSA.verifySignature: modulus too short");
        if (signature.length > nb)
            throw new Exception("RSA.verifySignature: signature (" + signature.length +
                                " bytes) too long for modulus (" + modulusBitSize + " bits)");

        // raw signature verification
        byte[] paddedHash = rawPublicKeyOperation(signature);

        // remove padding
        if (paddedHash[0] != (byte)0x00)
            throw new Exception("RSA.verifySignature: wrong padding (first byte)");
        if (paddedHash[1] != (byte)0x01)
            throw new Exception("RSA.verifySignature: wrong padding (second byte)");
        for (int i = 0; i < nb - 54; i++)
        {
            if (paddedHash[i+2] != (byte)0xff)
                throw new Exception("RSA.verifySignature: wrong padding (byte " + i +" unequal to 0xff)");
        }
        if (paddedHash[nb - 52] != (byte)0x00)
            throw new Exception("RSA.verifySignature: wrong padding (delimiter byte nonzero)");
        byte[] t = new byte[19];
        System.arraycopy(paddedHash, nb - 51, t, 0, 19);
        if (!Util.bytesToHex(t).equals("3031300d060960864801650304020105000420"))
            throw new Exception("RSA.verifySignature: wrong padding (algorithm identifier)");
        byte[] computedHash = new byte[32];
        System.arraycopy(paddedHash, nb - 32, computedHash, 0, 32);

        // compare hashes
        return Util.bytesToHex(hash).equals(Util.bytesToHex(computedHash));
    }

    // Assignment 11
    // computes p and q from n, d and e, obviously only in the non-CRT case
    public void factorFromKeyPair() throws Exception
    {
        if (!crt)
        {
            // e d - 1 = 2^s t
            BigInteger k = e.multiply(d).subtract(BigInteger.ONE);
            BigInteger t = k;
            int s = t.getLowestSetBit();
            t = t.shiftRight(s);

            BigInteger a = BigInteger.ONE;
            boolean doorgaan = true;
            while (doorgaan)
            {
                // for a choose small primes
                a = a.nextProbablePrime();
                System.out.println("trying a = " + a.toString());

                // compute sequence of squares until 1 is reached
                BigInteger x = a.modPow(t, n);
                BigInteger x2 = x.multiply(x).mod(n);
                while (!x2.equals(BigInteger.ONE))
                {
                    x = x2;
                    x2 = x.multiply(x).mod(n);
                }

                // check if nontrivial factor found
                if (!x.equals(BigInteger.ONE) && !x.equals(BigInteger.ONE.negate().mod(n)))
                {
                    p = n.gcd(x.add(BigInteger.ONE));
                    q = n.gcd(x.subtract(BigInteger.ONE));
                    doorgaan = false;
                }
            }
            System.out.println(((p.compareTo(BigInteger.ONE) > 0) && (q.compareTo(BigInteger.ONE) > 0) &&
                                 p.multiply(q).equals(n))?"SUCCEEDED":"FAILED");
        }
    }
}
