Example #1
2
def encode(m, embits, hash_class=hashlib.sha1,
        mgf=mgf.mgf1, salt=None, s_len=None, random=random.SystemRandom):

    m_hash = hash_class(m).digest()
    h_len = len(m_hash)
    if salt is not None:
        s_len = len(salt)
    else:
        if s_len is None:
            s_len = h_len
        salt = primitives.i2osp(random().getrandbits(s_len*8), s_len)
    em_len = primitives.integer_ceil(embits, 8)
    if em_len < len(m_hash) + s_len + 2:
        raise exceptions.EncodingError
    m_prime = ('\x00' * 8) + m_hash + salt
    h = hash_class(m_prime).digest()
    ps = '\x00' * (em_len - s_len - h_len - 2)
    db = ps + '\x01' + salt
    db_mask = mgf(h, em_len - h_len - 1)
    masked_db = primitives.string_xor(db, db_mask)
    octets, bits = (8 * em_len - embits) / 8, (8*em_len-embits) % 8
    # replace first `octets' bytes
    masked_db = ('\x00' * octets) + masked_db[octets:]
    new_byte = chr(ord(masked_db[octets]) & (255 >> bits))
    masked_db = masked_db[:octets] + new_byte + masked_db[octets+1:]
    return masked_db + h + '\xbc'
Example #2
0
def decrypt(private_key,
            message,
            label='',
            hash_class=hashlib.sha1,
            mgf=mgf.mgf1):
    '''Decrypt a byte message using a RSA private key and the OAEP wrapping algorithm,

       Parameters:
       public_key - an RSA public key
       message - a byte string
       label - a label a per-se PKCS#1 standard
       hash_class - a Python class for a message digest algorithme respecting
         the hashlib interface
       mgf1 - a mask generation function

       Return value:
       the string before encryption (decrypted)
    '''
    hash = hash_class()
    h_len = hash.digest_size
    k = private_key.byte_size
    # 1. check length
    if len(message) != k or k < 2 * h_len + 2:
        raise ValueError('decryption error')
    # 2. RSA decryption
    c = primitives.os2ip(message)
    m = private_key.rsadp(c)
    em = primitives.i2osp(m, k)
    # 4. EME-OAEP decoding
    hash.update(label)
    label_hash = hash.digest()
    y, masked_seed, masked_db = em[0], em[1:h_len + 1], em[1 + h_len:]
    if y != '\x00':
        raise ValueError('decryption error')
    seed_mask = mgf(masked_db, h_len, hash_class=hash_class)
    seed = primitives.string_xor(masked_seed, seed_mask)
    db_mask = mgf(seed, k - h_len - 1, hash_class=hash_class)
    db = primitives.string_xor(masked_db, db_mask)
    label_hash_prime, rest = db[:h_len], db[h_len:]
    i = rest.find('\x01')
    if i == -1:
        raise exceptions.DecryptionError
    if rest[:i].strip('\x00') != '':
        raise exceptions.DecryptionError
    m = rest[i + 1:]
    if label_hash_prime != label_hash:
        raise exceptions.DecryptionError
    return m
Example #3
0
def encrypt(public_key,
            message,
            label='',
            hash_class=hashlib.sha1,
            mgf=mgf.mgf1,
            seed=None,
            rnd=default_crypto_random):
    '''Encrypt a byte message using a RSA public key and the OAEP wrapping
       algorithm,

       Parameters:
       public_key - an RSA public key
       message - a byte string
       label - a label a per-se PKCS#1 standard
       hash_class - a Python class for a message digest algorithme respecting
         the hashlib interface
       mgf1 - a mask generation function
       seed - a seed to use instead of generating it using a random generator
       rnd - a random generator class, respecting the random generator
       interface from the random module, if seed is None, it is used to
       generate it.

       Return value:
       the encrypted string of the same length as the public key
    '''

    hash = hash_class()
    h_len = hash.digest_size
    k = public_key.byte_size
    max_message_length = k - 2 * h_len - 2
    if len(message) > max_message_length:
        raise exceptions.MessageTooLong
    hash.update(label)
    label_hash = hash.digest()
    ps = '\0' * int(max_message_length - len(message))
    db = ''.join((label_hash, ps, '\x01', message))
    if not seed:
        seed = primitives.i2osp(rnd.getrandbits(h_len * 8), h_len)
    db_mask = mgf(seed, k - h_len - 1, hash_class=hash_class)
    masked_db = primitives.string_xor(db, db_mask)
    seed_mask = mgf(masked_db, h_len, hash_class=hash_class)
    masked_seed = primitives.string_xor(seed, seed_mask)
    em = ''.join(('\x00', masked_seed, masked_db))
    m = primitives.os2ip(em)
    c = public_key.rsaep(m)
    output = primitives.i2osp(c, k)
    return output
