# Public domain. No warranty or anything.
import sys
from time import time
from Crypto.PublicKey import RSA
from Crypto.PublicKey.RSA import error
from random import randint, seed, random

HEX = False

def r(l):
    seed(time())
    o = ''
    for i in xrange(l):
        o += chr(int(randint(1, 0xFF)))
    return o

def sign(n,e,d,message):
    try:
        n = long(n,16)
    except:
        pass
    try:
        e = long(e,16)
    except:
        pass
    try:
        d = long(d,16)
    except:
        pass
    
    key = RSA.construct((n,e,d))
    signature = key.sign(message,0)[0]
    signature = hex(signature)[2:-1]
    signature = '0x'+' 0x'.join(signature[i:i+2] for i in
                                xrange(0,len(signature),2))
    return n,e,d,signature

def encrypt(n,e,d,plaintext,algorithm='RSA'):
    #    """returns n,e,d,cryptogram"""
    theKey = RSA.construct((n,e))
    cryptogram = theKey.encrypt(plaintext,random())[0]
    cryptogram = cgramtoinput(cryptogram)
    return (n,e,d,cryptogram)

def generate_key(n,e,d,passthrough,algorithm='RSA'):
    key = RSA.generate(1024,r)
    n = key.n
    e = key.e
    d = key.d
    return n,e,d,passthrough    

def decrypt(n,e,d,cryptogram,hexa,algorithm='RSA'):
    print n
    print e
    print d
    print cryptogram
    print hexa
    
    
    if not hexa:
        cryptogram = inputtocgram(cryptogram)
    theKey = RSA.construct((n,e,d))
    plaintext = theKey.decrypt(cryptogram)
    if hexa:
        plaintext = cgramtoinput(plaintext)
    return (n,e,d,plaintext)

def inputtocgram(plaintext):
    listed = plaintext.split()
    split = listed if len(listed) > 1 else plaintext
    return ''.join(chr(int(char, 16)) for char in split)

def cgramtoinput(cryptogram):
    return ' '.join(hex(ord(line)) for line in cryptogram)

def prettify(text,width=80):
    if len(text) > width:
        ttext = ''
        for i in xrange(len(text)/width):
            w = (i+1)*width
            s = i*width
            ttext += text[s:w]
            ttext += '\n'
        if float(len(text))/width != len(text)/width:
            ttext += text[w:]
        return ttext
    else:
        return text

def unprettify(text):
    text = ''.join(text.split())
    return text


def unhexify(text):
    """Turn a hex number in text into a string representing that number."""
    tlist = text.split()
    def too_big(s):
        if int(s, 16) > 0xFF:
            return True
        return False
    if True in map(too_big, tlist):
        nlist = []
        for num in tlist:
            num = int(num, 16)
            while num:
                nlist.append(chr(num & 0xFF))
                num >>= 8
        nlist.reverse()
        return ''.join(nlist)
    ttext = ''.join(chr(int(ch, 16)) for ch in tlist)
    return ttext
