// RSA Wiener
// Benne de Weger - TU/e - March 2021

// FAECTOR Workshop Cryptographic Programming
// Assignment 17, 18 - application

import net.deweger.crypto.*;

import java.math.BigInteger;

public class RSAWiener
{
    private BigInteger p = BigInteger.ONE;
    private BigInteger q = BigInteger.ONE;
    private BigInteger n = BigInteger.ONE;
    private BigInteger e = BigInteger.ZERO;
    private BigInteger d = BigInteger.ZERO;
    Util util;

    // constructor
    public RSAWiener()
    {
        util = new Util();
    }

    // Assignment 17, 18
    // generate weak RSA key pair, b = number of bits in the modulus
    // delta_wiener = number of bits in the private exponent
    // delta_fermat = true indicates also small prime difference
    private void generateWeakRSAKeyPair(int b, int delta_wiener, int delta_fermat)
    {
        // primes and modulus
        while (n.bitLength() != b)
        {
            p = BigInteger.probablePrime((b+1)/2, util.secureRandom);
            q = p.add((new BigInteger(delta_fermat-1, util.secureRandom)).multiply(BigInteger.TWO));
            while (!q.isProbablePrime(100))
                q = q.subtract(BigInteger.TWO);
            n = p.multiply(q);
        }

        // phi(n)
        BigInteger phi = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));

        // private exponent d
        while (phi.gcd(d).compareTo(BigInteger.ONE) > 0)
            d = new BigInteger(delta_wiener, util.secureRandom);

        // public exponent e
        e = d.modInverse(phi);
    }

    private BigInteger[][] reduce(BigInteger[][] b)
    {
        BigInteger[][] br = new BigInteger[2][2];
        if (b[0][0].multiply(b[0][0]).add(b[0][1].multiply(b[0][1])).compareTo(
            b[1][0].multiply(b[1][0]).add(b[1][1].multiply(b[1][1]))) == 1)
        {
            br[0][0] = b[1][0];
            br[0][1] = b[1][1];
            br[1][0] = b[0][0];
            br[1][1] = b[0][1];
        }
        else
        {
            br[0][0] = b[0][0];
            br[0][1] = b[0][1];
            br[1][0] = b[1][0];
            br[1][1] = b[1][1];
        }
        boolean notYetDone = true;
        while (notYetDone)
        {
            BigInteger[] k = br[0][0].multiply(br[1][0]).add(br[0][1].multiply(br[1][1])).divideAndRemainder(
                br[0][0].multiply(br[0][0]).add(br[0][1].multiply(br[0][1])));
            if (k[1].multiply(BigInteger.TWO).compareTo(br[0][0].multiply(br[0][0]).add(br[0][1].multiply(br[0][1]))) >= 0)
                k[0] = k[0].add(BigInteger.ONE);
            br[1][0] = br[1][0].subtract(br[0][0].multiply(k[0]));
            br[1][1] = br[1][1].subtract(br[0][1].multiply(k[0]));
            notYetDone = !k[0].equals(BigInteger.ZERO);
            if (br[0][0].multiply(br[0][0]).add(br[0][1].multiply(br[0][1])).compareTo(
                br[1][0].multiply(br[1][0]).add(br[1][1].multiply(br[1][1]))) == 1)
            {
                BigInteger aux = br[0][0];
                br[0][0] = br[1][0];
                br[1][0] = aux;
                aux = br[0][1];
                br[0][1] = br[1][1];
                br[1][1] = aux;
            }
        }
        return br;
    }

    private void testxy(int x, int y, BigInteger[][] t) throws Exception
    {
        // compute candidates for d and k
        BigInteger candd = t[0][0].multiply(BigInteger.valueOf(x)).add(t[1][0].multiply(BigInteger.valueOf(y)));
        BigInteger candk = t[0][1].multiply(BigInteger.valueOf(x)).add(t[1][1].multiply(BigInteger.valueOf(y)));
        if (candd.compareTo(BigInteger.ZERO) < 0)
            candd = candd.negate();
        else
            candk = candk.negate();

        // first test: k > 0
        if (candk.compareTo(BigInteger.ZERO) > 0)
        {
//            System.out.println("Wiener first test passed for x = " + x + ", y = " + y);

            BigInteger ed = e.multiply(candd);

            // second test: e d = 1 (mod k)
            if (ed.mod(candk).equals(BigInteger.ONE))
            {
                System.out.println("Wiener second test passed for x = " + x + ", y = " + y);

                BigInteger pplusq = n.add(BigInteger.ONE).subtract(ed.subtract(BigInteger.ONE).divide(candk));
                BigInteger discr2 = pplusq.multiply(pplusq).subtract(n.shiftLeft(2));
                BigInteger discr = discr2.sqrt();

                // third test: quadratic equation has positive discriminant
                if (discr.multiply(discr).equals(discr2))
                {
                    System.out.println("Wiener third test passed for x = " + x + ", y = " + y);

                    BigInteger candp = pplusq.subtract(discr).shiftRight(1);
                    BigInteger candq = pplusq.add(discr).shiftRight(1);

                    // fourth test: p q = n
                    if (candp.multiply(candq).equals(n))
                    {
                        System.out.println("Wiener fourth test passed for x = " + x + ", y = " + y + ", private key values:");
                        System.out.println("p = " + p.toString());
                        System.out.println("q = " + q.toString());
                        System.out.println("d = " + d.toString());
                        throw new Exception("WienerSuccess");
                    }
                }
            }
        }
    }

    // Wiener attack
    private void wienerAttack(int delta_fermat)
    {
        // make basis of lattice
        BigInteger[][] b = new BigInteger[2][2];
        b[0][0] = BigInteger.ONE.shiftLeft(2*delta_fermat-n.bitLength()/2);
        b[0][1] = e;
        b[1][0] = BigInteger.ZERO;
        b[1][1] = n.subtract(n.shiftLeft(2).sqrt());
        System.out.println("basis: ");
        System.out.println("(" + b[0][0].toString() + ", " + b[0][1].toString() + ")");
        System.out.println("(" + b[1][0].toString() + ", " + b[1][1].toString() + ")");

        // reduce basis
        BigInteger[][] br = reduce(b);
        BigInteger det = b[0][0].multiply(b[1][1]).subtract(b[0][1].multiply(b[1][0]));
        BigInteger[][] t = new BigInteger[2][2];
        t[0][0] = b[1][1].multiply(br[0][0]).subtract(b[1][0].multiply(br[0][1])).divide(det);
        t[0][1] = b[0][0].multiply(br[0][1]).subtract(b[0][1].multiply(br[0][0])).divide(det);
        t[1][0] = b[1][1].multiply(br[1][0]).subtract(b[1][0].multiply(br[1][1])).divide(det);
        t[1][1] = b[0][0].multiply(br[1][1]).subtract(b[0][1].multiply(br[1][0])).divide(det);
        System.out.println("reduced basis: ");
        System.out.println("(" + br[0][0].toString() + ", " + br[0][1].toString() + ")");
        System.out.println("(" + br[1][0].toString() + ", " + br[1][1].toString() + ")");
        System.out.println("transformation: ");
        System.out.println("(" + t[0][0].toString() + ", " + t[0][1].toString() + ")");
        System.out.println("(" + t[1][0].toString() + ", " + t[1][1].toString() + ")");

        // test in shells (shell k has all (x,y) with max(|x|,|y|) = k, only those with x >= 0 are tested)
        int kbound = 500;
        try
        {
            for (int k = 1; k <= kbound; k++)
            {
                // x = k
                for (int y = -k; y <= k; y++)
                    testxy(k, y, t);
                // y = k, -k
                for (int x = 0; x <= k-1; x++)
                {
                    testxy(x, k, t);
                    testxy(x, -k, t);
                }
            }
            System.out.println("Wiener attack up to coordinate bound " + kbound + " failed.");
        }
        catch (Exception exc)
        {
            if (exc.getMessage().equals("WienerSuccess"))
            {
                System.out.println("Wiener attack succeeded.");
            }
            else
                exc.printStackTrace();
        }
    }

    // bn = bitsize of modulus
    // delta = bitsize of private exponent
    private static void test(int bn, int delta_wiener, int delta_fermat)
    {
        RSAWiener rsaWiener = new RSAWiener();

        // generate weak key
        rsaWiener.generateWeakRSAKeyPair(bn, delta_wiener, delta_fermat);
        System.out.println("public key values:");
        System.out.println("n = " + rsaWiener.n.toString());
        System.out.println("e = " + rsaWiener.e.toString());

        // do Wiener attack
        rsaWiener.wienerAttack(delta_fermat);
    }

    // main, test different bitsizes
    public static void main(String[] args)
    {
        // with 2048 bit n, Fermat would work for prime difference bitsize 512 + epsilon,
        // and Wiener would also work for private key exponent bitsize 512 + epsilon,
        // combined we can get way further, to their sum 1536 + epsilon
        test(2048, 772, 772);
    }
}