Example #4
0
def decrypt(private_key, message, label='', hash_class=hashlib.sha1,
        mgf=mgf.mgf1):
    '''Decrypt a byte message using a RSA private key and the OAEP wrapping algorithm,

       Parameters:
       public_key - an RSA public key
       message - a byte string
       label - a label a per-se PKCS#1 standard
       hash_class - a Python class for a message digest algorithme respecting
         the hashlib interface
       mgf1 - a mask generation function

       Return value:
       the string before encryption (decrypted)
    '''
    hash = hash_class()
    h_len = hash.digest_size
    k = private_key.byte_size
    # 1. check length
    if len(message) != k or k < 2 * h_len + 2:
        raise ValueError('decryption error')
    # 2. RSA decryption
    c = primitives.os2ip(message)
    m = private_key.rsadp(c)
    em = primitives.i2osp(m, k)
    # 4. EME-OAEP decoding
    hash.update(label)
    label_hash = hash.digest()
    y, masked_seed, masked_db = em[0], em[1:h_len+1], em[1+h_len:]
    if y != '\x00':
        raise ValueError('decryption error')
    seed_mask = mgf(masked_db, h_len)
    seed = primitives.string_xor(masked_seed, seed_mask)
    db_mask = mgf(seed, k - h_len - 1)
    db = primitives.string_xor(masked_db, db_mask)
    label_hash_prime, rest = db[:h_len], db[h_len:]
    i = rest.find('\x01')
    if i == -1:
        raise exceptions.DecryptionError
    if rest[:i].strip('\x00') != '':
        raise exceptions.DecryptionError
    m = rest[i+1:]
    if label_hash_prime != label_hash:
        raise exceptions.DecryptionError
    return m
Example #5
0
def verify(m, em, embits, hash_class=hashlib.sha1, mgf=mgf.mgf1, s_len=None):
    '''
       Verify that a message padded using the PKCS#1 v2 PSS algorithm matched a
       given message string.

       m - the message to match
       em - the padded message
       embits - the length in bits of the padded message
       hash_class - the hash algorithm used to compute the digest of the message
       mgf - the mask generation function
       s_len - the length of the salt string, if None the length of the digest is used.

       Return: True if the message matches, False otherwise.
    '''
    # 1. cannot verify, does not know the max input length of hash_class
    # 2.
    m_hash = hash_class(m).digest()
    h_len = len(m_hash)
    if s_len is None:
        s_len = h_len
    em_len = primitives.integer_ceil(embits, 8)
    # 3.
    if em_len < len(m_hash) + s_len + 2:
        return False
    # 4.
    if em[-1] != '\xbc':
        return False
    # 5.
    masked_db, h = em[:em_len-h_len-1], em[em_len-h_len-1:-1]
    # 6.
    octets, bits = (8 * em_len - embits) / 8, (8*em_len-embits) % 8
    zero = masked_db[:octets] + chr(ord(masked_db[octets]) & ~(255 >>bits))
    for c in zero:
        if c != '\x00':
            return False
    # 7.
    db_mask = mgf(h, em_len - h_len - 1, hash_class=hash_class)
    # 8.
    db = primitives.string_xor(masked_db, db_mask)
    # 9.
    new_byte = chr(ord(db[octets]) & (255 >> bits))
    db = ('\x00' * octets) + new_byte + db[octets+1:]
    # 10.
    for c in db[:em_len-h_len-s_len-2]:
        if c != '\x00':
            return False
    if db[em_len-h_len-s_len-2] != '\x01':
        return False
    # 11.
    salt = db[-s_len:]
    # 12.
    m_prime = ('\x00' * 8) + m_hash + salt
    # 13.
    h_prime = hash_class(m_prime).digest()
    # 14.
    return primitives.constant_time_cmp(h_prime, h)
