Beispiel #1
0
def get_dh_params_length(server_handshake_bytes):
    """
    Determines the length of the DH params from the ServerKeyExchange

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        An integer
    """

    output = None

    found = False
    message_bytes = None

    pointer = 0
    while pointer < len(server_handshake_bytes):
        record_header = server_handshake_bytes[pointer:pointer + 5]
        record_type = record_header[0:1]
        record_length = int_from_bytes(record_header[3:])
        sub_type = server_handshake_bytes[pointer + 5:pointer + 6]
        if record_type == b'\x16' and sub_type == b'\x0c':
            found = True
            message_bytes = server_handshake_bytes[pointer + 5:pointer + 5 +
                                                   record_length]
            break
        pointer += 5 + record_length

    if found:
        # The first 4 bytes are the handshake type (1 byte) and total message
        # length (3 bytes)
        output = int_from_bytes(message_bytes[4:6]) * 8

    return output
Beispiel #2
0
def _parse_hello_extensions(data):
    """
    Creates a generator returning tuples of information about each extension
    from a byte string of extension data contained in a ServerHello ores
    ClientHello message

    :param data:
        A byte string of a extension data from a TLS ServerHello or ClientHello
        message

    :return:
        A generator that yields 2-element tuples:
        [0] Byte string of extension type
        [1] Byte string of extension data
    """

    if data == b'':
        return

    extentions_length = int_from_bytes(data[0:2])
    extensions_start = 2
    extensions_end = 2 + extentions_length

    pointer = extensions_start
    while pointer < extensions_end:
        extension_type = int_from_bytes(data[pointer:pointer + 2])
        extension_length = int_from_bytes(data[pointer + 2:pointer + 4])
        yield (extension_type,
               data[pointer + 4:pointer + 4 + extension_length])
        pointer += 4 + extension_length
Beispiel #3
0
def _parse_hello_extensions(data):
    """
    Creates a generator returning tuples of information about each extension
    from a byte string of extension data contained in a ServerHello ores
    ClientHello message

    :param data:
        A byte string of a extension data from a TLS ServerHello or ClientHello
        message

    :return:
        A generator that yields 2-element tuples:
        [0] Byte string of extension type
        [1] Byte string of extension data
    """

    if data == b'':
        return

    extentions_length = int_from_bytes(data[0:2])
    extensions_start = 2
    extensions_end = 2 + extentions_length

    pointer = extensions_start
    while pointer < extensions_end:
        extension_type = int_from_bytes(data[pointer:pointer + 2])
        extension_length = int_from_bytes(data[pointer + 2:pointer + 4])
        yield (
            extension_type,
            data[pointer + 4:pointer + 4 + extension_length]
        )
        pointer += 4 + extension_length
Beispiel #4
0
def get_dh_params_length(server_handshake_bytes):
    """
    Determines the length of the DH params from the ServerKeyExchange

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        An integer
    """

    output = None

    found = False
    message_bytes = None

    pointer = 0
    while pointer < len(server_handshake_bytes):
        record_header = server_handshake_bytes[pointer : pointer + 5]
        record_type = record_header[0:1]
        record_length = int_from_bytes(record_header[3:])
        sub_type = server_handshake_bytes[pointer + 5 : pointer + 6]
        if record_type == b"\x16" and sub_type == b"\x0c":
            found = True
            message_bytes = server_handshake_bytes[pointer + 5 : pointer + 5 + record_length]
            break
        pointer += 5 + record_length

    if found:
        # The first 4 bytes are the handshake type (1 byte) and total message
        # length (3 bytes)
        output = int_from_bytes(message_bytes[4:6]) * 8

    return output
Beispiel #5
0
 def test_int_fromto_bytes(self):
     for i in range(-300, 301):
         self.assertEqual(
             i, util.int_from_bytes(util.int_to_bytes(i, True), True))
     for i in range(0, 301):
         self.assertEqual(
             i, util.int_from_bytes(util.int_to_bytes(i, False), False))
Beispiel #6
0
 def test_int_from_bytes(self):
     self.assertEqual(util.int_from_bytes(b'', False), 0)
     self.assertEqual(util.int_from_bytes(b'', True), 0)
     self.assertEqual(util.int_from_bytes(b'\x00', False), 0)
     self.assertEqual(util.int_from_bytes(b'\x00', True), 0)
     self.assertEqual(util.int_from_bytes(b'\x80', False), 128)
     self.assertEqual(util.int_from_bytes(b'\x80', True), -128)
     self.assertEqual(util.int_from_bytes(b'\xff', False), 255)
     self.assertEqual(util.int_from_bytes(b'\xff', True), -1)
     self.assertEqual(util.int_from_bytes(b'\xbc\x61\x4e', False), 12345678)
     self.assertEqual(util.int_from_bytes(b'\xbc\x61\x4e', True), 12345678 - 2 ** 24)
Beispiel #7
0
def get_dh_params_length(server_handshake_bytes):
    """
    Determines the length of the DH params from the ServerKeyExchange

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        None or an integer of the bit size of the DH parameters
    """

    output = None

    dh_params_bytes = None

    for record_type, _, record_data in parse_tls_records(server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(record_data):
            if message_type == b'\x0c':
                dh_params_bytes = message_data
                break
        if dh_params_bytes:
            break

    if dh_params_bytes:
        output = int_from_bytes(dh_params_bytes[0:2]) * 8

    return output
Beispiel #8
0
def get_dh_params_length(server_handshake_bytes):
    """
    Determines the length of the DH params from the ServerKeyExchange

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        None or an integer of the bit size of the DH parameters
    """

    output = None

    dh_params_bytes = None

    for record_type, _, record_data in parse_tls_records(
            server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(
                record_data):
            if message_type == b'\x0c':
                dh_params_bytes = message_data
                break
        if dh_params_bytes:
            break

    if dh_params_bytes:
        output = int_from_bytes(dh_params_bytes[0:2]) * 8

    return output
Beispiel #9
0
def parse_tls_records(data):
    """
    Creates a generator returning tuples of information about each record
    in a byte string of data from a TLS client or server. Stops as soon as it
    find a ChangeCipherSpec message since all data from then on is encrypted.

    :param data:
        A byte string of TLS records

    :return:
        A generator that yields 3-element tuples:
        [0] Byte string of record type
        [1] Byte string of protocol version
        [2] Byte string of record data
    """

    pointer = 0
    data_len = len(data)
    while pointer < data_len:
        # Don't try to parse any more once the ChangeCipherSpec is found
        if data[pointer:pointer + 1] == b'\x14':
            break
        length = int_from_bytes(data[pointer + 3:pointer + 5])
        yield (
            data[pointer:pointer + 1],
            data[pointer + 1:pointer + 3],
            data[pointer + 5:pointer + 5 + length]
        )
        pointer += 5 + length
Beispiel #10
0
def parse_tls_records(data):
    """
    Creates a generator returning tuples of information about each record
    in a byte string of data from a TLS client or server. Stops as soon as it
    find a ChangeCipherSpec message since all data from then on is encrypted.

    :param data:
        A byte string of TLS records

    :return:
        A generator that yields 3-element tuples:
        [0] Byte string of record type
        [1] Byte string of protocol version
        [2] Byte string of record data
    """

    pointer = 0
    data_len = len(data)
    while pointer < data_len:
        # Don't try to parse any more once the ChangeCipherSpec is found
        if data[pointer:pointer + 1] == b'\x14':
            break
        length = int_from_bytes(data[pointer + 3:pointer + 5])
        yield (data[pointer:pointer + 1], data[pointer + 1:pointer + 3],
               data[pointer + 5:pointer + 5 + length])
        pointer += 5 + length
Beispiel #11
0
def ec_generate_pair(curve):
    """
    Generates a EC public/private key pair

    :param curve:
        A unicode string. Valid values include "secp256r1", "secp384r1" and
        "secp521r1".

    :raises:
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type

    :return:
        A 2-element tuple of (asn1crypto.keys.PublicKeyInfo,
        asn1crypto.keys.PrivateKeyInfo)
    """

    if curve not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(
            pretty_message(
                '''
            curve must be one of "secp256r1", "secp384r1", "secp521r1", not %s
            ''', repr(curve)))

    curve_num_bytes = CURVE_BYTES[curve]
    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve]

    while True:
        private_key_bytes = rand_bytes(curve_num_bytes)
        private_key_int = int_from_bytes(private_key_bytes, signed=False)

        if private_key_int > 0 and private_key_int < curve_base_point.order:
            break

    private_key_info = keys.PrivateKeyInfo({
        'version':
        0,
        'private_key_algorithm':
        keys.PrivateKeyAlgorithm({
            'algorithm':
            'ec',
            'parameters':
            keys.ECDomainParameters(name='named', value=curve)
        }),
        'private_key':
        keys.ECPrivateKey({
            'version': 'ecPrivkeyVer1',
            'private_key': private_key_int
        }),
    })

    ec_point = ec_compute_public_key_point(private_key_info)
    private_key_info['private_key'].parsed['public_key'] = ec_point.copy()

    return (ec_public_key_info(ec_point, curve), private_key_info)
Beispiel #12
0
def ec_generate_pair(curve):
    """
    Generates a EC public/private key pair

    :param curve:
        A unicode string. Valid values include "secp256r1", "secp384r1" and
        "secp521r1".

    :raises:
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type

    :return:
        A 2-element tuple of (asn1crypto.keys.PublicKeyInfo,
        asn1crypto.keys.PrivateKeyInfo)
    """

    if curve not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(pretty_message(
            '''
            curve must be one of "secp256r1", "secp384r1", "secp521r1", not %s
            ''',
            repr(curve)
        ))

    curve_num_bytes = CURVE_BYTES[curve]
    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve]

    while True:
        private_key_bytes = rand_bytes(curve_num_bytes)
        private_key_int = int_from_bytes(private_key_bytes, signed=False)

        if private_key_int > 0 and private_key_int < curve_base_point.order:
            break

    private_key_info = keys.PrivateKeyInfo({
        'version': 0,
        'private_key_algorithm': keys.PrivateKeyAlgorithm({
            'algorithm': 'ec',
            'parameters': keys.ECDomainParameters(
                name='named',
                value=curve
            )
        }),
        'private_key': keys.ECPrivateKey({
            'version': 'ecPrivkeyVer1',
            'private_key': private_key_int
        }),
    })
    private_key_info['private_key'].parsed['public_key'] = private_key_info.public_key
    public_key_info = private_key_info.public_key_info

    return (public_key_info, private_key_info)
