// FW_RSA
// Benne de Weger - TU/e - February-April 2021

// FAECTOR Workshop Cryptographic Programming
// Assignment 6, 8, 10 - API (public methods):
//
// constructors
//    public FW_RSA() throws Exception
//    public FW_RSA(String id) throws Exception
//
// set / get id
//    public void setId(String id)
//    public String getId()
//
// generate RSA key pair,
// ------------------------------------------------------------------------------------
// b = number of bits in the modulus
// fermat4: e = 65537 or big random
//    public void generateRSAKeyPair(int b, boolean fermat4)
//
// key input / output
// ------------------------------------------------------------------------------------
// write RSA public key to file fn
//    public void writeRSAPublicKey(String fn) throws Exception
// write RSA private key to file fn
//    public void writeRSAPrivateKey(String fn) throws Exception
// read RSA public key from file fn
//    public void readRSAPublicKey(String fn) throws Exception
// read RSA private key from file fn
//    public void readRSAPrivateKey(String fn) throws Exception
//
// encryption and decryption of symmetric keys
// ------------------------------------------------------------------------------------
// RSA encryption with PKCS1v1.5-padding
//    public byte[] encryptKey(byte[] plaintext) throws Exception
// RSA decryption with PKCS1v1.5-padding
//    public byte[] decryptKey(byte[] ciphertext) throws Exception
//
// signatures
// ------------------------------------------------------------------------------------
// RSA signature generation
//    public byte[] generateSignature(byte[] message) throws Exception
// RSA signature verification
//    public boolean verifySignature(byte[] message, byte[] signature) throws Exception

import net.deweger.crypto.*;
import java.math.BigInteger;
import java.util.List;
import java.util.ArrayList;
import java.nio.file.*;
import java.nio.charset.StandardCharsets;

public class FW_RSA
{
    private String id = "";
    private BigInteger p = BigInteger.ONE;
    private BigInteger q = BigInteger.ONE;
    private BigInteger n = BigInteger.ONE;
    private BigInteger e = BigInteger.ZERO;
    private BigInteger d = BigInteger.ZERO;
    int modulusBitSize = 0;

    Util util;
    SHA256 sha256;

    // constructors
    public FW_RSA() throws Exception
    {
        util = new Util();
        sha256 = new SHA256();
    }

    public FW_RSA(String id) throws Exception
    {
        this.id = id;
        util = new Util();
        sha256 = new SHA256();
    }

    // set or get identity
    public void setId(String id)
    {
        this.id = id;
    }

    public String getId()
    {
        return id;
    }

