from math import *
from fractions import *
import numpy as np
import matplotlib.pyplot as plt

# This is likely an error on Wikipedia
def findBsq(m):
    # This is wikipedia's version
    return 1/(1-m)
    # This matches Wolfram Alpha
    #return 1-m

class Point:

    def __init__(self, cn, sn, dn, m):
        self.cn = cn
        self.sn = sn
        self.dn = dn
        self.m = m
        self.verifyPoint()

    def verifyPoint(self):
        bsq = findBsq(self.m)
        if (abs(self.cn**2 + self.sn**2 - 1) > 1.0e-3 or
            abs(self.cn**2 + self.sn**2/bsq - self.dn**2) > 1.0e-3):
            raise Exception("Invalid point %s" % self)

def ecAdd(a, b):
    c1, s1, d1 = a.cn, a.sn, a.dn
    c2, s2, d2 = b.cn, b.sn, b.dn
    if a.m != b.m:
        raise Exception("Points are on different curves")
    m = a.m
    N = 1 - m*s1**2*s2**2
    c3 = (c1*c2 - s1*s2*d1*d2) / N
    s3 = (s1*c2*d2 + s2*c1*d1) / N
    d3 = (d1*d2 - m*s1*s2*c1*c2) / N
    return Point(c3, s3, d3, m)

def ecAddEdwards(a, b):
    """
    x^2 + y^2 = 1 + d*x^2*y^2
    y^2(1 - d*x^2) = 1 - x^2
    y = sqrt(1 - x^2)/(1 - d*x^2)
    
    x3 = (x1*y2 + x2*y1)/(1 + d*x1*x2*y1*y2)
    y3 = (y1*y2 - x1*x2)/(1 - d*x1*x2*y1*y2)
    """

    d = a.m
    x1 = a.sn
    x2 = b.sn
    y1 = sqrt((1 - x1**2)/(1 - d*x1**2))
    y2 = sqrt((1 - x2**2)/(1 - d*x2**2))
    if a.cn < 0:
        y1 = -y1
    if b.cn < 0:
        y2 = -y2
    x3 = (x1*y2 + x2*y1)/(1 + d*x1*x2*y1*y2)
    y3 = (y1*y2 - x1*x2)/(1 - d*x1*x2*y1*y2)
    return x3, y3

# Different from Wolfram's definition, but according to Wikipedia, m = 3/4 makes b = 2
#m = 3.0/4.0
m = 3.0/4.0
print "Bsq", findBsq(m)
sn = 0.3  # Generator point
cn = sqrt(1 - sn**2)
# This matches Wikipedia
#dn = 1.0/sqrt(cn**2 + sn**2/findBsq(m))
# This matches other sources
dn = sqrt(cn**2 + sn**2/findBsq(m))
g = Point(cn, sn,dn, m)
b = sqrt(findBsq(m))
xlistEllipse = []
ylistEllipse = []
xlistEdwards = []
ylistEdwards = []
p = Point(1, 0, 1, m)
for i in range(50):
    xlistEllipse.append(p.cn)
    ylistEllipse.append(p.sn*b)
    (x, y) = ecAddEdwards(p, g)
    xlistEdwards.append(x)
    ylistEdwards.append(y)
    p = ecAdd(p, g)
    print "%d: (x, y): (%f, %f)  (sn, cn): (%f, %f)" % (i, x, y, p.sn, p.cn)
    print "x - sn = %f", x - p.sn

maxDim = max(max(xlistEllipse), max(ylistEllipse))
xlistEllipse.append(-maxDim) # Initialize to corner points to force scaling in plot to look right
xlistEllipse.append(maxDim)
ylistEllipse.append(-maxDim)
ylistEllipse.append(maxDim)

plt.scatter(xlistEllipse, ylistEllipse)
plt.show()
plt.scatter(xlistEdwards, ylistEdwards)
plt.show()