Example #6
0
def verify(m, em, embits, hash_class=hashlib.sha1, mgf=mgf.mgf1, s_len=None):
    '''
       Verify that a message padded using the PKCS#1 v2 PSS algorithm matched a
       given message string.

       m - the message to match
       em - the padded message
       embits - the length in bits of the padded message
       hash_class - the hash algorithm used to compute the digest of the message
       mgf - the mask generation function
       s_len - the length of the salt string, if None the length of the digest is used.

       Return: True if the message matches, False otherwise.
    '''
    # 1. cannot verify, does not know the max input length of hash_class
    # 2.
    m_hash = hash_class(m).digest()
    h_len = len(m_hash)
    if s_len is None:
        s_len = h_len
    em_len = primitives.integer_ceil(embits, 8)
    # 3.
    if em_len < len(m_hash) + s_len + 2:
        return False
    # 4.
    if em[-1] != '\xbc':
        return False
    # 5.
    masked_db, h = em[:em_len - h_len - 1], em[em_len - h_len - 1:-1]
    # 6.
    octets, bits = (8 * em_len - embits) / 8, (8 * em_len - embits) % 8
    zero = masked_db[:octets] + chr(ord(masked_db[octets]) & ~(255 >> bits))
    for c in zero:
        if c != '\x00':
            return False
    # 7.
    db_mask = mgf(h, em_len - h_len - 1)
    # 8.
    db = primitives.string_xor(masked_db, db_mask)
    # 9.
    new_byte = chr(ord(db[octets]) & (255 >> bits))
    db = ('\x00' * octets) + new_byte + db[octets + 1:]
    # 10.
    for c in db[:em_len - h_len - s_len - 2]:
        if c != '\x00':
            return False
    if db[em_len - h_len - s_len - 2] != '\x01':
        return False
    # 11.
    salt = db[-s_len:]
    # 12.
    m_prime = ('\x00' * 8) + m_hash + salt
    # 13.
    h_prime = hash_class(m_prime).digest()
    # 14.
    return primitives.constant_time_cmp(h_prime, h)
Example #7
0
def encrypt(public_key, message, label='', hash_class=hashlib.sha1,
        mgf=mgf.mgf1, seed=None, random=random.SystemRandom):
    '''Encrypt a byte message using a RSA public key and the OAEP wrapping
       algorithm,

       Parameters:
       public_key - an RSA public key
       message - a byte string
       label - a label a per-se PKCS#1 standard
       hash_class - a Python class for a message digest algorithme respecting
         the hashlib interface
       mgf1 - a mask generation function
       seed - a seed to use instead of generating it using a random generator
       random - a random generator class, respecting the random generator
       interface from the random module, if seed is None, it is used to
       generate it.

       Return value:
       the encrypted string of the same length as the public key
    '''

    hash = hash_class()
    h_len = hash.digest_size
    k = public_key.byte_size
    max_message_length = k - 2 * h_len - 2
    if len(message) > max_message_length:
        raise exceptions.MessageTooLong
    hash.update(label)
    label_hash = hash.digest()
    ps = '\0' * int(max_message_length - len(message))
    db = ''.join((label_hash, ps, '\x01', message))
    if not seed:
        seed = primitives.i2osp(random().getrandbits(h_len*8), h_len)
    db_mask = mgf(seed, k - h_len - 1, hash_class=hash_class)
    masked_db = primitives.string_xor(db, db_mask)
    seed_mask = mgf(masked_db, h_len, hash_class=hash_class)
    masked_seed = primitives.string_xor(seed, seed_mask)
    em = ''.join(('\x00', masked_seed, masked_db))
    m = primitives.os2ip(em)
    c = public_key.rsaep(m)
    output = primitives.i2osp(c, k)
    return output
Example #8
0
def encode(m,
           embits,
           hash_class=hashlib.sha1,
           mgf=mgf.mgf1,
           salt=None,
           s_len=None,
           rnd=default_crypto_random):
    '''Encode a message using the PKCS v2 PSS padding.

       m - the message to encode
       embits - the length of the padded message
       mgf - a masg generating function, default is mgf1 the mask generating
       function proposed in the PKCS#1 v2 standard.
       hash_class - the hash algorithm to use to compute the digest of the
       message, must conform to the hashlib class interface.
       salt - a fixed salt string to use, if None, a random string of length
       s_len is used instead, necessary for tests,
       s_len - the length of the salt string when using a random generator to
       create it, if None the length of the digest is used.
       rnd - the random generator used to compute the salt string

       Return value: the padded message
    '''
    m_hash = hash_class(m).digest()
    h_len = len(m_hash)
    if salt is not None:
        s_len = len(salt)
    else:
        if s_len is None:
            s_len = h_len
        salt = primitives.i2osp(rnd.getrandbits(s_len * 8), s_len)
    em_len = primitives.integer_ceil(embits, 8)
    if em_len < len(m_hash) + s_len + 2:
        raise exceptions.EncodingError
    m_prime = ('\x00' * 8) + m_hash + salt
    h = hash_class(m_prime).digest()
    ps = '\x00' * (em_len - s_len - h_len - 2)
    db = ps + '\x01' + salt
    db_mask = mgf(h, em_len - h_len - 1)
    masked_db = primitives.string_xor(db, db_mask)
    octets, bits = (8 * em_len - embits) / 8, (8 * em_len - embits) % 8
    # replace first `octets' bytes
    masked_db = ('\x00' * octets) + masked_db[octets:]
    new_byte = chr(ord(masked_db[octets]) & (255 >> bits))
    masked_db = masked_db[:octets] + new_byte + masked_db[octets + 1:]
    return masked_db + h + '\xbc'