Beispiel #13
0
def raw_rsa_public_crypt(certificate_or_public_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a certificate or
    public key. This is a low-level primitive and is prone to disastrous results
    if used incorrectly.

    :param certificate_or_public_key:
        An oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
        object

    :param data:
        A byte string of the signature when verifying, or padded plaintext when
        encrypting. Must be less than or equal to the length of the public key.
        When verifying, padding will need to be removed afterwards. When
        encrypting, padding must be applied before.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    valid_types = (PublicKeyInfo, Certificate)
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1,
                                      valid_types):
        raise TypeError(
            pretty_message(
                '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''', type_name(certificate_or_public_key)))

    algo = certificate_or_public_key.asn1['algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            certificate_or_public_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_public_key = certificate_or_public_key.asn1['public_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_public_key['public_exponent'].native,
                          rsa_public_key['modulus'].native)
    return int_to_bytes(transformed_int,
                        width=certificate_or_public_key.asn1.byte_size)
Beispiel #14
0
def extract_chain(server_handshake_bytes):
    """
    Extracts the X.509 certificates from the server handshake bytes for use
    when debugging

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A list of asn1crypto.x509.Certificate objects
    """

    output = []

    found = False
    message_bytes = None

    pointer = 0
    while pointer < len(server_handshake_bytes):
        record_header = server_handshake_bytes[pointer:pointer + 5]
        record_type = record_header[0:1]
        record_length = int_from_bytes(record_header[3:])
        sub_type = server_handshake_bytes[pointer + 5:pointer + 6]
        if record_type == b'\x16' and sub_type == b'\x0b':
            found = True
            message_bytes = server_handshake_bytes[pointer + 5:pointer + 5 +
                                                   record_length]
            break
        pointer += 5 + record_length

    if found:
        # The first 7 bytes are the handshake type (1 byte) and total message
        # length (3 bytes) and cert chain length (3 bytes)
        pointer = 7
        while pointer < len(message_bytes):
            cert_length = int_from_bytes(message_bytes[pointer:pointer + 3])
            cert_start = pointer + 3
            cert_end = cert_start + cert_length
            pointer = cert_end
            cert_bytes = message_bytes[cert_start:cert_end]
            output.append(Certificate.load(cert_bytes))

    return output
Beispiel #15
0
def raw_rsa_private_crypt(private_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a private key.
    This is a low-level primitive and is prone to disastrous results if used
    incorrectly.

    :param private_key:
        An oscrypto.asymmetric.PrivateKey object

    :param data:
        A byte string of the plaintext to be signed or ciphertext to be
        decrypted. Must be less than or equal to the length of the private key.
        In the case of signing, padding must already be applied. In the case of
        decryption, padding must be removed afterward.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    if not hasattr(private_key, 'asn1') or not isinstance(private_key.asn1, PrivateKeyInfo):
        raise TypeError(pretty_message(
            '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''',
            type_name(private_key)
        ))

    algo = private_key.asn1['private_key_algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(pretty_message(
            '''
            private_key must be an RSA key, not %s
            ''',
            algo.upper()
        ))

    if not isinstance(data, byte_cls):
        raise TypeError(pretty_message(
            '''
            data must be a byte string, not %s
            ''',
            type_name(data)
        ))

    rsa_private_key = private_key.asn1['private_key'].parsed
    transformed_int = pow(
        int_from_bytes(data),
        rsa_private_key['private_exponent'].native,
        rsa_private_key['modulus'].native
    )
    return int_to_bytes(transformed_int, width=private_key.asn1.byte_size)
Beispiel #16
0
def extract_chain(server_handshake_bytes):
    """
    Extracts the X.509 certificates from the server handshake bytes for use
    when debugging

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A list of asn1crypto.x509.Certificate objects
    """

    output = []

    found = False
    message_bytes = None

    pointer = 0
    while pointer < len(server_handshake_bytes):
        record_header = server_handshake_bytes[pointer : pointer + 5]
        record_type = record_header[0:1]
        record_length = int_from_bytes(record_header[3:])
        sub_type = server_handshake_bytes[pointer + 5 : pointer + 6]
        if record_type == b"\x16" and sub_type == b"\x0b":
            found = True
            message_bytes = server_handshake_bytes[pointer + 5 : pointer + 5 + record_length]
            break
        pointer += 5 + record_length

    if found:
        # The first 7 bytes are the handshake type (1 byte) and total message
        # length (3 bytes) and cert chain length (3 bytes)
        pointer = 7
        while pointer < len(message_bytes):
            cert_length = int_from_bytes(message_bytes[pointer : pointer + 3])
            cert_start = pointer + 3
            cert_end = cert_start + cert_length
            pointer = cert_end
            cert_bytes = message_bytes[cert_start:cert_end]
            output.append(Certificate.load(cert_bytes))

    return output
Beispiel #17
0
def parse_alert(server_handshake_bytes):
    """
    Parses the handshake for protocol alerts

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        None or an 2-element tuple of integers:
         0: 1 (warning) or 2 (fatal)
         1: The alert description (see https://tools.ietf.org/html/rfc5246#section-7.2)
    """

    for record_type, _, record_data in parse_tls_records(server_handshake_bytes):
        if record_type != b'\x15':
            continue
        if len(record_data) != 2:
            return None
        return (int_from_bytes(record_data[0:1]), int_from_bytes(record_data[1:2]))
    return None
Beispiel #18
0
def parse_alert(server_handshake_bytes):
    """
    Parses the handshake for protocol alerts

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        None or an 2-element tuple of integers:
         0: 1 (warning) or 2 (fatal)
         1: The alert description (see https://tools.ietf.org/html/rfc5246#section-7.2)
    """

    for record_type, _, record_data in parse_tls_records(
            server_handshake_bytes):
        if record_type != b'\x15':
            continue
        if len(record_data) != 2:
            return None
        return (int_from_bytes(record_data[0:1]),
                int_from_bytes(record_data[1:2]))
    return None
