IMPLEMENTATION MODULE RSAGenerate;

        (********************************************************)
        (*                                                      *)
        (*           The RSA public key cryptosystem            *)
        (*                                                      *)
        (*  Programmer:         P. Moylan                       *)
        (*  Last edited:        14 July 2023                    *)
        (*  Status:             Working                         *)
        (*                                                      *)
        (*      In case you're wondering what RSA stands for,   *)
        (*      it's Rivest-Shamir-Adleman, after the           *)
        (*      inventors of the algorithm.                     *)
        (*                                                      *)
        (********************************************************)


FROM RSAKeys IMPORT
    (* type *)  RSAKeyType;

IMPORT BigNum;

FROM BigNum IMPORT
    (* type *)  BN,
    (* proc *)  ShowBNUsage;

FROM Primes IMPORT
    (* proc *)  RandomPrime;

FROM STextIO IMPORT
    (* proc *)  WriteChar, WriteString, WriteLn;

(************************************************************************)

CONST TESTING = FALSE;

(************************************************************************)

PROCEDURE WriteCard (N: CARDINAL);

    BEGIN
        IF N > 9 THEN
            WriteCard (N DIV 10);
            N := N MOD 10;
        END (*IF*);
        WriteChar (CHR(ORD('0')+N));
    END WriteCard;

(************************************************************************)
(*                           MODULAR ARITHMETIC                         *)
(************************************************************************)

PROCEDURE gcd (a, b: BN): BN;

    (* Greatest common divisor. *)

    VAR a1, b1, temp, Q, R: BN;

    BEGIN
        (*WriteString ("Entering gcd, ");  ShowBNUsage;*)
        a1 := BigNum.CopyBN (a);
        IF BigNum.Cmp(a, b) = 0 THEN RETURN a1;
        ELSE
            b1 := BigNum.CopyBN(b);
            IF BigNum.Cmp (b1, a1) > 0 THEN
                temp := b1;  b1 := a1;  a1 := temp;
            END (*IF*);

            (* Now a1 > b1 for the rest of the calculation. *)

            WHILE NOT BigNum.IsZero(b1) DO
                BigNum.Divide (a1, b1, Q, R);
                BigNum.Discard (Q);
                BigNum.Discard (a1);
                a1 := b1;
                b1 := R;
            END (*WHILE*);
        END (*IF*);
        BigNum.Discard (b1);
        (*WriteString ("Leaving gcd, ");  ShowBNUsage;*)
        RETURN a1;
    END gcd;

(************************************************************************)

PROCEDURE exgcd (a, m: BN): BN;

    (* The extended Euclid algorithm calculates x and y such that       *)
    (* ax+by = gcd(a,b).  This is a variant where we return x but don't *)
    (* care about the solution for y.                                   *)
    (* This is of interest mostly in the case where a and m are         *)
    (* relatively prime, because in that case x = inverse(a) mod m.     *)

    VAR Q, oldR, R, newR, temp: BN;
        xold, x, xnew: BN;

    BEGIN
        (*WriteString ("Entering exgcd, ");  ShowBNUsage;*)
        oldR := BigNum.CopyBN(a);  R := BigNum.CopyBN(m);
        xold := BigNum.MakeBignum(1);  x := BigNum.MakeBignum(0);
        LOOP
            BigNum.Divide (oldR, R, Q, newR);
            temp := BigNum.Prod (Q, x);
            xnew := BigNum.Diff (xold, temp);
            BigNum.Discard (temp);
            BigNum.Discard (Q);
            IF BigNum.IsZero(newR) THEN EXIT(*LOOP*) END(*LOOP*);
            BigNum.Discard (oldR);  oldR := R;  R := newR;
            BigNum.Discard (xold);  xold := x;  x := xnew;
        END (*LOOP*);
        BigNum.Discard (Q);  BigNum.Discard (oldR);
        BigNum.Discard (R);  BigNum.Discard (newR);
        BigNum.Discard (xold);  BigNum.Discard (xnew);
        (*WriteString ("Leaving exgcd, ");  ShowBNUsage;*)
        RETURN x;
    END exgcd;

(************************************************************************)

PROCEDURE Inverse (a, m: BN): BN;

    (* Returns inverse(a) mod m.                    *)
    (* Assumption: a and m are relatively prime.    *)

    VAR result, temp: BN;

    BEGIN
        result := exgcd (a, m);
        IF BigNum.Sign(result) < 0 THEN
            temp := BigNum.Sum (result, m);
            BigNum.Discard (result);
            result := temp;
        END (*IF*);
        RETURN result;
    END Inverse;

(************************************************************************)

PROCEDURE lcm (a, b: BN): BN;

    (* Least common multiple. *)

    VAR c, d, Q, R: BN;

    BEGIN
        IF BigNum.Cmp (a, b) = 0 THEN
            RETURN BigNum.CopyBN (a);
        ELSE
            c := gcd(a,b);
            d := BigNum.Prod (a, b);
            BigNum.Divide (d, c, Q, R);
            BigNum.Discard (R);
            BigNum.Discard (c);
            BigNum.Discard (d);
            RETURN Q;
        END (*IF*);
    END lcm;

(************************************************************************)
(*                           RSA KEY GENERATION                         *)
(************************************************************************)

PROCEDURE MakeMask (nbits: CARDINAL): CARDINAL;

    (* Creates a mask to define the rightmost nbits bits of a word. *)
    (* We assume 0 < nbits <= 32.                                   *)

    VAR j, mask: CARDINAL;

    BEGIN
        j := 32;  mask := 0FFFFFFFFH;
        WHILE j > nbits DO
            mask := mask DIV 2;  DEC(j);
        END (*WHILE*);
        RETURN mask;
    END MakeMask;

(************************************************************************)

PROCEDURE RSA_Generate (keylength: CARDINAL): RSAKeyType;

    (* Parameter keylength is the desired key length in bits.   *)
    (* For security, it should really be at least 1024, and     *)
    (* preferably 4096 bits.                                    *)

    VAR p, q, pm1, qm1, n, lambda, e, d, temp, Q: BN;
        bitsover, L, mask: CARDINAL;
        key: RSAKeyType;

    BEGIN
        IF keylength < 64 THEN
            WriteString ("ERROR: specified key size is too small.");  WriteLn;
            HALT;
        ELSE
            WriteString ("Desired key length is ");  WriteCard(keylength);
            WriteString (" bits");  WriteLn;
        END (*IF*);
        bitsover := keylength MOD 32;
        keylength := keylength DIV 32;      (* now in words *)

        (* Choose two distinct prime numbers p and q.  Let n = p*q *)

        WriteString ("Looking for first prime");  WriteLn;
        p := RandomPrime(keylength DIV 2, 0FFFFFFFFH, TRUE);
        WriteString ("Looking for second prime");  WriteLn;
        L := keylength - keylength DIV 2;
        IF bitsover = 0 THEN
            mask := 0FFFFFFFFH;
        ELSE
            INC (L);
            mask := MakeMask (bitsover);
        END (*IF*);
        q := RandomPrime(L, mask, TRUE);
        (*
        IF TESTING THEN
            WriteString ("We have chosen two random primes,");  WriteLn;
            WriteString ("p = ");  BigNum.WriteBignum (p);  WriteLn;
            WriteString ("q = ");  BigNum.WriteBignum (q);  WriteLn;
        END (*IF*);
        *)
        n := BigNum.Prod (p, q);

        pm1 := BigNum.CopyBN(p);
        BigNum.Decr (pm1);
        qm1 := BigNum.CopyBN(q);
        BigNum.Decr (qm1);

        lambda := lcm (pm1, qm1);

        IF TESTING THEN
            (*WriteString ("lambda = ");  BigNum.WriteBignum (lambda);  WriteLn;*)
        END (*IF*);

        (* Choose e such that 1 < e < lambda that is coprime to lambda. *)
        (* Making e prime simplifies the coprime check.                 *)

        e := BigNum.MakeBignum (65537);
        temp := gcd (e, lambda);
        BigNum.Decr (temp);
        IF NOT BigNum.IsZero (temp) THEN
            WriteString ("Poorly chosen e. Try again.");  WriteLn;
            HALT;
        END (*IF*);
        BigNum.Discard (temp);

        (* Solve d*e mod lambda = 1 for d. *)
        (* NOTE: standards might require 0 < d < lambda, so might need  *)
        (* to reduce d modulo lambda here.                              *)

        d := Inverse (e, lambda);
        BigNum.Discard (lambda);

        key.n := n;
        key.public := e;
        key.private := d;
        key.p := p;
        key.q := q;
        BigNum.Divide (d, pm1, Q, key.dp);
        BigNum.Discard (Q);
        BigNum.Divide (d, qm1, Q, key.dq);
        BigNum.Discard (Q);

        key.qinv := Inverse (q, p);

        BigNum.Discard (pm1);
        BigNum.Discard (qm1);

        RETURN key;

    END RSA_Generate;

(************************************************************************)

END RSAGenerate.

