import sys
from Crypto import Random
import binascii

def int2bytes(i, length):
    hex_string = '%x' % i
    return binascii.unhexlify(hex_string.zfill(length << 1))

def bytes2int(str):
    return int(str.encode('hex'), 16)

def E(key, data):
    if key >= (1 << 256) or data >= (1 << 512):
        raise Exception("E input data too large")
    dlow = data & ((1<<256)-1)
    dhigh = data >> 256
    for i in range(4):
        dlow ^= H(key, dhigh)
        dhigh ^= H(key, dlow)
    dlow ^= H(key, dhigh)
    return (dhigh << 256) | dlow

def H(key, data):
    state = key << 256 | data
    for i in range(64):
        state = doRound(state)
    return state & ((1 << 256)-1)

def doRound(state):
    vlow = state & ((1 << 256)-1)
    vhigh = state >> 256
    vlow = (vlow*(vhigh | 1)) & ((1 << 256)-1)
    vhigh ^= vlow
    vlow = rotateBits(vlow, 39)
    vhigh = rotateBits(vhigh, 209)
    return (vhigh << 256) | vlow

def rotateBits(data, dist):
    return (data >> dist) | ((data & ((1<<256-dist)-1)) << (256-dist))

def blocksFromFile(f):
    while True:
        chunk = f.read(512/8)
        if chunk:
            yield bytes2int(chunk)
        else:
            break

def testE():
    data = 0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef
    key = 0xaaaabbbbccccddddeeee
    encData = E(key, data)
    print "Encrypted data %x" % encData
    decData = E(key, encData)
    print "Decrypted data %x" % decData

def encryptFile(filename, key):
    if filename.endswith(".enc"):
        outFilename = filename[:-4]
        encrypting = False
    else:
        outFilename = filename + ".enc"
        encrypting = True
    with open(filename, "rb") as infile:
        with open(outFilename, "wb") as outfile:
            if encrypting:
                rndfile = Random.new()
                nonce = bytes2int(rndfile.read(32))
                outfile.write(int2bytes(nonce, 32))
            else:
                nonce = bytes2int(infile.read(32))
            key ^= nonce
            for data in blocksFromFile(infile):
                encData = E(key, data)
                outfile.write(int2bytes(encData, 64))
                key = (key + 1) & ((1 << 256)-1)

if len(sys.argv) != 3:
    print "Usage: python nucrypt.py password filename"
    sys.exit(1)
key = bytes2int(sys.argv[1])
filename = sys.argv[2]
encryptFile(filename, key)
