예제 #1
0
def _parse_hello_extensions(data):
    """
    Creates a generator returning tuples of information about each extension
    from a byte string of extension data contained in a ServerHello ores
    ClientHello message

    :param data:
        A byte string of a extension data from a TLS ServerHello or ClientHello
        message

    :return:
        A generator that yields 2-element tuples:
        [0] Byte string of extension type
        [1] Byte string of extension data
    """

    if data == b'':
        return

    extentions_length = int_from_bytes(data[0:2])
    extensions_start = 2
    extensions_end = 2 + extentions_length

    pointer = extensions_start
    while pointer < extensions_end:
        extension_type = int_from_bytes(data[pointer:pointer + 2])
        extension_length = int_from_bytes(data[pointer + 2:pointer + 4])
        yield (
            extension_type,
            data[pointer + 4:pointer + 4 + extension_length]
        )
        pointer += 4 + extension_length
예제 #2
0
def get_dh_params_length(server_handshake_bytes):
    """
    Determines the length of the DH params from the ServerKeyExchange

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        None or an integer of the bit size of the DH parameters
    """

    output = None

    dh_params_bytes = None

    for record_type, _, record_data in _parse_tls_records(server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in _parse_handshake_messages(record_data):
            if message_type == b'\x0c':
                dh_params_bytes = message_data
                break
        if dh_params_bytes:
            break

    if dh_params_bytes:
        output = int_from_bytes(dh_params_bytes[0:2]) * 8

    return output
예제 #3
0
def _parse_tls_records(data):
    """
    Creates a generator returning tuples of information about each record
    in a byte string of data from a TLS client or server. Stops as soon as it
    find a ChangeCipherSpec message since all data from then on is encrypted.

    :param data:
        A byte string of TLS records

    :return:
        A generator that yields 3-element tuples:
        [0] Byte string of record type
        [1] Byte string of protocol version
        [2] Byte string of record data
    """

    pointer = 0
    data_len = len(data)
    while pointer < data_len:
        # Don't try to parse any more once the ChangeCipherSpec is found
        if data[pointer:pointer + 1] == b'\x14':
            break
        length = int_from_bytes(data[pointer + 3:pointer + 5])
        yield (
            data[pointer:pointer + 1],
            data[pointer + 1:pointer + 3],
            data[pointer + 5:pointer + 5 + length]
        )
        pointer += 5 + length
예제 #4
0
def ec_generate_pair(curve):
    """
    Generates a EC public/private key pair

    :param curve:
        A unicode string. Valid values include "secp256r1", "secp384r1" and
        "secp521r1".

    :raises:
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type

    :return:
        A 2-element tuple of (asn1crypto.keys.PublicKeyInfo,
        asn1crypto.keys.PrivateKeyInfo)
    """

    if curve not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(
            pretty_message(
                '''
            curve must be one of "secp256r1", "secp384r1", "secp521r1", not %s
            ''', repr(curve)))

    curve_num_bytes = CURVE_BYTES[curve]
    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve]

    while True:
        private_key_bytes = rand_bytes(curve_num_bytes)
        private_key_int = int_from_bytes(private_key_bytes, signed=False)

        if private_key_int > 0 and private_key_int < curve_base_point.order:
            break

    private_key_info = keys.PrivateKeyInfo({
        'version':
        0,
        'private_key_algorithm':
        keys.PrivateKeyAlgorithm({
            'algorithm':
            'ec',
            'parameters':
            keys.ECDomainParameters(name='named', value=curve)
        }),
        'private_key':
        keys.ECPrivateKey({
            'version': 'ecPrivkeyVer1',
            'private_key': private_key_int
        }),
    })
    private_key_info['private_key'].parsed[
        'public_key'] = private_key_info.public_key
    public_key_info = private_key_info.public_key_info

    return (public_key_info, private_key_info)
예제 #5
0
def raw_rsa_public_crypt(certificate_or_public_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a certificate or
    public key. This is a low-level primitive and is prone to disastrous results
    if used incorrectly.

    :param certificate_or_public_key:
        An oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
        object

    :param data:
        A byte string of the signature when verifying, or padded plaintext when
        encrypting. Must be less than or equal to the length of the public key.
        When verifying, padding will need to be removed afterwards. When
        encrypting, padding must be applied before.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    valid_types = (PublicKeyInfo, Certificate)
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1,
                                      valid_types):
        raise TypeError(
            pretty_message(
                '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''', type_name(certificate_or_public_key)))

    algo = certificate_or_public_key.asn1['algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            certificate_or_public_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_public_key = certificate_or_public_key.asn1['public_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_public_key['public_exponent'].native,
                          rsa_public_key['modulus'].native)
    return int_to_bytes(transformed_int,
                        width=certificate_or_public_key.asn1.byte_size)