Beispiel #19
0
def raw_rsa_private_crypt(private_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a private key.
    This is a low-level primitive and is prone to disastrous results if used
    incorrectly.

    :param private_key:
        An oscrypto.asymmetric.PrivateKey object

    :param data:
        A byte string of the plaintext to be signed or ciphertext to be
        decrypted. Must be less than or equal to the length of the private key.
        In the case of signing, padding must already be applied. In the case of
        decryption, padding must be removed afterward.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    if not hasattr(private_key, 'asn1') or not isinstance(
            private_key.asn1, PrivateKeyInfo):
        raise TypeError(
            pretty_message(
                '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''', type_name(private_key)))

    algo = private_key.asn1['private_key_algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            private_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_private_key = private_key.asn1['private_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_private_key['private_exponent'].native,
                          rsa_private_key['modulus'].native)
    return int_to_bytes(transformed_int, width=private_key.asn1.byte_size)
Beispiel #20
0
def parse_handshake_messages(data):
    """
    Creates a generator returning tuples of information about each message in
    a byte string of data from a TLS handshake record

    :param data:
        A byte string of a TLS handshake record data

    :return:
        A generator that yields 2-element tuples:
        [0] Byte string of message type
        [1] Byte string of message data
    """

    pointer = 0
    data_len = len(data)
    while pointer < data_len:
        length = int_from_bytes(data[pointer + 1:pointer + 4])
        yield (data[pointer:pointer + 1],
               data[pointer + 4:pointer + 4 + length])
        pointer += 4 + length
Beispiel #21
0
def extract_chain(server_handshake_bytes):
    """
    Extracts the X.509 certificates from the server handshake bytes for use
    when debugging

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A list of asn1crypto.x509.Certificate objects
    """

    output = []

    chain_bytes = None

    for record_type, _, record_data in parse_tls_records(
            server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(
                record_data):
            if message_type == b'\x0b':
                chain_bytes = message_data
                break
        if chain_bytes:
            break

    if chain_bytes:
        # The first 3 bytes are the cert chain length
        pointer = 3
        while pointer < len(chain_bytes):
            cert_length = int_from_bytes(chain_bytes[pointer:pointer + 3])
            cert_start = pointer + 3
            cert_end = cert_start + cert_length
            pointer = cert_end
            cert_bytes = chain_bytes[cert_start:cert_end]
            output.append(Certificate.load(cert_bytes))

    return output
Beispiel #22
0
def extract_chain(server_handshake_bytes):
    """
    Extracts the X.509 certificates from the server handshake bytes for use
    when debugging

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A list of asn1crypto.x509.Certificate objects
    """

    output = []

    chain_bytes = None

    for record_type, _, record_data in parse_tls_records(server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(record_data):
            if message_type == b'\x0b':
                chain_bytes = message_data
                break
        if chain_bytes:
            break

    if chain_bytes:
        # The first 3 bytes are the cert chain length
        pointer = 3
        while pointer < len(chain_bytes):
            cert_length = int_from_bytes(chain_bytes[pointer:pointer + 3])
            cert_start = pointer + 3
            cert_end = cert_start + cert_length
            pointer = cert_end
            cert_bytes = chain_bytes[cert_start:cert_end]
            output.append(Certificate.load(cert_bytes))

    return output
Beispiel #23
0
def parse_handshake_messages(data):
    """
    Creates a generator returning tuples of information about each message in
    a byte string of data from a TLS handshake record

    :param data:
        A byte string of a TLS handshake record data

    :return:
        A generator that yields 2-element tuples:
        [0] Byte string of message type
        [1] Byte string of message data
    """

    pointer = 0
    data_len = len(data)
    while pointer < data_len:
        length = int_from_bytes(data[pointer + 1:pointer + 4])
        yield (
            data[pointer:pointer + 1],
            data[pointer + 4:pointer + 4 + length]
        )
        pointer += 4 + length
Beispiel #24
0
def detect_client_auth_request(server_handshake_bytes):
    """
    Determines if a CertificateRequest message is sent from the server asking
    the client for a certificate

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A boolean - if a client certificate request was found
    """

    pointer = 0
    while pointer < len(server_handshake_bytes):
        record_header = server_handshake_bytes[pointer:pointer + 5]
        record_type = record_header[0:1]
        record_length = int_from_bytes(record_header[3:])
        sub_type = server_handshake_bytes[pointer + 5:pointer + 6]
        if record_type == b'\x16' and sub_type == b'\x0d':
            return True
        pointer += 5 + record_length

    return False
Beispiel #25
0
def detect_client_auth_request(server_handshake_bytes):
    """
    Determines if a CertificateRequest message is sent from the server asking
    the client for a certificate

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :return:
        A boolean - if a client certificate request was found
    """

    pointer = 0
    while pointer < len(server_handshake_bytes):
        record_header = server_handshake_bytes[pointer : pointer + 5]
        record_type = record_header[0:1]
        record_length = int_from_bytes(record_header[3:])
        sub_type = server_handshake_bytes[pointer + 5 : pointer + 6]
        if record_type == b"\x16" and sub_type == b"\x0d":
            return True
        pointer += 5 + record_length

    return False
Beispiel #26
0
def add_pss_padding(hash_algorithm, salt_length, key_length, message):
    """
    Pads a byte string using the EMSA-PSS-Encode operation described in PKCS#1
    v2.2.

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :return:
        The encoded (passed) message
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding addition code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 512:
        raise ValueError(
            pretty_message(
                '''
            key_length must be 512 or more - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    # The maximal bit size of a non-negative integer is one less than the bit
    # size of the key since the first bit is used to store sign
    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        raise ValueError(
            pretty_message('''
            Key is not long enough to use with specified hash_algorithm and
            salt_length
            '''))

    if salt_length > 0:
        salt = os.urandom(salt_length)
    else:
        salt = b''

    m_prime = (b'\x00' * 8) + message_digest + salt

    m_prime_digest = hash_func(m_prime).digest()

    padding = b'\x00' * (em_len - salt_length - hash_length - 2)

    db = padding + b'\x01' + salt

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    masked_db = int_to_bytes(int_from_bytes(db) ^ int_from_bytes(db_mask))
    masked_db = fill_width(masked_db, len(db_mask))

    zero_bits = (8 * em_len) - em_bits
    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        masked_db = chr_cls(left_int_mask
                            & ord(masked_db[0:1])) + masked_db[1:]

    return masked_db + m_prime_digest + b'\xBC'
Beispiel #27
0
def verify_pss_padding(hash_algorithm, salt_length, key_length, message,
                       signature):
    """
    Verifies the PSS padding on an encoded message

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :param signature:
        The signature to verify

    :return:
        A boolean indicating if the signature is invalid
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding verification code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(signature, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            signature must be a byte string, not %s
            ''', type_name(signature)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        return False

    if signature[-1:] != b'\xBC':
        return False

    zero_bits = (8 * em_len) - em_bits

    masked_db_length = em_len - hash_length - 1
    masked_db = signature[0:masked_db_length]

    first_byte = ord(masked_db[0:1])
    bits_that_should_be_zero = first_byte >> (8 - zero_bits)
    if bits_that_should_be_zero != 0:
        return False

    m_prime_digest = signature[masked_db_length:masked_db_length + hash_length]

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        db_mask = chr_cls(left_int_mask & ord(db_mask[0:1])) + db_mask[1:]

    db = int_to_bytes(int_from_bytes(masked_db) ^ int_from_bytes(db_mask))
    if len(db) < len(masked_db):
        db = (b'\x00' * (len(masked_db) - len(db))) + db

    zero_length = em_len - hash_length - salt_length - 2
    zero_string = b'\x00' * zero_length
    if not constant_compare(db[0:zero_length], zero_string):
        return False

    if db[zero_length:zero_length + 1] != b'\x01':
        return False

    salt = db[0 - salt_length:]

    m_prime = (b'\x00' * 8) + message_digest + salt

    h_prime = hash_func(m_prime).digest()

    return constant_compare(m_prime_digest, h_prime)
Beispiel #28
0
def parse_session_info(server_handshake_bytes, client_handshake_bytes):
    """
    Parse the TLS handshake from the client to the server to extract information
    including the cipher suite selected, if compression is enabled, the
    session id and if a new or reused session ticket exists.

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :param client_handshake_bytes:
        A byte string of the handshake data sent to the server

    :return:
        A dict with the following keys:
         - "protocol": unicode string
         - "cipher_suite": unicode string
         - "compression": boolean
         - "session_id": "new", "reused" or None
         - "session_ticket: "new", "reused" or None
    """

    protocol = None
    cipher_suite = None
    compression = False
    session_id = None
    session_ticket = None

    server_session_id = None
    client_session_id = None

    if server_handshake_bytes[0:1] == b"\x16":
        server_tls_record_header = server_handshake_bytes[0:5]
        server_tls_record_length = int_from_bytes(server_tls_record_header[3:])
        server_tls_record = server_handshake_bytes[5 : 5 + server_tls_record_length]

        # Ensure we are working with a ServerHello message
        if server_tls_record[0:1] == b"\x02":
            protocol = {
                b"\x03\x00": "SSLv3",
                b"\x03\x01": "TLSv1",
                b"\x03\x02": "TLSv1.1",
                b"\x03\x03": "TLSv1.2",
                b"\x03\x04": "TLSv1.3",
            }[server_tls_record[4:6]]

            session_id_length = int_from_bytes(server_tls_record[38:39])
            if session_id_length > 0:
                server_session_id = server_tls_record[39 : 39 + session_id_length]

            cipher_suite_start = 39 + session_id_length
            cipher_suite_bytes = server_tls_record[cipher_suite_start : cipher_suite_start + 2]
            cipher_suite = CIPHER_SUITE_MAP[cipher_suite_bytes]

            compression_start = cipher_suite_start + 2
            compression = server_tls_record[compression_start : compression_start + 1] != b"\x00"

            extensions_length_start = compression_start + 1
            if extensions_length_start < len(server_tls_record):
                extentions_length = int_from_bytes(
                    server_tls_record[extensions_length_start : extensions_length_start + 2]
                )
                extensions_start = extensions_length_start + 2
                extensions_end = extensions_start + extentions_length
                extension_start = extensions_start
                while extension_start < extensions_end:
                    extension_type = int_from_bytes(server_tls_record[extension_start : extension_start + 2])
                    extension_length = int_from_bytes(server_tls_record[extension_start + 2 : extension_start + 4])
                    if extension_type == 35:
                        session_ticket = "new"
                    extension_start += 4 + extension_length

    if client_handshake_bytes[0:1] == b"\x16":
        client_tls_record_header = client_handshake_bytes[0:5]
        client_tls_record_length = int_from_bytes(client_tls_record_header[3:])
        client_tls_record = client_handshake_bytes[5 : 5 + client_tls_record_length]

        # Ensure we are working with a ClientHello message
        if client_tls_record[0:1] == b"\x01":
            session_id_length = int_from_bytes(client_tls_record[38:39])
            if session_id_length > 0:
                client_session_id = client_tls_record[39 : 39 + session_id_length]

            cipher_suite_start = 39 + session_id_length
            cipher_suite_length = int_from_bytes(client_tls_record[cipher_suite_start : cipher_suite_start + 2])

            compression_start = cipher_suite_start + 2 + cipher_suite_length
            compression_length = int_from_bytes(client_tls_record[compression_start : compression_start + 1])

            # On subsequent requests, the session ticket will only be seen
            # in the ClientHello message
            if server_session_id is None and session_ticket is None:
                extensions_length_start = compression_start + 1 + compression_length
                if extensions_length_start < len(client_tls_record):
                    extentions_length = int_from_bytes(
                        client_tls_record[extensions_length_start : extensions_length_start + 2]
                    )
                    extensions_start = extensions_length_start + 2
                    extensions_end = extensions_start + extentions_length
                    extension_start = extensions_start
                    while extension_start < extensions_end:
                        extension_type = int_from_bytes(client_tls_record[extension_start : extension_start + 2])
                        extension_length = int_from_bytes(client_tls_record[extension_start + 2 : extension_start + 4])
                        if extension_type == 35:
                            session_ticket = "reused"
                        extension_start += 4 + extension_length

    if server_session_id is not None:
        if client_session_id is None:
            session_id = "new"
        else:
            if client_session_id != server_session_id:
                session_id = "new"
            else:
                session_id = "reused"

    return {
        "protocol": protocol,
        "cipher_suite": cipher_suite,
        "compression": compression,
        "session_id": session_id,
        "session_ticket": session_ticket,
    }
Beispiel #29
0
def parse_session_info(server_handshake_bytes, client_handshake_bytes):
    """
    Parse the TLS handshake from the client to the server to extract information
    including the cipher suite selected, if compression is enabled, the
    session id and if a new or reused session ticket exists.

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :param client_handshake_bytes:
        A byte string of the handshake data sent to the server

    :return:
        A dict with the following keys:
         - "protocol": unicode string
         - "cipher_suite": unicode string
         - "compression": boolean
         - "session_id": "new", "reused" or None
         - "session_ticket: "new", "reused" or None
    """

    protocol = None
    cipher_suite = None
    compression = False
    session_id = None
    session_ticket = None

    server_session_id = None
    client_session_id = None

    for record_type, _, record_data in parse_tls_records(server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(record_data):
            # Ensure we are working with a ServerHello message
            if message_type != b'\x02':
                continue
            protocol = {
                b'\x03\x00': "SSLv3",
                b'\x03\x01': "TLSv1",
                b'\x03\x02': "TLSv1.1",
                b'\x03\x03': "TLSv1.2",
                b'\x03\x04': "TLSv1.3",
            }[message_data[0:2]]

            session_id_length = int_from_bytes(message_data[34:35])
            if session_id_length > 0:
                server_session_id = message_data[35:35 + session_id_length]

            cipher_suite_start = 35 + session_id_length
            cipher_suite_bytes = message_data[cipher_suite_start:cipher_suite_start + 2]
            cipher_suite = CIPHER_SUITE_MAP[cipher_suite_bytes]

            compression_start = cipher_suite_start + 2
            compression = message_data[compression_start:compression_start + 1] != b'\x00'

            extensions_length_start = compression_start + 1
            extensions_data = message_data[extensions_length_start:]
            for extension_type, extension_data in _parse_hello_extensions(extensions_data):
                if extension_type == 35:
                    session_ticket = "new"
                    break
            break

    for record_type, _, record_data in parse_tls_records(client_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(record_data):
            # Ensure we are working with a ClientHello message
            if message_type != b'\x01':
                continue

            session_id_length = int_from_bytes(message_data[34:35])
            if session_id_length > 0:
                client_session_id = message_data[35:35 + session_id_length]

            cipher_suite_start = 35 + session_id_length
            cipher_suite_length = int_from_bytes(message_data[cipher_suite_start:cipher_suite_start + 2])

            compression_start = cipher_suite_start + 2 + cipher_suite_length
            compression_length = int_from_bytes(message_data[compression_start:compression_start + 1])

            # On subsequent requests, the session ticket will only be seen
            # in the ClientHello message
            if server_session_id is None and session_ticket is None:
                extensions_length_start = compression_start + 1 + compression_length
                extensions_data = message_data[extensions_length_start:]
                for extension_type, extension_data in _parse_hello_extensions(extensions_data):
                    if extension_type == 35:
                        session_ticket = "reused"
                        break
            break

    if server_session_id is not None:
        if client_session_id is None:
            session_id = "new"
        else:
            if client_session_id != server_session_id:
                session_id = "new"
            else:
                session_id = "reused"

    return {
        "protocol": protocol,
        "cipher_suite": cipher_suite,
        "compression": compression,
        "session_id": session_id,
        "session_ticket": session_ticket,
    }
Beispiel #30
0
    def build_mpc(self, signing_private_key, orq_ip, orq_port):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it with a mpc engine

        :param signing_private_key:
            An integer identifier for the private key to sign the certificate with.
            This identifier permits the mpc engine to find the private key associated
            with the public key exported in the key generation process

        :return:
            An asn1crypto.x509.Certificate object of the newly signed
            certificate
        """
        def rsa_mpc_sign(signing_private_key, tbs_cert_dump, hash_algo, orq_ip,
                         orq_port):

            # Calculate hash corresponding to tbs_cert_dump	[Only SHA256]
            digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
            digest.update(tbs_cert_dump)
            digest = hexlify(digest.finalize())
            print("[Client] Hash: " + digest)

            # Send HTTP GET request to the orquestrator for signing
            target = "http://" + orq_ip + ":" + orq_port + "/signMessage/" + str(
                signing_private_key) + "?message=" + str(digest)
            r = requests.get(target)

            response = dict(json.loads(r.text))

            # Sign format must be bytearray
            print("Firma: " + str(response["sign"]))
            firma = unhexlify(response["sign"])
            return firma

        if self._self_signed is not True and self._issuer is None:
            raise ValueError(
                _pretty_message('''
                Certificate must be self-signed, or an issuer must be specified
                '''))

        if self._self_signed:
            self._issuer = self._subject

        if self._serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = util.rand_bytes(4)
            self._serial_number = int_from_bytes(time_part + random_part)

        if self._begin_date is None:
            self._begin_date = datetime.now(timezone.utc)

        if self._end_date is None:
            self._end_date = self._begin_date + timedelta(365)

        if not self.ca:
            for ca_only_extension in set([
                    'policy_mappings', 'policy_constraints',
                    'inhibit_any_policy'
            ]):
                if ca_only_extension in self._other_extensions:
                    raise ValueError(
                        _pretty_message(
                            '''
                        Extension %s is only valid for CA certificates
                        ''', ca_only_extension))

        # The algorith is forced to be 'rsa'
        signature_algo = 'rsa'

        # Hash's algorithm limited to SHA256 (internally)
        signature_algorithm_id = '%s_%s' % (self._hash_algo, signature_algo)

        def _make_extension(name, value):
            return {
                'extn_id': name,
                'critical': self._determine_critical(name),
                'extn_value': value
            }

        extensions = []
        for name in sorted(self._special_extensions):
            value = getattr(self, '_%s' % name)
            if name == 'ocsp_no_check':
                value = core.Null() if value else None
            if value is not None:
                extensions.append(_make_extension(name, value))

        for name in sorted(self._other_extensions.keys()):
            extensions.append(
                _make_extension(name, self._other_extensions[name]))

        tbs_cert = x509.TbsCertificate({
            'version': 'v3',
            'serial_number': self._serial_number,
            'signature': {
                'algorithm': signature_algorithm_id
            },
            'issuer': self._issuer,
            'validity': {
                'not_before': x509.Time(name='utc_time',
                                        value=self._begin_date),
                'not_after': x509.Time(name='utc_time', value=self._end_date),
            },
            'subject': self._subject,
            'subject_public_key_info': self._subject_public_key,
            'extensions': extensions
        })

        # Function binding
        sign_func = rsa_mpc_sign
        print(orq_ip)
        print(orq_port)
        signature = sign_func(signing_private_key, tbs_cert.dump(),
                              self._hash_algo, orq_ip, orq_port)

        return x509.Certificate({
            'tbs_certificate': tbs_cert,
            'signature_algorithm': {
                'algorithm': signature_algorithm_id
            },
            'signature_value': signature
        })
Beispiel #31
0
def parse_session_info(server_handshake_bytes, client_handshake_bytes):
    """
    Parse the TLS handshake from the client to the server to extract information
    including the cipher suite selected, if compression is enabled, the
    session id and if a new or reused session ticket exists.

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :param client_handshake_bytes:
        A byte string of the handshake data sent to the server

    :return:
        A dict with the following keys:
         - "protocol": unicode string
         - "cipher_suite": unicode string
         - "compression": boolean
         - "session_id": "new", "reused" or None
         - "session_ticket: "new", "reused" or None
    """

    protocol = None
    cipher_suite = None
    compression = False
    session_id = None
    session_ticket = None

    server_session_id = None
    client_session_id = None

    if server_handshake_bytes[0:1] == b'\x16':
        server_tls_record_header = server_handshake_bytes[0:5]
        server_tls_record_length = int_from_bytes(server_tls_record_header[3:])
        server_tls_record = server_handshake_bytes[5:5 +
                                                   server_tls_record_length]

        # Ensure we are working with a ServerHello message
        if server_tls_record[0:1] == b'\x02':
            protocol = {
                b'\x03\x00': "SSLv3",
                b'\x03\x01': "TLSv1",
                b'\x03\x02': "TLSv1.1",
                b'\x03\x03': "TLSv1.2",
                b'\x03\x04': "TLSv1.3",
            }[server_tls_record[4:6]]

            session_id_length = int_from_bytes(server_tls_record[38:39])
            if session_id_length > 0:
                server_session_id = server_tls_record[39:39 +
                                                      session_id_length]

            cipher_suite_start = 39 + session_id_length
            cipher_suite_bytes = server_tls_record[
                cipher_suite_start:cipher_suite_start + 2]
            cipher_suite = CIPHER_SUITE_MAP[cipher_suite_bytes]

            compression_start = cipher_suite_start + 2
            compression = server_tls_record[
                compression_start:compression_start + 1] != b'\x00'

            extensions_length_start = compression_start + 1
            if extensions_length_start < len(server_tls_record):
                extentions_length = int_from_bytes(server_tls_record[
                    extensions_length_start:extensions_length_start + 2])
                extensions_start = extensions_length_start + 2
                extensions_end = extensions_start + extentions_length
                extension_start = extensions_start
                while extension_start < extensions_end:
                    extension_type = int_from_bytes(
                        server_tls_record[extension_start:extension_start + 2])
                    extension_length = int_from_bytes(
                        server_tls_record[extension_start + 2:extension_start +
                                          4])
                    if extension_type == 35:
                        session_ticket = "new"
                    extension_start += 4 + extension_length

    if client_handshake_bytes[0:1] == b'\x16':
        client_tls_record_header = client_handshake_bytes[0:5]
        client_tls_record_length = int_from_bytes(client_tls_record_header[3:])
        client_tls_record = client_handshake_bytes[5:5 +
                                                   client_tls_record_length]

        # Ensure we are working with a ClientHello message
        if client_tls_record[0:1] == b'\x01':
            session_id_length = int_from_bytes(client_tls_record[38:39])
            if session_id_length > 0:
                client_session_id = client_tls_record[39:39 +
                                                      session_id_length]

            cipher_suite_start = 39 + session_id_length
            cipher_suite_length = int_from_bytes(
                client_tls_record[cipher_suite_start:cipher_suite_start + 2])

            compression_start = cipher_suite_start + 2 + cipher_suite_length
            compression_length = int_from_bytes(
                client_tls_record[compression_start:compression_start + 1])

            # On subsequent requests, the session ticket will only be seen
            # in the ClientHello message
            if server_session_id is None and session_ticket is None:
                extensions_length_start = compression_start + 1 + compression_length
                if extensions_length_start < len(client_tls_record):
                    extentions_length = int_from_bytes(client_tls_record[
                        extensions_length_start:extensions_length_start + 2])
                    extensions_start = extensions_length_start + 2
                    extensions_end = extensions_start + extentions_length
                    extension_start = extensions_start
                    while extension_start < extensions_end:
                        extension_type = int_from_bytes(
                            client_tls_record[extension_start:extension_start +
                                              2])
                        extension_length = int_from_bytes(
                            client_tls_record[extension_start +
                                              2:extension_start + 4])
                        if extension_type == 35:
                            session_ticket = "reused"
                        extension_start += 4 + extension_length

    if server_session_id is not None:
        if client_session_id is None:
            session_id = "new"
        else:
            if client_session_id != server_session_id:
                session_id = "new"
            else:
                session_id = "reused"

    return {
        "protocol": protocol,
        "cipher_suite": cipher_suite,
        "compression": compression,
        "session_id": session_id,
        "session_ticket": session_ticket,
    }
Beispiel #32
0
def raw_rsa_public_crypt(certificate_or_public_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a certificate or
    public key. This is a low-level primitive and is prone to disastrous results
    if used incorrectly.

    :param certificate_or_public_key:
        An oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
        object

    :param data:
        A byte string of the signature when verifying, or padded plaintext when
        encrypting. Must be less than or equal to the length of the public key.
        When verifying, padding will need to be removed afterwards. When
        encrypting, padding must be applied before.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    valid_types = (PublicKeyInfo, Certificate)
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1, valid_types):
        raise TypeError(pretty_message(
            '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''',
            type_name(certificate_or_public_key)
        ))

    algo = certificate_or_public_key.asn1['algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(pretty_message(
            '''
            certificate_or_public_key must be an RSA key, not %s
            ''',
            algo.upper()
        ))

    if not isinstance(data, byte_cls):
        raise TypeError(pretty_message(
            '''
            data must be a byte string, not %s
            ''',
            type_name(data)
        ))

    rsa_public_key = certificate_or_public_key.asn1['public_key'].parsed
    transformed_int = pow(
        int_from_bytes(data),
        rsa_public_key['public_exponent'].native,
        rsa_public_key['modulus'].native
    )
    return int_to_bytes(
        transformed_int,
        width=certificate_or_public_key.asn1.byte_size
    )
Beispiel #33
0
 def value(self):
     return int_from_bytes(self.contents, signed=True)
Beispiel #34
0
def ecdsa_sign(private_key, data, hash_algorithm):
    """
    Generates an ECDSA signature in pure Python (thus slow)

    :param private_key:
        The PrivateKey to generate the signature with

    :param data:
        A byte string of the data the signature is for

    :param hash_algorithm:
        A unicode string of "sha1", "sha256", "sha384" or "sha512"

    :raises:
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type
        OSError - when an error is returned by the OS crypto library

    :return:
        A byte string of the signature
    """

    if not hasattr(private_key, 'asn1') or not isinstance(private_key.asn1, keys.PrivateKeyInfo):
        raise TypeError(pretty_message(
            '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''',
            type_name(private_key)
        ))

    curve_name = private_key.curve
    if curve_name not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(pretty_message(
            '''
            private_key does not use one of the named curves secp256r1,
            secp384r1 or secp521r1
            '''
        ))

    if not isinstance(data, byte_cls):
        raise TypeError(pretty_message(
            '''
            data must be a byte string, not %s
            ''',
            type_name(data)
        ))

    if hash_algorithm not in set(['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    hash_func = getattr(hashlib, hash_algorithm)

    ec_private_key = private_key.asn1['private_key'].parsed
    private_key_bytes = ec_private_key['private_key'].contents
    private_key_int = ec_private_key['private_key'].native

    curve_num_bytes = CURVE_BYTES[curve_name]
    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve_name]

    n = curve_base_point.order

    # RFC 6979 section 3.2

    # a.
    digest = hash_func(data).digest()
    hash_length = len(digest)

    h = int_from_bytes(digest, signed=False) % n

    # b.
    V = b'\x01' * hash_length

    # c.
    K = b'\x00' * hash_length

    # d.
    K = hmac.new(K, V + b'\x00' + private_key_bytes + digest, hash_func).digest()

    # e.
    V = hmac.new(K, V, hash_func).digest()

    # f.
    K = hmac.new(K, V + b'\x01' + private_key_bytes + digest, hash_func).digest()

    # g.
    V = hmac.new(K, V, hash_func).digest()

    # h.
    r = 0
    s = 0
    while True:
        # h. 1
        T = b''

        # h. 2
        while len(T) < curve_num_bytes:
            V = hmac.new(K, V, hash_func).digest()
            T += V

        # h. 3
        k = int_from_bytes(T[0:curve_num_bytes], signed=False)
        if k == 0 or k >= n:
            continue

        # Calculate the signature in the loop in case we need a new k
        r = (curve_base_point * k).x % n
        if r == 0:
            continue

        s = (inverse_mod(k, n) * (h + (private_key_int * r) % n)) % n
        if s == 0:
            continue

        break

    return DSASignature({'r': r, 's': s}).dump()
Beispiel #35
0
def parse_session_info(server_handshake_bytes, client_handshake_bytes):
    """
    Parse the TLS handshake from the client to the server to extract information
    including the cipher suite selected, if compression is enabled, the
    session id and if a new or reused session ticket exists.

    :param server_handshake_bytes:
        A byte string of the handshake data received from the server

    :param client_handshake_bytes:
        A byte string of the handshake data sent to the server

    :return:
        A dict with the following keys:
         - "protocol": unicode string
         - "cipher_suite": unicode string
         - "compression": boolean
         - "session_id": "new", "reused" or None
         - "session_ticket: "new", "reused" or None
    """

    protocol = None
    cipher_suite = None
    compression = False
    session_id = None
    session_ticket = None

    server_session_id = None
    client_session_id = None

    for record_type, _, record_data in parse_tls_records(
            server_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(
                record_data):
            # Ensure we are working with a ServerHello message
            if message_type != b'\x02':
                continue
            protocol = {
                b'\x03\x00': "SSLv3",
                b'\x03\x01': "TLSv1",
                b'\x03\x02': "TLSv1.1",
                b'\x03\x03': "TLSv1.2",
                b'\x03\x04': "TLSv1.3",
            }[message_data[0:2]]

            session_id_length = int_from_bytes(message_data[34:35])
            if session_id_length > 0:
                server_session_id = message_data[35:35 + session_id_length]

            cipher_suite_start = 35 + session_id_length
            cipher_suite_bytes = message_data[
                cipher_suite_start:cipher_suite_start + 2]
            cipher_suite = CIPHER_SUITE_MAP[cipher_suite_bytes]

            compression_start = cipher_suite_start + 2
            compression = message_data[compression_start:compression_start +
                                       1] != b'\x00'

            extensions_length_start = compression_start + 1
            extensions_data = message_data[extensions_length_start:]
            for extension_type, extension_data in _parse_hello_extensions(
                    extensions_data):
                if extension_type == 35:
                    session_ticket = "new"
                    break
            break

    for record_type, _, record_data in parse_tls_records(
            client_handshake_bytes):
        if record_type != b'\x16':
            continue
        for message_type, message_data in parse_handshake_messages(
                record_data):
            # Ensure we are working with a ClientHello message
            if message_type != b'\x01':
                continue

            session_id_length = int_from_bytes(message_data[34:35])
            if session_id_length > 0:
                client_session_id = message_data[35:35 + session_id_length]

            cipher_suite_start = 35 + session_id_length
            cipher_suite_length = int_from_bytes(
                message_data[cipher_suite_start:cipher_suite_start + 2])

            compression_start = cipher_suite_start + 2 + cipher_suite_length
            compression_length = int_from_bytes(
                message_data[compression_start:compression_start + 1])

            # On subsequent requests, the session ticket will only be seen
            # in the ClientHello message
            if server_session_id is None and session_ticket is None:
                extensions_length_start = compression_start + 1 + compression_length
                extensions_data = message_data[extensions_length_start:]
                for extension_type, extension_data in _parse_hello_extensions(
                        extensions_data):
                    if extension_type == 35:
                        session_ticket = "reused"
                        break
            break

    if server_session_id is not None:
        if client_session_id is None:
            session_id = "new"
        else:
            if client_session_id != server_session_id:
                session_id = "new"
            else:
                session_id = "reused"

    return {
        "protocol": protocol,
        "cipher_suite": cipher_suite,
        "compression": compression,
        "session_id": session_id,
        "session_ticket": session_ticket,
    }
Beispiel #36
0
def ecdsa_verify(certificate_or_public_key, signature, data, hash_algorithm):
    """
    Verifies an ECDSA signature in pure Python (thus slow)

    :param certificate_or_public_key:
        A Certificate or PublicKey instance to verify the signature with

    :param signature:
        A byte string of the signature to verify

    :param data:
        A byte string of the data the signature is for

    :param hash_algorithm:
        A unicode string of "md5", "sha1", "sha256", "sha384" or "sha512"

    :raises:
        oscrypto.errors.SignatureError - when the signature is determined to be invalid
        ValueError - when any of the parameters contain an invalid value
        TypeError - when any of the parameters are of the wrong type
        OSError - when an error is returned by the OS crypto library
    """

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1, (keys.PublicKeyInfo, Certificate)):
        raise TypeError(pretty_message(
            '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''',
            type_name(certificate_or_public_key)
        ))

    curve_name = certificate_or_public_key.curve
    if curve_name not in set(['secp256r1', 'secp384r1', 'secp521r1']):
        raise ValueError(pretty_message(
            '''
            certificate_or_public_key does not use one of the named curves
            secp256r1, secp384r1 or secp521r1
            '''
        ))

    if not isinstance(signature, byte_cls):
        raise TypeError(pretty_message(
            '''
            signature must be a byte string, not %s
            ''',
            type_name(signature)
        ))

    if not isinstance(data, byte_cls):
        raise TypeError(pretty_message(
            '''
            data must be a byte string, not %s
            ''',
            type_name(data)
        ))

    if hash_algorithm not in set(['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    asn1 = certificate_or_public_key.asn1
    if isinstance(asn1, Certificate):
        asn1 = asn1.public_key

    curve_base_point = {
        'secp256r1': SECP256R1_BASE_POINT,
        'secp384r1': SECP384R1_BASE_POINT,
        'secp521r1': SECP521R1_BASE_POINT,
    }[curve_name]

    x, y = asn1['public_key'].to_coords()
    n = curve_base_point.order

    # Validates that the point is valid
    public_key_point = PrimePoint(curve_base_point.curve, x, y, n)

    try:
        signature = DSASignature.load(signature)
        r = signature['r'].native
        s = signature['s'].native
    except (ValueError):
        raise SignatureError('Signature is invalid')

    invalid = 0

    # Check r is valid
    invalid |= r < 1
    invalid |= r >= n

    # Check s is valid
    invalid |= s < 1
    invalid |= s >= n

    if invalid:
        raise SignatureError('Signature is invalid')

    hash_func = getattr(hashlib, hash_algorithm)

    digest = hash_func(data).digest()

    z = int_from_bytes(digest, signed=False) % n
    w = inverse_mod(s, n)
    u1 = (z * w) % n
    u2 = (r * w) % n
    hash_point = (curve_base_point * u1) + (public_key_point * u2)
    if r != (hash_point.x % n):
        raise SignatureError('Signature is invalid')
Beispiel #37
0
    def build(self, signing_private_key):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it

        :param signing_private_key:
            An asn1crypto.keys.PrivateKeyInfo or oscrypto.asymmetric.PrivateKey
            object for the private key to sign the certificate with. If the key
            is self-signed, this should be the private key that matches the
            public key, otherwise it needs to be the issuer's private key.

        :return:
            An asn1crypto.x509.Certificate object of the newly signed
            certificate
        """

        is_oscrypto = isinstance(signing_private_key, asymmetric.PrivateKey)
        if not isinstance(signing_private_key, keys.PrivateKeyInfo) and not is_oscrypto:
            raise TypeError(_pretty_message(
                '''
                signing_private_key must be an instance of
                asn1crypto.keys.PrivateKeyInfo or
                oscrypto.asymmetric.PrivateKey, not %s
                ''',
                _type_name(signing_private_key)
            ))

        if self._self_signed is not True and self._issuer is None:
            raise ValueError(_pretty_message(
                '''
                Certificate must be self-signed, or an issuer must be specified
                '''
            ))

        if self._self_signed:
            self._issuer = self._subject

        if self._serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = util.rand_bytes(4)
            self._serial_number = int_from_bytes(time_part + random_part)

        if self._begin_date is None:
            self._begin_date = datetime.now(timezone.utc)

        if self._end_date is None:
            self._end_date = self._begin_date + timedelta(365)

        if not self.ca:
            for ca_only_extension in set(['policy_mappings', 'policy_constraints', 'inhibit_any_policy']):
                if ca_only_extension in self._other_extensions:
                    raise ValueError(_pretty_message(
                        '''
                        Extension %s is only valid for CA certificates
                        ''',
                        ca_only_extension
                    ))

        signature_algo = signing_private_key.algorithm
        if signature_algo == 'ec':
            signature_algo = 'ecdsa'

        signature_algorithm_id = '%s_%s' % (self._hash_algo, signature_algo)

        # RFC 3280 4.1.2.5
        def _make_validity_time(dt):
            if dt < datetime(2050, 1, 1, tzinfo=timezone.utc):
                value = x509.Time(name='utc_time', value=dt)
            else:
                value = x509.Time(name='general_time', value=dt)

            return value

        def _make_extension(name, value):
            return {
                'extn_id': name,
                'critical': self._determine_critical(name),
                'extn_value': value
            }

        extensions = []
        for name in sorted(self._special_extensions):
            value = getattr(self, '_%s' % name)
            if name == 'ocsp_no_check':
                value = core.Null() if value else None
            if value is not None:
                extensions.append(_make_extension(name, value))

        for name in sorted(self._other_extensions.keys()):
            extensions.append(_make_extension(name, self._other_extensions[name]))

        tbs_cert = x509.TbsCertificate({
            'version': 'v3',
            'serial_number': self._serial_number,
            'signature': {
                'algorithm': signature_algorithm_id
            },
            'issuer': self._issuer,
            'validity': {
                'not_before': _make_validity_time(self._begin_date),
                'not_after': _make_validity_time(self._end_date),
            },
            'subject': self._subject,
            'subject_public_key_info': self._subject_public_key,
            'extensions': extensions
        })

        if signing_private_key.algorithm == 'rsa':
            sign_func = asymmetric.rsa_pkcs1v15_sign
        elif signing_private_key.algorithm == 'dsa':
            sign_func = asymmetric.dsa_sign
        elif signing_private_key.algorithm == 'ec':
            sign_func = asymmetric.ecdsa_sign

        if not is_oscrypto:
            signing_private_key = asymmetric.load_private_key(signing_private_key)
        signature = sign_func(signing_private_key, tbs_cert.dump(), self._hash_algo)

        return x509.Certificate({
            'tbs_certificate': tbs_cert,
            'signature_algorithm': {
                'algorithm': signature_algorithm_id
            },
            'signature_value': signature
        })
Beispiel #38
0
    def build(self, signing_private_key):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it

        :param signing_private_key:
            An asn1crypto.keys.PrivateKeyInfo or oscrypto.asymmetric.PrivateKey
            object for the private key to sign the certificate with. If the key
            is self-signed, this should be the private key that matches the
            public key, otherwise it needs to be the issuer's private key.

        :return:
            An asn1crypto.x509.Certificate object of the newly signed
            certificate
        """

        is_oscrypto = isinstance(signing_private_key, asymmetric.PrivateKey)
        if not isinstance(
                signing_private_key, keys.PrivateKeyInfo
        ) and not is_oscrypto and signing_private_key is not None:
            raise TypeError(
                _pretty_message(
                    '''
                signing_private_key must be an instance of
                asn1crypto.keys.PrivateKeyInfo or
                oscrypto.asymmetric.PrivateKey, not %s
                ''', _type_name(signing_private_key)))

        if self._self_signed is not True and self._issuer is None:
            raise ValueError(
                _pretty_message('''
                Certificate must be self-signed, or an issuer must be specified
                '''))

        if self._self_signed:
            self._issuer = self._subject

        if self._serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = util.rand_bytes(4)
            self._serial_number = int_from_bytes(time_part + random_part)

        if self._begin_date is None:
            self._begin_date = datetime.now(timezone.utc)

        if self._end_date is None:
            self._end_date = self._begin_date + timedelta(365)

        if not self.ca:
            for ca_only_extension in set([
                    'policy_mappings', 'policy_constraints',
                    'inhibit_any_policy'
            ]):
                if ca_only_extension in self._other_extensions:
                    raise ValueError(
                        _pretty_message(
                            '''
                        Extension %s is only valid for CA certificates
                        ''', ca_only_extension))

        if signing_private_key is not None:
            signature_algo = signing_private_key.algorithm
            if signature_algo == 'ec':
                signature_algo = 'ecdsa'

            signature_algorithm_id = '%s_%s' % (self._hash_algo,
                                                signature_algo)
        else:
            signature_algorithm_id = '%s_%s' % (
                self._hash_algo, "rsa")  #making rsa assumption for ease

        # RFC 3280 4.1.2.5
        def _make_validity_time(dt):
            if dt < datetime(2050, 1, 1, tzinfo=timezone.utc):
                value = x509.Time(name='utc_time', value=dt)
            else:
                value = x509.Time(name='general_time', value=dt)

            return value

        def _make_extension(name, value):
            return {
                'extn_id': name,
                'critical': self._determine_critical(name),
                'extn_value': value
            }

        extensions = []
        for name in sorted(self._special_extensions):
            value = getattr(self, '_%s' % name)
            if name == 'ocsp_no_check':
                value = core.Null() if value else None
            if value is not None:
                extensions.append(_make_extension(name, value))

        for name in sorted(self._other_extensions.keys()):
            extensions.append(
                _make_extension(name, self._other_extensions[name]))

        tbs_cert = x509.TbsCertificate({
            'version': 'v3',
            'serial_number': self._serial_number,
            'signature': {
                'algorithm': signature_algorithm_id
            },
            'issuer': self._issuer,
            'validity': {
                'not_before': _make_validity_time(self._begin_date),
                'not_after': _make_validity_time(self._end_date),
            },
            'subject': self._subject,
            'subject_public_key_info': self._subject_public_key,
            'extensions': extensions
        })

        if signing_private_key is None:
            return tbs_cert
        elif signing_private_key.algorithm == 'rsa':
            sign_func = asymmetric.rsa_pkcs1v15_sign
        elif signing_private_key.algorithm == 'dsa':
            sign_func = asymmetric.dsa_sign
        elif signing_private_key.algorithm == 'ec':
            sign_func = asymmetric.ecdsa_sign

        if not is_oscrypto:
            signing_private_key = asymmetric.load_private_key(
                signing_private_key)
        signature = sign_func(signing_private_key, tbs_cert.dump(),
                              self._hash_algo)

        return x509.Certificate({
            'tbs_certificate': tbs_cert,
            'signature_algorithm': {
                'algorithm': signature_algorithm_id
            },
            'signature_value': signature
        })
Beispiel #39
0
def add_pss_padding(hash_algorithm, salt_length, key_length, message):
    """
    Pads a byte string using the EMSA-PSS-Encode operation described in PKCS#1
    v2.2.

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :return:
        The encoded (passed) message
    """

    if not isinstance(message, byte_cls):
        raise TypeError(pretty_message(
            '''
            message must be a byte string, not %s
            ''',
            type_name(message)
        ))

    if not isinstance(salt_length, int_types):
        raise TypeError(pretty_message(
            '''
            salt_length must be an integer, not %s
            ''',
            type_name(salt_length)
        ))

    if salt_length < 0:
        raise ValueError(pretty_message(
            '''
            salt_length must be 0 or more - is %s
            ''',
            repr(salt_length)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 512:
        raise ValueError(pretty_message(
            '''
            key_length must be 512 or more - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    hash_func = getattr(hashlib, hash_algorithm)

    # The maximal bit size of a non-negative integer is one less than the bit
    # size of the key since the first bit is used to store sign
    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        raise ValueError(pretty_message(
            '''
            Key is not long enough to use with specified hash_algorithm and
            salt_length
            '''
        ))

    if salt_length > 0:
        salt = os.urandom(salt_length)
    else:
        salt = b''

    m_prime = (b'\x00' * 8) + message_digest + salt

    m_prime_digest = hash_func(m_prime).digest()

    padding = b'\x00' * (em_len - salt_length - hash_length - 2)

    db = padding + b'\x01' + salt

    db_mask = mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    masked_db = int_to_bytes(int_from_bytes(db) ^ int_from_bytes(db_mask))
    masked_db = fill_width(masked_db, len(db_mask))

    zero_bits = (8 * em_len) - em_bits
    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        masked_db = chr_cls(left_int_mask & ord(masked_db[0:1])) + masked_db[1:]

    return masked_db + m_prime_digest + b'\xBC'
Beispiel #40
0
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
    """
    KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :param id_:
        The ID of the usage - 1 for key, 2 for iv, 3 for mac

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            password must be a byte string, not %s
            ''', type_name(password)))

    if not isinstance(salt, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            salt must be a byte string, not %s
            ''', type_name(salt)))

    if not isinstance(iterations, int_types):
        raise TypeError(
            pretty_message(
                '''
            iterations must be an integer, not %s
            ''', type_name(iterations)))

    if iterations < 1:
        raise ValueError(
            pretty_message(
                '''
            iterations must be greater than 0 - is %s
            ''', repr(iterations)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 1:
        raise ValueError(
            pretty_message(
                '''
            key_length must be greater than 0 - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''', repr(hash_algorithm)))

    if id_ not in set([1, 2, 3]):
        raise ValueError(
            pretty_message(
                '''
            id_ must be one of 1, 2, 3, not %s
            ''', repr(id_)))

    utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'

    algo = getattr(hashlib, hash_algorithm)

    # u and v values are bytes (not bits as in the RFC)
    u = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    if hash_algorithm in ['sha384', 'sha512']:
        v = 128
    else:
        v = 64

    # Step 1
    d = chr_cls(id_) * v

    # Step 2
    s = b''
    if salt != b'':
        s_len = v * int(math.ceil(float(len(salt)) / v))
        while len(s) < s_len:
            s += salt
        s = s[0:s_len]

    # Step 3
    p = b''
    if utf16_password != b'':
        p_len = v * int(math.ceil(float(len(utf16_password)) / v))
        while len(p) < p_len:
            p += utf16_password
        p = p[0:p_len]

    # Step 4
    i = s + p

    # Step 5
    c = int(math.ceil(float(key_length) / u))

    a = b'\x00' * (c * u)

    for num in range(1, c + 1):
        # Step 6A
        a2 = algo(d + i).digest()
        for _ in range(2, iterations + 1):
            a2 = algo(a2).digest()

        if num < c:
            # Step 6B
            b = b''
            while len(b) < v:
                b += a2

            b = int_from_bytes(b[0:v]) + 1

            # Step 6C
            for num2 in range(0, len(i) // v):
                start = num2 * v
                end = (num2 + 1) * v
                i_num2 = i[start:end]

                i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)

                # Ensure the new slice is the right size
                i_num2_l = len(i_num2)
                if i_num2_l > v:
                    i_num2 = i_num2[i_num2_l - v:]

                i = i[0:start] + i_num2 + i[end:]

        # Step 7 (one peice at a time)
        begin = (num - 1) * u
        to_copy = min(key_length, u)
        a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]

    return a[0:key_length]
Beispiel #41
0
def pbkdf2(hash_algorithm, password, salt, iterations, key_length):
    """
    Implements PBKDF2 from PKCS#5 v2.2 in pure Python

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(pretty_message(
            '''
            password must be a byte string, not %s
            ''',
            type_name(password)
        ))

    if not isinstance(salt, byte_cls):
        raise TypeError(pretty_message(
            '''
            salt must be a byte string, not %s
            ''',
            type_name(salt)
        ))

    if not isinstance(iterations, int_types):
        raise TypeError(pretty_message(
            '''
            iterations must be an integer, not %s
            ''',
            type_name(iterations)
        ))

    if iterations < 1:
        raise ValueError(pretty_message(
            '''
            iterations must be greater than 0 - is %s
            ''',
            repr(iterations)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 1:
        raise ValueError(pretty_message(
            '''
            key_length must be greater than 0 - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    algo = getattr(hashlib, hash_algorithm)

    hash_length = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    blocks = int(math.ceil(key_length / hash_length))

    original_hmac = hmac.new(password, None, algo)

    int_pack = struct.Struct(b'>I').pack

    output = b''
    for block in range(1, blocks + 1):
        prf = original_hmac.copy()
        prf.update(salt + int_pack(block))
        last = prf.digest()
        u = int_from_bytes(last)
        for _ in range(2, iterations + 1):
            prf = original_hmac.copy()
            prf.update(last)
            last = prf.digest()
            u ^= int_from_bytes(last)
        t = int_to_bytes(u)
        output += t

    return output[0:key_length]
Beispiel #42
0
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
    """
    KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :param id_:
        The ID of the usage - 1 for key, 2 for iv, 3 for mac

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(pretty_message(
            '''
            password must be a byte string, not %s
            ''',
            type_name(password)
        ))

    if not isinstance(salt, byte_cls):
        raise TypeError(pretty_message(
            '''
            salt must be a byte string, not %s
            ''',
            type_name(salt)
        ))

    if not isinstance(iterations, int_types):
        raise TypeError(pretty_message(
            '''
            iterations must be an integer, not %s
            ''',
            type_name(iterations)
        ))

    if iterations < 1:
        raise ValueError(pretty_message(
            '''
            iterations must be greater than 0 - is %s
            ''',
            repr(iterations)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 1:
        raise ValueError(pretty_message(
            '''
            key_length must be greater than 0 - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    if id_ not in set([1, 2, 3]):
        raise ValueError(pretty_message(
            '''
            id_ must be one of 1, 2, 3, not %s
            ''',
            repr(id_)
        ))

    utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'

    algo = getattr(hashlib, hash_algorithm)

    # u and v values are bytes (not bits as in the RFC)
    u = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    if hash_algorithm in ['sha384', 'sha512']:
        v = 128
    else:
        v = 64

    # Step 1
    d = chr_cls(id_) * v

    # Step 2
    s = b''
    if salt != b'':
        s_len = v * int(math.ceil(float(len(salt)) / v))
        while len(s) < s_len:
            s += salt
        s = s[0:s_len]

    # Step 3
    p = b''
    if utf16_password != b'':
        p_len = v * int(math.ceil(float(len(utf16_password)) / v))
        while len(p) < p_len:
            p += utf16_password
        p = p[0:p_len]

    # Step 4
    i = s + p

    # Step 5
    c = int(math.ceil(float(key_length) / u))

    a = b'\x00' * (c * u)

    for num in range(1, c + 1):
        # Step 6A
        a2 = algo(d + i).digest()
        for _ in range(2, iterations + 1):
            a2 = algo(a2).digest()

        if num < c:
            # Step 6B
            b = b''
            while len(b) < v:
                b += a2

            b = int_from_bytes(b[0:v]) + 1

            # Step 6C
            for num2 in range(0, len(i) // v):
                start = num2 * v
                end = (num2 + 1) * v
                i_num2 = i[start:end]

                i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)

                # Ensure the new slice is the right size
                i_num2_l = len(i_num2)
                if i_num2_l > v:
                    i_num2 = i_num2[i_num2_l - v:]

                i = i[0:start] + i_num2 + i[end:]

        # Step 7 (one peice at a time)
        begin = (num - 1) * u
        to_copy = min(key_length, u)
        a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]

    return a[0:key_length]
Beispiel #43
0
def verify_pss_padding(hash_algorithm, salt_length, key_length, message, signature):
    """
    Verifies the PSS padding on an encoded message

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :param signature:
        The signature to verify

    :return:
        A boolean indicating if the signature is invalid
    """

    if not isinstance(message, byte_cls):
        raise TypeError(pretty_message(
            '''
            message must be a byte string, not %s
            ''',
            type_name(message)
        ))

    if not isinstance(signature, byte_cls):
        raise TypeError(pretty_message(
            '''
            signature must be a byte string, not %s
            ''',
            type_name(signature)
        ))

    if not isinstance(salt_length, int_types):
        raise TypeError(pretty_message(
            '''
            salt_length must be an integer, not %s
            ''',
            type_name(salt_length)
        ))

    if salt_length < 0:
        raise ValueError(pretty_message(
            '''
            salt_length must be 0 or more - is %s
            ''',
            repr(salt_length)
        ))

    if hash_algorithm not in set(['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    hash_func = getattr(hashlib, hash_algorithm)

    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        return False

    if signature[-1:] != b'\xBC':
        return False

    zero_bits = (8 * em_len) - em_bits

    masked_db_length = em_len - hash_length - 1
    masked_db = signature[0:masked_db_length]

    first_byte = ord(masked_db[0:1])
    bits_that_should_be_zero = first_byte >> (8 - zero_bits)
    if bits_that_should_be_zero != 0:
        return False

    m_prime_digest = signature[masked_db_length:masked_db_length + hash_length]

    db_mask = mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        db_mask = chr_cls(left_int_mask & ord(db_mask[0:1])) + db_mask[1:]

    db = int_to_bytes(int_from_bytes(masked_db) ^ int_from_bytes(db_mask))
    if len(db) < len(masked_db):
        db = (b'\x00' * (len(masked_db) - len(db))) + db

    zero_length = em_len - hash_length - salt_length - 2
    zero_string = b'\x00' * zero_length
    if not constant_compare(db[0:zero_length], zero_string):
        return False

    if db[zero_length:zero_length + 1] != b'\x01':
        return False

    salt = db[0 - salt_length:]

    m_prime = (b'\x00' * 8) + message_digest + salt

    h_prime = hash_func(m_prime).digest()

    return constant_compare(m_prime_digest, h_prime)
Beispiel #44
0
    def build(self, signing_private_key_path, debug=False):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it
        :param signing_private_key:
            path to a .pem file with a private key
        :return:
            An m2m.Certificate object of the newly signed
            certificate
        """
        if self._self_signed is not True and self._issuer is None:
            raise ValueError(_pretty_message(
                '''
                Certificate must be self-signed, or an issuer must be specified
                '''
            ))

        if self._self_signed:
            self._issuer = self._subject

        if self.serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = random.getrandbits(24).to_bytes(3, byteorder='big')  # Must contain at least 20 randomly generated BITS
            self.serial_number = int_from_bytes(time_part + random_part)

        # Only re non-optionals are always in this dict
        properties = {
            'version':Integer(value=self._version),
            'serialNumber':OctetString(value=self.serial_number.to_bytes(20, byteorder='big')),
            'subject':self.subject,
        }

        # Optional fields are only added if they're not None
        if self.ca_algorithm is not None:
            properties['cAAlgorithm'] = self.ca_algorithm
        if self.ca_algorithm_parameters is not None:
            properties['cAAlgParams'] = OctetString(value=self.ca_algorithm_parameters)
        if self.issuer is not None:
            properties['issuer'] = self.issuer
        if self.valid_from is not None:
            properties['validFrom'] = self.valid_from
        if self.valid_duration is not None:
            properties['validDuration'] =  self.valid_duration
        if self.pk_algorithm is not None:
            properties['pKAlgorithm'] = self.pk_algorithm
        if self.pk_algorithm_parameters is not None:
            properties['pKAlgParams'] = OctetString(value=self.pk_algorithm_parameters)
        if self.public_key is not None:
            properties['pubKey'] = OctetString(value=self.public_key)
        if self.authkey_id is not None:
            properties['authKeyId'] = self.authkey_id
        if self.subject_key_id is not None:
            properties['subjKeyId'] = OctetString(value=self.subject_key_id)
        if self.key_usage is not None:
            properties['keyUsage'] = OctetString(value=self.key_usage)
        if self.basic_constraints is not None:
            properties['basicConstraints'] =  Integer(value=self.basic_constraints)
        if self.certificate_policy is not None:
            properties['certificatePolicy'] = self.certificate_policy
        if self.subject_alternative_name is not None:
            properties['subjectAltName'] = self.subject_alternative_name
        if self.issuer_alternative_name is not None:
            properties['issuerAltName'] = self.issuer_alternative_name
        if self.extended_key_usage is not None:
            properties['extendedKeyUsage'] = self.extended_key_usage
        if self.auth_info_access_ocsp is not None:
            properties['authInfoAccessOCSP'] = self.auth_info_access_ocsp
        if self.crl_distribution_point_uri is not None:
            properties['cRLDistribPointURI'] = self.crl_distribution_point_uri
        if self.x509_extensions is not None:
            properties['x509extensions'] = self.x509_extensions

        # import ipdb; ipdb.set_trace()
        # break /usr/local/lib/python3.5/dist-packages/asn1crypto/core.py:2786
        tbs_cert = TBSCertificate(properties)

        bytes_to_sign = tbs_cert.dump()
        signature = generate_signature(bytes_to_sign, signing_private_key_path)

        # assert verify_signature(bytes_to_sign, signature, "public.pem")

        if debug:
            print("Build  - Signed_bytes ({len}): {content}".format(len=len(bytes_to_sign), content=hexlify(bytes_to_sign)))
            print("Build  - Signature ({len}): {content}".format(len=len(signature), content=hexlify(signature)))


        return Certificate({
            'tbsCertificate': tbs_cert,
            'cACalcValue': signature
        })