示例#1
0
def raw_rsa_public_crypt(certificate_or_public_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a certificate or
    public key. This is a low-level primitive and is prone to disastrous results
    if used incorrectly.

    :param certificate_or_public_key:
        An oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
        object

    :param data:
        A byte string of the signature when verifying, or padded plaintext when
        encrypting. Must be less than or equal to the length of the public key.
        When verifying, padding will need to be removed afterwards. When
        encrypting, padding must be applied before.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    valid_types = (PublicKeyInfo, Certificate)
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1,
                                      valid_types):
        raise TypeError(
            pretty_message(
                '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''', type_name(certificate_or_public_key)))

    algo = certificate_or_public_key.asn1['algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            certificate_or_public_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_public_key = certificate_or_public_key.asn1['public_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_public_key['public_exponent'].native,
                          rsa_public_key['modulus'].native)
    return int_to_bytes(transformed_int,
                        width=certificate_or_public_key.asn1.byte_size)
示例#2
0
def raw_rsa_private_crypt(private_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a private key.
    This is a low-level primitive and is prone to disastrous results if used
    incorrectly.

    :param private_key:
        An oscrypto.asymmetric.PrivateKey object

    :param data:
        A byte string of the plaintext to be signed or ciphertext to be
        decrypted. Must be less than or equal to the length of the private key.
        In the case of signing, padding must already be applied. In the case of
        decryption, padding must be removed afterward.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    if not hasattr(private_key, 'asn1') or not isinstance(
            private_key.asn1, PrivateKeyInfo):
        raise TypeError(
            pretty_message(
                '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''', type_name(private_key)))

    algo = private_key.asn1['private_key_algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            private_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_private_key = private_key.asn1['private_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_private_key['private_exponent'].native,
                          rsa_private_key['modulus'].native)
    return int_to_bytes(transformed_int, width=private_key.asn1.byte_size)
示例#3
0
def pbkdf2(hash_algorithm, password, salt, iterations, key_length):
    """
    Implements PBKDF2 from PKCS#5 v2.2 in pure Python

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            password must be a byte string, not %s
            ''', type_name(password)))

    if not isinstance(salt, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            salt must be a byte string, not %s
            ''', type_name(salt)))

    if not isinstance(iterations, int_types):
        raise TypeError(
            pretty_message(
                '''
            iterations must be an integer, not %s
            ''', type_name(iterations)))

    if iterations < 1:
        raise ValueError(
            pretty_message(
                '''
            iterations must be greater than 0 - is %s
            ''', repr(iterations)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 1:
        raise ValueError(
            pretty_message(
                '''
            key_length must be greater than 0 - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''', repr(hash_algorithm)))

    algo = getattr(hashlib, hash_algorithm)

    hash_length = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    blocks = int(math.ceil(key_length / hash_length))

    original_hmac = hmac.new(password, None, algo)

    int_pack = struct.Struct(b'>I').pack

    output = b''
    for block in range(1, blocks + 1):
        prf = original_hmac.copy()
        prf.update(salt + int_pack(block))
        last = prf.digest()
        u = int_from_bytes(last)
        for _ in range(2, iterations + 1):
            prf = original_hmac.copy()
            prf.update(last)
            last = prf.digest()
            u ^= int_from_bytes(last)
        t = int_to_bytes(u)
        output += t

    return output[0:key_length]
示例#4
0
def add_pss_padding(hash_algorithm, salt_length, key_length, message):
    """
    Pads a byte string using the EMSA-PSS-Encode operation described in PKCS#1
    v2.2.

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :return:
        The encoded (passed) message
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding addition code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 512:
        raise ValueError(
            pretty_message(
                '''
            key_length must be 512 or more - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    # The maximal bit size of a non-negative integer is one less than the bit
    # size of the key since the first bit is used to store sign
    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        raise ValueError(
            pretty_message('''
            Key is not long enough to use with specified hash_algorithm and
            salt_length
            '''))

    if salt_length > 0:
        salt = os.urandom(salt_length)
    else:
        salt = b''

    m_prime = (b'\x00' * 8) + message_digest + salt

    m_prime_digest = hash_func(m_prime).digest()

    padding = b'\x00' * (em_len - salt_length - hash_length - 2)

    db = padding + b'\x01' + salt

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    masked_db = int_to_bytes(int_from_bytes(db) ^ int_from_bytes(db_mask))
    masked_db = fill_width(masked_db, len(db_mask))

    zero_bits = (8 * em_len) - em_bits
    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        masked_db = chr_cls(left_int_mask
                            & ord(masked_db[0:1])) + masked_db[1:]

    return masked_db + m_prime_digest + b'\xBC'
示例#5
0
def verify_pss_padding(hash_algorithm, salt_length, key_length, message,
                       signature):
    """
    Verifies the PSS padding on an encoded message

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :param signature:
        The signature to verify

    :return:
        A boolean indicating if the signature is invalid
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding verification code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(signature, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            signature must be a byte string, not %s
            ''', type_name(signature)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        return False

    if signature[-1:] != b'\xBC':
        return False

    zero_bits = (8 * em_len) - em_bits

    masked_db_length = em_len - hash_length - 1
    masked_db = signature[0:masked_db_length]

    first_byte = ord(masked_db[0:1])
    bits_that_should_be_zero = first_byte >> (8 - zero_bits)
    if bits_that_should_be_zero != 0:
        return False

    m_prime_digest = signature[masked_db_length:masked_db_length + hash_length]

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        db_mask = chr_cls(left_int_mask & ord(db_mask[0:1])) + db_mask[1:]

    db = int_to_bytes(int_from_bytes(masked_db) ^ int_from_bytes(db_mask))
    if len(db) < len(masked_db):
        db = (b'\x00' * (len(masked_db) - len(db))) + db

    zero_length = em_len - hash_length - salt_length - 2
    zero_string = b'\x00' * zero_length
    if not constant_compare(db[0:zero_length], zero_string):
        return False

    if db[zero_length:zero_length + 1] != b'\x01':
        return False

    salt = db[0 - salt_length:]

    m_prime = (b'\x00' * 8) + message_digest + salt

    h_prime = hash_func(m_prime).digest()

    return constant_compare(m_prime_digest, h_prime)
示例#6
0
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
    """
    KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :param id_:
        The ID of the usage - 1 for key, 2 for iv, 3 for mac

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(pretty_message(
            '''
            password must be a byte string, not %s
            ''',
            type_name(password)
        ))

    if not isinstance(salt, byte_cls):
        raise TypeError(pretty_message(
            '''
            salt must be a byte string, not %s
            ''',
            type_name(salt)
        ))

    if not isinstance(iterations, int_types):
        raise TypeError(pretty_message(
            '''
            iterations must be an integer, not %s
            ''',
            type_name(iterations)
        ))

    if iterations < 1:
        raise ValueError(pretty_message(
            '''
            iterations must be greater than 0 - is %s
            ''',
            repr(iterations)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 1:
        raise ValueError(pretty_message(
            '''
            key_length must be greater than 0 - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    if id_ not in set([1, 2, 3]):
        raise ValueError(pretty_message(
            '''
            id_ must be one of 1, 2, 3, not %s
            ''',
            repr(id_)
        ))

    utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'

    algo = getattr(hashlib, hash_algorithm)

    # u and v values are bytes (not bits as in the RFC)
    u = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    if hash_algorithm in ['sha384', 'sha512']:
        v = 128
    else:
        v = 64

    # Step 1
    d = chr_cls(id_) * v

    # Step 2
    s = b''
    if salt != b'':
        s_len = v * int(math.ceil(float(len(salt)) / v))
        while len(s) < s_len:
            s += salt
        s = s[0:s_len]

    # Step 3
    p = b''
    if utf16_password != b'':
        p_len = v * int(math.ceil(float(len(utf16_password)) / v))
        while len(p) < p_len:
            p += utf16_password
        p = p[0:p_len]

    # Step 4
    i = s + p

    # Step 5
    c = int(math.ceil(float(key_length) / u))

    a = b'\x00' * (c * u)

    for num in range(1, c + 1):
        # Step 6A
        a2 = algo(d + i).digest()
        for _ in range(2, iterations + 1):
            a2 = algo(a2).digest()

        if num < c:
            # Step 6B
            b = b''
            while len(b) < v:
                b += a2

            b = int_from_bytes(b[0:v]) + 1

            # Step 6C
            for num2 in range(0, len(i) // v):
                start = num2 * v
                end = (num2 + 1) * v
                i_num2 = i[start:end]

                i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)

                # Ensure the new slice is the right size
                i_num2_l = len(i_num2)
                if i_num2_l > v:
                    i_num2 = i_num2[i_num2_l - v:]

                i = i[0:start] + i_num2 + i[end:]

        # Step 7 (one peice at a time)
        begin = (num - 1) * u
        to_copy = min(key_length, u)
        a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]

    return a[0:key_length]