예제 #6
0
def raw_rsa_private_crypt(private_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a private key.
    This is a low-level primitive and is prone to disastrous results if used
    incorrectly.

    :param private_key:
        An oscrypto.asymmetric.PrivateKey object

    :param data:
        A byte string of the plaintext to be signed or ciphertext to be
        decrypted. Must be less than or equal to the length of the private key.
        In the case of signing, padding must already be applied. In the case of
        decryption, padding must be removed afterward.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    if not hasattr(private_key, 'asn1') or not isinstance(
            private_key.asn1, PrivateKeyInfo):
        raise TypeError(
            pretty_message(
                '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''', type_name(private_key)))

    algo = private_key.asn1['private_key_algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            private_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_private_key = private_key.asn1['private_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_private_key['private_exponent'].native,
                          rsa_private_key['modulus'].native)
    return int_to_bytes(transformed_int, width=private_key.asn1.byte_size)
예제 #7
0
def extract_chain(server_handshake_bytes):
    """
    Extracts the X.509 certificates from the server handshake bytes for use
    when debugging

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A list of asn1crypto.x509.Certificate objects
    """

    output = []

    chain_bytes = None

    for record_type, _, record_data in _parse_tls_records(server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in _parse_handshake_messages(record_data):
            if message_type == b'\x0b':
                chain_bytes = message_data
                break
        if chain_bytes:
            break

    if chain_bytes:
        # The first 3 bytes are the cert chain length
        pointer = 3
        while pointer < len(chain_bytes):
            cert_length = int_from_bytes(chain_bytes[pointer:pointer + 3])
            cert_start = pointer + 3
            cert_end = cert_start + cert_length
            pointer = cert_end
            cert_bytes = chain_bytes[cert_start:cert_end]
            output.append(Certificate.load(cert_bytes))

    return output
예제 #8
0
def _parse_handshake_messages(data):
    """
    Creates a generator returning tuples of information about each message in
    a byte string of data from a TLS handshake record

    :param data:
        A byte string of a TLS handshake record data

    :return:
        A generator that yields 2-element tuples:
        [0] Byte string of message type
        [1] Byte string of message data
    """

    pointer = 0
    data_len = len(data)
    while pointer < data_len:
        length = int_from_bytes(data[pointer + 1:pointer + 4])
        yield (
            data[pointer:pointer + 1],
            data[pointer + 4:pointer + 4 + length]
        )
        pointer += 4 + length
예제 #9
0
def pbkdf2(hash_algorithm, password, salt, iterations, key_length):
    """
    Implements PBKDF2 from PKCS#5 v2.2 in pure Python

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            password must be a byte string, not %s
            ''', type_name(password)))

    if not isinstance(salt, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            salt must be a byte string, not %s
            ''', type_name(salt)))

    if not isinstance(iterations, int_types):
        raise TypeError(
            pretty_message(
                '''
            iterations must be an integer, not %s
            ''', type_name(iterations)))

    if iterations < 1:
        raise ValueError(
            pretty_message(
                '''
            iterations must be greater than 0 - is %s
            ''', repr(iterations)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 1:
        raise ValueError(
            pretty_message(
                '''
            key_length must be greater than 0 - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''', repr(hash_algorithm)))

    algo = getattr(hashlib, hash_algorithm)

    hash_length = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    blocks = int(math.ceil(key_length / hash_length))

    original_hmac = hmac.new(password, None, algo)

    int_pack = struct.Struct(b'>I').pack

    output = b''
    for block in range(1, blocks + 1):
        prf = original_hmac.copy()
        prf.update(salt + int_pack(block))
        last = prf.digest()
        u = int_from_bytes(last)
        for _ in range(2, iterations + 1):
            prf = original_hmac.copy()
            prf.update(last)
            last = prf.digest()
            u ^= int_from_bytes(last)
        t = int_to_bytes(u)
        output += t

    return output[0:key_length]
예제 #10
0
def parse_session_info(server_handshake_bytes, client_handshake_bytes):
    """
    Parse the TLS handshake from the client to the server to extract information
    including the cipher suite selected, if compression is enabled, the
    session id and if a new or reused session ticket exists.

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :param client_handshake_bytes:
        A byte string of the handshake data sent to the server

    :return:
        A dict with the following keys:
         - "protocol": unicode string
         - "cipher_suite": unicode string
         - "compression": boolean
         - "session_id": "new", "reused" or None
         - "session_ticket: "new", "reused" or None
    """

    protocol = None
    cipher_suite = None
    compression = False
    session_id = None
    session_ticket = None

    server_session_id = None
    client_session_id = None

    for record_type, _, record_data in _parse_tls_records(server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in _parse_handshake_messages(record_data):
            # Ensure we are working with a ServerHello message
            if message_type != b'\x02':
                continue
            protocol = {
                b'\x03\x00': "SSLv3",
                b'\x03\x01': "TLSv1",
                b'\x03\x02': "TLSv1.1",
                b'\x03\x03': "TLSv1.2",
                b'\x03\x04': "TLSv1.3",
            }[message_data[0:2]]

            session_id_length = int_from_bytes(message_data[34:35])
            if session_id_length > 0:
                server_session_id = message_data[35:35 + session_id_length]

            cipher_suite_start = 35 + session_id_length
            cipher_suite_bytes = message_data[cipher_suite_start:cipher_suite_start + 2]
            cipher_suite = CIPHER_SUITE_MAP[cipher_suite_bytes]

            compression_start = cipher_suite_start + 2
            compression = message_data[compression_start:compression_start + 1] != b'\x00'

            extensions_length_start = compression_start + 1
            extensions_data = message_data[extensions_length_start:]
            for extension_type, extension_data in _parse_hello_extensions(extensions_data):
                if extension_type == 35:
                    session_ticket = "new"
                    break
            break

    for record_type, _, record_data in _parse_tls_records(client_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in _parse_handshake_messages(record_data):
            # Ensure we are working with a ClientHello message
            if message_type != b'\x01':
                continue

            session_id_length = int_from_bytes(message_data[34:35])
            if session_id_length > 0:
                client_session_id = message_data[35:35 + session_id_length]

            cipher_suite_start = 35 + session_id_length
            cipher_suite_length = int_from_bytes(message_data[cipher_suite_start:cipher_suite_start + 2])

            compression_start = cipher_suite_start + 2 + cipher_suite_length
            compression_length = int_from_bytes(message_data[compression_start:compression_start + 1])

            # On subsequent requests, the session ticket will only be seen
            # in the ClientHello message
            if server_session_id is None and session_ticket is None:
                extensions_length_start = compression_start + 1 + compression_length
                extensions_data = message_data[extensions_length_start:]
                for extension_type, extension_data in _parse_hello_extensions(extensions_data):
                    if extension_type == 35:
                        session_ticket = "reused"
                        break
            break

    if server_session_id is not None:
        if client_session_id is None:
            session_id = "new"
        else:
            if client_session_id != server_session_id:
                session_id = "new"
            else:
                session_id = "reused"

    return {
        "protocol": protocol,
        "cipher_suite": cipher_suite,
        "compression": compression,
        "session_id": session_id,
        "session_ticket": session_ticket,
    }
예제 #11
0
def ecdsa_verify(certificate_or_public_key, signature, data, hash_algorithm):
    """
    Verifies an ECDSA signature in pure Python (thus slow)

    :param certificate_or_public_key:
        A Certificate or PublicKey instance to verify the signature with

    :param signature:
        A byte string of the signature to verify

    :param data:
        A byte string of the data the signature is for

    :param hash_algorithm:
        A unicode string of "md5", "sha1", "sha256", "sha384" or "sha512"

    :raises:
        oscrypto.errors.SignatureError - when the signature is determined to be invalid
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type
        OSError - when an error is returned by the OS crypto library
    """

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1,
                                      (keys.PublicKeyInfo, Certificate)):
        raise TypeError(
            pretty_message(
                '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''', type_name(certificate_or_public_key)))

    curve_name = certificate_or_public_key.curve
    if curve_name not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(
            pretty_message('''
            certificate_or_public_key does not use one of the named curves
            secp256r1, secp384r1 or secp521r1
            '''))

    if not isinstance(signature, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            signature must be a byte string, not %s
            ''', type_name(signature)))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    asn1 = certificate_or_public_key.asn1
    if isinstance(asn1, Certificate):
        asn1 = asn1.public_key

    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve_name]

    x, y = asn1['public_key'].to_coords()
    n = curve_base_point.order

    # Validates that the point is valid
    public_key_point = PrimePoint(curve_base_point.curve, x, y, n)

    try:
        signature = DSASignature.load(signature)
        r = signature['r'].native
        s = signature['s'].native
    except (ValueError):
        raise SignatureError('Signature is invalid')

    invalid = 0

    # Check r is valid
    invalid |= r < 1
    invalid |= r >= n

    # Check s is valid
    invalid |= s < 1
    invalid |= s >= n

    if invalid:
        raise SignatureError('Signature is invalid')

    hash_func = getattr(hashlib, hash_algorithm)

    digest = hash_func(data).digest()

    z = int_from_bytes(digest, signed=False) % n
    w = inverse_mod(s, n)
    u1 = (z * w) % n
    u2 = (r * w) % n
    hash_point = (curve_base_point * u1) + (public_key_point * u2)
    if r != (hash_point.x % n):
        raise SignatureError('Signature is invalid')
예제 #12
0
def ecdsa_sign(private_key, data, hash_algorithm):
    """
    Generates an ECDSA signature in pure Python (thus slow)

    :param private_key:
        The PrivateKey to generate the signature with

    :param data:
        A byte string of the data the signature is for

    :param hash_algorithm:
        A unicode string of "sha1", "sha256", "sha384" or "sha512"

    :raises:
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type
        OSError - when an error is returned by the OS crypto library

    :return:
        A byte string of the signature
    """

    if not hasattr(private_key, 'asn1') or not isinstance(
            private_key.asn1, keys.PrivateKeyInfo):
        raise TypeError(
            pretty_message(
                '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''', type_name(private_key)))

    curve_name = private_key.curve
    if curve_name not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(
            pretty_message('''
            private_key does not use one of the named curves secp256r1,
            secp384r1 or secp521r1
            '''))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    ec_private_key = private_key.asn1['private_key'].parsed
    private_key_bytes = ec_private_key['private_key'].contents
    private_key_int = ec_private_key['private_key'].native

    curve_num_bytes = CURVE_BYTES[curve_name]
    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve_name]

    n = curve_base_point.order

    # RFC 6979 section 3.2

    # a.
    digest = hash_func(data).digest()
    hash_length = len(digest)

    h = int_from_bytes(digest, signed=False) % n

    # b.
    V = b'\x01' * hash_length

    # c.
    K = b'\x00' * hash_length

    # d.
    K = hmac.new(K, V + b'\x00' + private_key_bytes + digest,
                 hash_func).digest()

    # e.
    V = hmac.new(K, V, hash_func).digest()

    # f.
    K = hmac.new(K, V + b'\x01' + private_key_bytes + digest,
                 hash_func).digest()

    # g.
    V = hmac.new(K, V, hash_func).digest()

    # h.
    r = 0
    s = 0
    while True:
        # h. 1
        T = b''

        # h. 2
        while len(T) < curve_num_bytes:
            V = hmac.new(K, V, hash_func).digest()
            T += V

        # h. 3
        k = int_from_bytes(T[0:curve_num_bytes], signed=False)
        if k == 0 or k >= n:
            continue

        # Calculate the signature in the loop in case we need a new k
        r = (curve_base_point * k).x % n
        if r == 0:
            continue

        s = (inverse_mod(k, n) * (h + (private_key_int * r) % n)) % n
        if s == 0:
            continue

        break

    return DSASignature({'r': r, 's': s}).dump()
예제 #13
0
def add_pss_padding(hash_algorithm, salt_length, key_length, message):
    """
    Pads a byte string using the EMSA-PSS-Encode operation described in PKCS#1
    v2.2.

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :return:
        The encoded (passed) message
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding addition code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 512:
        raise ValueError(
            pretty_message(
                '''
            key_length must be 512 or more - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    # The maximal bit size of a non-negative integer is one less than the bit
    # size of the key since the first bit is used to store sign
    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        raise ValueError(
            pretty_message('''
            Key is not long enough to use with specified hash_algorithm and
            salt_length
            '''))

    if salt_length > 0:
        salt = os.urandom(salt_length)
    else:
        salt = b''

    m_prime = (b'\x00' * 8) + message_digest + salt

    m_prime_digest = hash_func(m_prime).digest()

    padding = b'\x00' * (em_len - salt_length - hash_length - 2)

    db = padding + b'\x01' + salt

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    masked_db = int_to_bytes(int_from_bytes(db) ^ int_from_bytes(db_mask))
    masked_db = fill_width(masked_db, len(db_mask))

    zero_bits = (8 * em_len) - em_bits
    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        masked_db = chr_cls(left_int_mask
                            & ord(masked_db[0:1])) + masked_db[1:]

    return masked_db + m_prime_digest + b'\xBC'
예제 #14
0
def verify_pss_padding(hash_algorithm, salt_length, key_length, message,
                       signature):
    """
    Verifies the PSS padding on an encoded message

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :param signature:
        The signature to verify

    :return:
        A boolean indicating if the signature is invalid
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding verification code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(signature, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            signature must be a byte string, not %s
            ''', type_name(signature)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        return False

    if signature[-1:] != b'\xBC':
        return False

    zero_bits = (8 * em_len) - em_bits

    masked_db_length = em_len - hash_length - 1
    masked_db = signature[0:masked_db_length]

    first_byte = ord(masked_db[0:1])
    bits_that_should_be_zero = first_byte >> (8 - zero_bits)
    if bits_that_should_be_zero != 0:
        return False

    m_prime_digest = signature[masked_db_length:masked_db_length + hash_length]

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        db_mask = chr_cls(left_int_mask & ord(db_mask[0:1])) + db_mask[1:]

    db = int_to_bytes(int_from_bytes(masked_db) ^ int_from_bytes(db_mask))
    if len(db) < len(masked_db):
        db = (b'\x00' * (len(masked_db) - len(db))) + db

    zero_length = em_len - hash_length - salt_length - 2
    zero_string = b'\x00' * zero_length
    if not constant_compare(db[0:zero_length], zero_string):
        return False

    if db[zero_length:zero_length + 1] != b'\x01':
        return False

    salt = db[0 - salt_length:]

    m_prime = (b'\x00' * 8) + message_digest + salt

    h_prime = hash_func(m_prime).digest()

    return constant_compare(m_prime_digest, h_prime)
예제 #15
0
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
    """
    KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :param id_:
        The ID of the usage - 1 for key, 2 for iv, 3 for mac

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(pretty_message(
            '''
            password must be a byte string, not %s
            ''',
            type_name(password)
        ))

    if not isinstance(salt, byte_cls):
        raise TypeError(pretty_message(
            '''
            salt must be a byte string, not %s
            ''',
            type_name(salt)
        ))

    if not isinstance(iterations, int_types):
        raise TypeError(pretty_message(
            '''
            iterations must be an integer, not %s
            ''',
            type_name(iterations)
        ))

    if iterations < 1:
        raise ValueError(pretty_message(
            '''
            iterations must be greater than 0 - is %s
            ''',
            repr(iterations)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 1:
        raise ValueError(pretty_message(
            '''
            key_length must be greater than 0 - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    if id_ not in set([1, 2, 3]):
        raise ValueError(pretty_message(
            '''
            id_ must be one of 1, 2, 3, not %s
            ''',
            repr(id_)
        ))

    utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'

    algo = getattr(hashlib, hash_algorithm)

    # u and v values are bytes (not bits as in the RFC)
    u = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    if hash_algorithm in ['sha384', 'sha512']:
        v = 128
    else:
        v = 64

    # Step 1
    d = chr_cls(id_) * v

    # Step 2
    s = b''
    if salt != b'':
        s_len = v * int(math.ceil(float(len(salt)) / v))
        while len(s) < s_len:
            s += salt
        s = s[0:s_len]

    # Step 3
    p = b''
    if utf16_password != b'':
        p_len = v * int(math.ceil(float(len(utf16_password)) / v))
        while len(p) < p_len:
            p += utf16_password
        p = p[0:p_len]

    # Step 4
    i = s + p

    # Step 5
    c = int(math.ceil(float(key_length) / u))

    a = b'\x00' * (c * u)

    for num in range(1, c + 1):
        # Step 6A
        a2 = algo(d + i).digest()
        for _ in range(2, iterations + 1):
            a2 = algo(a2).digest()

        if num < c:
            # Step 6B
            b = b''
            while len(b) < v:
                b += a2

            b = int_from_bytes(b[0:v]) + 1

            # Step 6C
            for num2 in range(0, len(i) // v):
                start = num2 * v
                end = (num2 + 1) * v
                i_num2 = i[start:end]

                i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)

                # Ensure the new slice is the right size
                i_num2_l = len(i_num2)
                if i_num2_l > v:
                    i_num2 = i_num2[i_num2_l - v:]

                i = i[0:start] + i_num2 + i[end:]

        # Step 7 (one peice at a time)
        begin = (num - 1) * u
        to_copy = min(key_length, u)
        a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]

    return a[0:key_length]