"""
We can create addition operators for group operations easily using the following
format:

    a @ b = Finv(F(a) + F(b))

This is a group operator if there is a 0 element, normally Finv(0).  The '+' can
be a '*' if desired, but it may be slower.

Associativity is easy to prove:

    (a @ b) @ c = Finv(F(Finv(F(a) + F(b))) + F(c)) = Finv(F(a) + F(b) + F(c))

It's obviously communitive.  Be careful to prove there is an identity element,
and that the inverse existes for everything.

Try using this for something simple:

  let F(x) = 1/x for x != 0, and 0 for x = 0
  Finv(x)  = 1/x for x != 0, and 0 for x = 0
  Addition rule: 1/(1/a + 1/b), or 0 if a == 0, b == 0, or 1/a + 1/b == 0

We need a "0" element, which is why we say F(0) = 0.  Now check for inverse:

  inv(a) = Finv(-F(a)) = 1/(-1/a) = -a

We can define this on the open interval (-1, 1), which excludes 1, and -1, which
map to themselves, which would cause problems.

Make it faster for modular arithmetic:
Only have to do the modular inverse at end of computation if we track
nuemerator/denominator sepearately

  1/(ad/an + bd/bn) = 1/((ad*bn + an*bd)/(an*bn) = an*bn/(ad*bn + an*bd)
"""
import random

# from stackexchange
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modular inverse for %d does not exist mod %d' % (a, m))
    else:
        return x % m

def toInt(an, ad, p):
  return (an*modinv(ad, p)) % p

def add(an, ad, bn, bd, p):
  """Add two rational values mod p, without doing the inverse"""
  if ad == 0 or bd == 0:
    raise Exception("Bad denominator")
  if an == 0:
    return bn, bd
  if bn == 0:
    return an, ad
  rn, rd = (an*bn) % p, (ad*bn + an*bd) % p
  if rd == 0:
    # This happens if a == -b
    return 0, 1
  return rn, rd


def double(an, ad, p):
  ad <<= 1
  if ad >= p:
    ad -= p
  return an, ad

def mul(m, an, ad, p):
    rn, rd = 0, 1
    while m != 0:
        if m & 1:
            rn, rd = add(rn, rd, an, ad, p)
        m >>= 1
        an, ad = double(an, ad, p)
    return rn, rd

def  printGroup(g, p):
  print g, "=>",
  rn, rd = g, 1
  print rn,
  rn, rd = mul(g, rn, rd, p)
  rn, rd = toInt(rn, rd, p), 1
  while rn != g:
    print rn,
    rn, rd = mul(g, rn, rd, p)
    rn, rd = toInt(rn, rd, p), 1
  print

p = 23

for i in range(1, p):
  printGroup(i, p)

p = pow(2, 255) - 19
gn, gd = 1, 1
sa = random.randrange(p/2, p)
sb = random.randrange(p/2, p)
print "sa:", sa
print "sb:", sb
An, Ad = mul(sa, gn, gd, p)
A = (An * modinv(Ad, p)) % p
Bn, Bd = mul(sb, gn, gd, p)
B = toInt(Bn, Bd, p)
print "A:", A
print "B:", B
aliceMSn, aliceMSd = mul(sa, B, 1, p)
bobMSn, bobMSd = mul(sb, A, 1, p)
aliceMS = toInt(aliceMSn, aliceMSd, p);
bobMS = toInt(bobMSn, bobMSd, p)
print "Alice:", aliceMS
print "Bob:", bobMS
assert aliceMS == bobMS