    // Assignment 6
    // generate RSA key pair, b = number of bits in the modulus
    // fermat4: e = 65537 or big random
    public void generateRSAKeyPair(int b, boolean fermat4)
    {
        modulusBitSize = b;
        BigInteger F4 = new BigInteger("65537");

        // primes p, q and modulus n
        while (n.bitLength() != b || n.mod(F4).equals(BigInteger.ZERO))
        {
            p = BigInteger.probablePrime((b+1)/2, util.secureRandom);
            q = BigInteger.probablePrime((b+1)/2, util.secureRandom);
            n = p.multiply(q);
        }

        // phi(n)
        BigInteger phi = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));

        // public exponent e
        if (fermat4)
            e = F4;
        else
        {
            e = phi;
            while (e.compareTo(phi) >= 0 || e.gcd(phi).compareTo(BigInteger.ONE) > 0)
                e = new BigInteger(b, util.secureRandom);
        }

        // private exponent d
        d = e.modInverse(phi);
    }

    // Assignment 6
    // write RSA public key to file fn
    public void writeRSAPublicKey(String fn) throws Exception
    {
        ArrayList<String[]> key = new ArrayList<String[]>();

        // build the tag-value structure
        // identity
        String[] identity = new String[2];
        identity[0] = "[id] (identity)";
        identity[1] = id;
        key.add(identity);
        // modulus
        String[] modulus = new String[2];
        modulus[0] = "[n] (modulus)";
        modulus[1] = Util.bigIntegerToHex64(n);
        key.add(modulus);
        // public exponent
        String[] publicExponent = new String[2];
        publicExponent[0] = "[e] (public exponent)";
        publicExponent[1] = Util.bigIntegerToHex64(e);
        key.add(publicExponent);

        // write to file
        List<String> lines = new ArrayList<String>();
        for (String[] item: key)
            lines.add(item[0] + "\n" + item[1]);
        Files.write(Paths.get(fn), lines, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
    }

    // Assignment 6
    // write RSA private key to file fn
    public void writeRSAPrivateKey(String fn) throws Exception
    {
        ArrayList<String[]> key = new ArrayList<String[]>();

        // build the tag-value structure
        // identity
        String[] identity = new String[2];
        identity[0] = "[id] (identity)";
        identity[1] = id;
        key.add(identity);
        // modulus
        String[] modulus = new String[2];
        modulus[0] = "[n] (modulus)";
        modulus[1] = Util.bigIntegerToHex64(n);
        key.add(modulus);
        // private exponent
        String[] privateExponent = new String[2];
        privateExponent[0] = "[d] (private exponent)";
        privateExponent[1] = Util.bigIntegerToHex64(d);
        key.add(privateExponent);

        // write to file
        List<String> lines = new ArrayList<String>();
        for (String[] item: key)
            lines.add(item[0] + "\n" + item[1]);
        Files.write(Paths.get(fn), lines, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
    }

    // Assignment 8
    // read RSA public key from file fn
    public void readRSAPublicKey(String fn) throws Exception
    {
        String id2 = "";

        ArrayList<String[]> key = new ArrayList<String[]>();

        // read file
        List<String> lines = Files.readAllLines(Paths.get(fn));

        // parse into tag-value structure
        String[] tagValue = new String[2];
        boolean first = true;
        for(String line: lines)
        {
            if (line.startsWith("["))
            {
                if (first)
                    first = false;
                else
                    key.add(tagValue);
                tagValue = new String[2];
                tagValue[0] = line;
                tagValue[1] = "";
            }
            else
                tagValue[1] += line.trim();
        }
        key.add(tagValue);

        // set variables
        for (String[] s: key)
        {
            if (s[0].startsWith("[id]"))
            {
                id2 = s[1];
//                if (!id.equals(id2))
//                    throw new Exception("RSA.readRSAPublicKey: identities do not match");
            }
            if (s[0].startsWith("[n]"))
                n = new BigInteger(s[1], 16);
            if (s[0].startsWith("[e]"))
                e = new BigInteger(s[1], 16);
        }
        modulusBitSize = n.bitLength();
    }

    // Assignment 8
    // read RSA private key from file fn
    public void readRSAPrivateKey(String fn) throws Exception
    {
        String id2 = "";

        ArrayList<String[]> key = new ArrayList<String[]>();

        // read key
        List<String> lines = Files.readAllLines(Paths.get(fn));

        // parse into tag-value structure
        String[] tagValue = new String[2];
        boolean first = true;
        for(String line: lines)
        {
            if (line.startsWith("["))
            {
                if (first)
                    first = false;
                else
                    key.add(tagValue);
                tagValue = new String[2];
                tagValue[0] = line;
                tagValue[1] = "";
            }
            else
                tagValue[1] += line.trim();
        }
        key.add(tagValue);

        // set variables
        for (String[] s: key)
        {
            if (s[0].startsWith("[id]"))
            {
                id2 = s[1];
//                if (!id.equals(id2))
//                    throw new Exception("RSA.readRSAPrivateKey: identities do not match");
            }
            if (s[0].startsWith("[n]"))
                n = new BigInteger(s[1], 16);
            if (s[0].startsWith("[d]"))
                d = new BigInteger(s[1], 16);
        }
        modulusBitSize = n.bitLength();
    }

    // Assignment 8
    // raw RSA public key (encryption / signature verification) operation
    // result will be padded with zero bytes to the modulus bytelength
    //   to prevent two's complement problems in the PKCS#1v1.5 situation
    private byte[] rawPublicKeyOperation(byte[] m) throws Exception
    {
        byte[] c = new byte[(modulusBitSize + 7)/8];
        byte[] cc = Util.bigIntegerToByteArray((new BigInteger(1, m)).modPow(e, n));
        System.arraycopy(cc, 0, c, c.length-cc.length, cc.length);
        return c;
    }

    // Assignment 8
    // raw RSA private key (decryption / signature generation) operation
    // result will be padded with zero bytes to the modulus bytelength
    //   to prevent two's complement problems in the PKCS#1v1.5 situation
    private byte[] rawPrivateKeyOperation(byte[] c) throws Exception
    {
        byte[] m = new byte[(modulusBitSize + 7)/8];
        byte[] mm = Util.bigIntegerToByteArray((new BigInteger(1, c)).modPow(d, n));
        System.arraycopy(mm, 0, m, m.length-mm.length, mm.length);
        return m;
    }

    // Assignment 8
    // RSA encryption with PKCS1v1.5-padding
    public byte[] encryptKey(byte[] plaintext) throws Exception
    {
        int nb = (modulusBitSize + 7)/8;
        if (plaintext.length > nb - 11)
            throw new Exception("RSA.encryptKey: plaintext (" + plaintext.length +
                                " bytes) too long for modulus (" + modulusBitSize + " bits)");

        // create padding
        byte[] ps = util.randomBytes(nb - plaintext.length - 3);
        for (int i = 0; i < ps.length; i++)
        {
            while (ps[i] == (byte)0x00)
            {
                byte[] b = util.randomBytes(1);
                ps[i] = b[0];
            }
        }

        // build padded plaintext
        byte[] paddedPlaintext = new byte[nb];
        paddedPlaintext[0] = (byte)0x00;
        paddedPlaintext[1] = (byte)0x02;
        System.arraycopy(ps, 0, paddedPlaintext, 2, ps.length);
        paddedPlaintext[ps.length+2] = (byte)0x00;
        System.arraycopy(plaintext, 0, paddedPlaintext, ps.length+3, plaintext.length);

        // raw encryption
        return rawPublicKeyOperation(paddedPlaintext);
    }

    // Assignment 8
    // RSA decryption with PKCS1v1.5-padding
    public byte[] decryptKey(byte[] ciphertext) throws Exception
    {
        int nb = (modulusBitSize + 7)/8;
        if (ciphertext.length > nb)
            throw new Exception("RSA.decryptKey: ciphertext (" + ciphertext.length +
                                " bytes) too long for modulus (" + modulusBitSize + " bits)");

        // raw decryption
        byte[] paddedPlaintext = rawPrivateKeyOperation(ciphertext);

        // remove padding
        if (paddedPlaintext[0] != (byte)0x00)
            throw new Exception("RSA.decryptKey: wrong padding (first byte)");
        if (paddedPlaintext[1] != (byte)0x02)
            throw new Exception("RSA.decryptKey: wrong padding (second byte)");
        int psLength = -1;
        for (int i = 0; i < paddedPlaintext.length - 2; i++)
        {
            if (paddedPlaintext[i+2] == (byte)0x00)
            {
                psLength = i;
                break;
            }
        }
        if (psLength == -1)
            throw new Exception("RSA.decryptKey: wrong padding (end of padding not found)");
        if (psLength < 8)
            throw new Exception("RSA.decryptKey: wrong padding (padding too short)");
        byte[] plaintext = new byte[paddedPlaintext.length - psLength - 3];
        System.arraycopy(paddedPlaintext, psLength + 3, plaintext, 0, paddedPlaintext.length - psLength - 3);

        return plaintext;
    }

    // Assignment 10
    // RSA signature generation
    public byte[] generateSignature(byte[] message) throws Exception
    {
        byte[] hash = sha256.hash(message);
        int nb = (modulusBitSize + 7)/8;

        // create padding
        if (nb < 62)
            throw new Exception("RSA.generateSignature: modulus too short");
        byte[] ps = new byte[nb - 54];
        for (int i = 0; i < nb - 54; i++)
            ps[i] = (byte)0xff;

        // build padded hash
        byte[] paddedHash = new byte[nb];
        paddedHash[0] = (byte)0x00;
        paddedHash[1] = (byte)0x01;
        System.arraycopy(ps, 0, paddedHash, 2, nb - 54);
        paddedHash[ps.length+2] = (byte)0x00;
        System.arraycopy(Util.hexToBytes("3031300d060960864801650304020105000420"), 0, paddedHash, nb - 51, 19);
        System.arraycopy(hash, 0, paddedHash, nb - 32, 32);

        // raw signature generation
        return rawPrivateKeyOperation(paddedHash);
    }

    // Assignment 10
    // RSA signature verification
    public boolean verifySignature(byte[] message, byte[] signature) throws Exception
    {
        byte[] hash = sha256.hash(message);
        int nb = (modulusBitSize + 7)/8;
        if (nb < 62)
            throw new Exception("RSA.verifySignature: modulus too short");
        if (signature.length > nb)
            throw new Exception("RSA.verifySignature: signature (" + signature.length +
                                " bytes) too long for modulus (" + modulusBitSize + " bits)");

        // raw signature verification
        byte[] paddedHash = rawPublicKeyOperation(signature);

        // remove padding
        if (paddedHash[0] != (byte)0x00)
            throw new Exception("RSA.verifySignature: wrong padding (first byte)");
        if (paddedHash[1] != (byte)0x01)
            throw new Exception("RSA.verifySignature: wrong padding (second byte)");
        for (int i = 0; i < nb - 54; i++)
        {
            if (paddedHash[i+2] != (byte)0xff)
                throw new Exception("RSA.verifySignature: wrong padding (byte " + i +" unequal to 0xff)");
        }
        if (paddedHash[nb - 52] != (byte)0x00)
            throw new Exception("RSA.verifySignature: wrong padding (delimiter byte nonzero)");
        byte[] t = new byte[19];
        System.arraycopy(paddedHash, nb - 51, t, 0, 19);
        if (!Util.bytesToHex(t).equals("3031300d060960864801650304020105000420"))
            throw new Exception("RSA.verifySignature: wrong padding (algorithm identifier)");
        byte[] computedHash = new byte[32];
        System.arraycopy(paddedHash, nb - 32, computedHash, 0, 32);

        // compare hashes
        return Util.bytesToHex(hash).equals(Util.bytesToHex(computedHash));
    }
}