Example #9
0
def verify(m, em, embits, hash_class=hashlib.sha1, mgf=mgf.mgf1, s_len=None):
    # 1. cannot verify, does not know the max input length of hash_class
    # 2.
    m_hash = hash_class(m).digest()
    h_len = len(m_hash)
    if s_len is None:
        s_len = h_len
    em_len = primitives.integer_ceil(embits, 8)
    # 3.
    if em_len < len(m_hash) + s_len + 2:
        return False
    # 4.
    if em[-1] != '\xbc':
        return False
    # 5.
    masked_db, h = em[:em_len-h_len-1], em[em_len-h_len-1:-1]
    # 6.
    octets, bits = (8 * em_len - embits) / 8, (8*em_len-embits) % 8
    zero = masked_db[:octets] + chr(ord(masked_db[octets]) & ~(255 >>bits))
    for c in zero:
        if c != '\x00':
            return False
    # 7.
    db_mask = mgf(h, em_len - h_len - 1)
    # 8.
    db = primitives.string_xor(masked_db, db_mask)
    # 9.
    new_byte = chr(ord(db[octets]) & (255 >> bits))
    db = ('\x00' * octets) + new_byte + db[octets+1:]
    # 10.
    for c in db[:em_len-h_len-s_len-2]:
        if c != '\x00':
            return False
    if db[em_len-h_len-s_len-2] != '\x01':
        return False
    # 11.
    salt = db[-s_len:]
    # 12.
    m_prime = ('\x00' * 8) + m_hash + salt
    # 13.
    h_prime = hash_class(m_prime).digest()
    # 14.
    return primitives.constant_time_cmp(h_prime, h)
Example #10
0
def encode(m, embits, hash_class=hashlib.sha1,
        mgf=mgf.mgf1, salt=None, s_len=None, rnd=default_crypto_random):
    '''Encode a message using the PKCS v2 PSS padding.

       m - the message to encode
       embits - the length of the padded message
       mgf - a masg generating function, default is mgf1 the mask generating
       function proposed in the PKCS#1 v2 standard.
       hash_class - the hash algorithm to use to compute the digest of the
       message, must conform to the hashlib class interface.
       salt - a fixed salt string to use, if None, a random string of length
       s_len is used instead, necessary for tests,
       s_len - the length of the salt string when using a random generator to
       create it, if None the length of the digest is used.
       rnd - the random generator used to compute the salt string

       Return value: the padded message
    '''
    m_hash = hash_class(m).digest()
    h_len = len(m_hash)
    if salt is not None:
        s_len = len(salt)
    else:
        if s_len is None:
            s_len = h_len
        salt = primitives.i2osp(rnd.getrandbits(s_len*8), s_len)
    em_len = primitives.integer_ceil(embits, 8)
    if em_len < len(m_hash) + s_len + 2:
        raise exceptions.EncodingError
    m_prime = ('\x00' * 8) + m_hash + salt
    h = hash_class(m_prime).digest()
    ps = '\x00' * (em_len - s_len - h_len - 2)
    db = ps + '\x01' + salt
    db_mask = mgf(h, em_len - h_len - 1, hash_class=hash_class)
    masked_db = primitives.string_xor(db, db_mask)
    octets, bits = (8 * em_len - embits) / 8, (8*em_len-embits) % 8
    # replace first `octets' bytes
    masked_db = ('\x00' * octets) + masked_db[octets:]
    new_byte = chr(ord(masked_db[octets]) & (255 >> bits))
    masked_db = masked_db[:octets] + new_byte + masked_db[octets+1:]
    return masked_db + h + '\xbc'