#
# OpenID.py
#    an OpenID (http://www.danga.com/openid) server library for python
#    (now at http://openid.net)
#
#
#
#  Usage:  
#
#    s = OpenID.Server(privkey, pubkey, current_user, is_user, is_trusted, 'http://setup/url')
#    if 'openid.mode' in cgi_get_vars():
#        type, data = s.handle_page(cgi_get_vars())
#        if type == 'setup':
#            ...
#        elif type == 'redirect':
#           # redirect to the url in 'data'
#        else:
#            print "Content-type: %s\n\n%s\n" % (type, data)
#


class Server(object):

    def __init__(self, private_key, public_key, get_user_sub, is_identity_sub, is_trusted_sub, setup_url, setup_map=None):
        self.private_key = private_key
	self.public_key = public_key
	self.get_user_sub = get_user_sub
	self.is_identity_sub = is_identity_sub
	self.is_trusted_sub = is_trusted_sub
	self.setup_url = setup_url
	if setup_map is None:
	    self.setup_map = { 'trust_root': 'trust_root', 
	                       'return_to':'return_to',
			       'post_grant': 'post_grant',
			       'is_identity':'is_identity' }
	else:
	    self.setup_map = setup_map

    def handle_page(self, args):
        # returns (type, data) where

        mode = args.get('openid.mode')
	if mode == 'getpubkey':
            return ('text/plain', self.public_key)
	
	if not mode in ['checkid_immediate', 'checkid_setup']:
	    return self._fail('unknown mode')
	
	return_to = args.get('openid.return_to','')
	if not (return_to.starts_with('http://') or return_to.starts_with('https://')):
	    return self._fail('no_return_to')

	trust_root = args.get('openid.trust_root', return_to)
	if not self._url_is_under(trust_root, return_to):
	    return self._fail('invalid_trust_root')

        hookindex = trust_root.find('?')
	trust_root = trust_root[:hookindex]

	identity = args.get('openid.is_identity')
	u = self.get_user_sub()
	is_identity = self.is_identity_sub(u, identity)
	is_trusted = self.is_trusted_sub(u, trust_root, is_identity)

	if is_identity and is_trusted:
	    return_url = _set_getarg(return_to, {'openid.mode': 'id_res'})
	    now = _w3c_time()
	    plain = '::'.join([ now, 'assert_identity', identity, return_to ])
	    sig = self._dsaSign(plain, self.private_key)
	    sig64 = self._b64_encode(sig)
	    return_url = _set_getarg(return_url, { 'openid.assert_identity' : identity,
	                                           'openid.sig' : sig64,
	                                           'openid.timestamp' : now,
	                                           'openid.return_to' : return_to })

	    return ('redirect', return_url)
	
        # assertion could not be made, so user requires setup (login/trust.. something)
        # two ways that can happen:  caller might have asked us for an immediate return
	# with a setup URL (the default), or explictly said that we're in control of
        # the user-agent's full window, and we can do whatever we want with them now.
	        
	newargs = { self.setup_map['trust_root']: trust_root,
		    self.setup_map['return_to']: return_to,
		    self.setup_map['is_identity']: identity }

	if mode == 'checkid_setup':
	    return ('setup', newargs)

	newargs['openid.user_setup_url'] = self.setup_url
        return_url = _set_getarg(self.setup_url, newargs)
        return ('redirect', return_url)

    def _fail(self, error):
        return ('text/plain', 'Error: %s\n' % error)

    def _w3c_time(self):
        # return the timestamp in w3c format. It must be in the UTC
        #  timezone, indicated with a "Z", like: "2005-05-15T17:11:51Z".
        import time
        return time.strftime('%Y-%m-%dT%H:%M:%SZ", time.gmtime())

    def _b64_encode(self, raw):
        pass

    def _set_getarg(url, paramdict):
        pass
       
    def _dsaSign(self, plaintext, key):
        pass


    # from here down taken from http://adam.bregenzer.net/python/typekey/TypeKey.py

    def _dsaVerify(self, message, sig, key):
        """Verify a DSA signature
        """
        hash_m = long(binascii.hexlify(sha.new(message).digest()), 16)

        (r_sig, s_sig) = sig.split(':')
        r_sig = long(binascii.hexlify(base64.decodestring(r_sig)), 16)
        s_sig = long(binascii.hexlify(base64.decodestring(s_sig)), 16)

        w = self._invert(s_sig, key['q'])

        u1 = (hash_m * w) % key['q']
        u2 = (r_sig * w) % key['q']

        v = ((self._powerModulus(key['g'], u1, key['p']) *
              self._powerModulus(key['pub_key'], u2, key['p'])
             ) % key['p']
            ) % key['q']

        return v == r_sig


    def _invert(self, x, y):
        """Return the inverse of x and y
        """
        while x < 0:
            x += y

        gcd = self._gcd(x, y)

        if gcd[2] == 1:
            inverse = gcd[0]
            while inverse < 0:
                inverse += y
            return inverse
        else:
            return False


    def _gcd(self, x, y):
        """Return the greatest common denominator of x and y
        """
        (a, a_last) = (1, 0)
        (b, b_last) = (0, 1)

        while y > 0:
            q = x / y
            (x, y) = (y, x % y)
            (a, a_last) = (a_last, a - (q * a_last))
            (b, b_last) = (b_last, b - (q * b_last))

        return (a, b, x)


    def _powerModulus(self, base, exp, mod):
        """Return the result of (base ** exp) % mod
        """
        if exp == 1:
            return base % mod
        elif (exp % 2) == 0:
            return pow(self._powerModulus(base, exp / 2, mod), 2) % mod
        else:
            return (base * self._powerModulus(base, exp - 1, mod)) % mod


