Example #1
0
 def test_int_fromto_bytes(self):
     for i in range(-300, 301):
         self.assertEqual(
             i, util.int_from_bytes(util.int_to_bytes(i, True), True))
     for i in range(0, 301):
         self.assertEqual(
             i, util.int_from_bytes(util.int_to_bytes(i, False), False))
Example #2
0
    def serial_number(self, value):
        """
        An int representable in 160 bits or less - must uniquely identify
        this certificate when combined with the issuer name.
        """

        if not isinstance(value, int_types):
            raise TypeError(
                _pretty_message(
                    '''
                serial_number must be an integer, not %s
                ''', _type_name(value)))

        if value < 0:
            raise ValueError(
                _pretty_message(
                    '''
                serial_number must be a non-negative integer, not %s
                ''', repr(value)))

        if len(int_to_bytes(value)) > 20:
            required_bits = len(int_to_bytes(value)) * 8
            raise ValueError(
                _pretty_message(
                    '''
                serial_number must be an integer that can be represented by a
                160-bit number, specified requires %s
                ''', required_bits))

        self._serial_number = value
Example #3
0
    def serial_number(self, value):
        """
        An int representable in 160 bits or less - must uniquely identify
        this certificate when combined with the issuer name.
        """

        if not isinstance(value, int_types):
            raise TypeError(_pretty_message(
                '''
                serial_number must be an integer, not %s
                ''',
                _type_name(value)
            ))

        if value < 0:
            raise ValueError(_pretty_message(
                '''
                serial_number must be a non-negative integer, not %s
                ''',
                repr(value)
            ))

        if len(int_to_bytes(value)) > 20:
            required_bits = len(int_to_bytes(value)) * 8
            raise ValueError(_pretty_message(
                '''
                serial_number must be an integer that can be represented by a
                160-bit number, specified requires %s
                ''',
                required_bits
            ))

        self._serial_number = value
Example #4
0
def raw_rsa_public_crypt(certificate_or_public_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a certificate or
    public key. This is a low-level primitive and is prone to disastrous results
    if used incorrectly.

    :param certificate_or_public_key:
        An oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
        object

    :param data:
        A byte string of the signature when verifying, or padded plaintext when
        encrypting. Must be less than or equal to the length of the public key.
        When verifying, padding will need to be removed afterwards. When
        encrypting, padding must be applied before.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    valid_types = (PublicKeyInfo, Certificate)
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1,
                                      valid_types):
        raise TypeError(
            pretty_message(
                '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''', type_name(certificate_or_public_key)))

    algo = certificate_or_public_key.asn1['algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            certificate_or_public_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_public_key = certificate_or_public_key.asn1['public_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_public_key['public_exponent'].native,
                          rsa_public_key['modulus'].native)
    return int_to_bytes(transformed_int,
                        width=certificate_or_public_key.asn1.byte_size)
Example #5
0
def raw_rsa_private_crypt(private_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a private key.
    This is a low-level primitive and is prone to disastrous results if used
    incorrectly.

    :param private_key:
        An oscrypto.asymmetric.PrivateKey object

    :param data:
        A byte string of the plaintext to be signed or ciphertext to be
        decrypted. Must be less than or equal to the length of the private key.
        In the case of signing, padding must already be applied. In the case of
        decryption, padding must be removed afterward.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    if not hasattr(private_key, 'asn1') or not isinstance(private_key.asn1, PrivateKeyInfo):
        raise TypeError(pretty_message(
            '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''',
            type_name(private_key)
        ))

    algo = private_key.asn1['private_key_algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(pretty_message(
            '''
            private_key must be an RSA key, not %s
            ''',
            algo.upper()
        ))

    if not isinstance(data, byte_cls):
        raise TypeError(pretty_message(
            '''
            data must be a byte string, not %s
            ''',
            type_name(data)
        ))

    rsa_private_key = private_key.asn1['private_key'].parsed
    transformed_int = pow(
        int_from_bytes(data),
        rsa_private_key['private_exponent'].native,
        rsa_private_key['modulus'].native
    )
    return int_to_bytes(transformed_int, width=private_key.asn1.byte_size)
Example #6
0
def raw_rsa_private_crypt(private_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a private key.
    This is a low-level primitive and is prone to disastrous results if used
    incorrectly.

    :param private_key:
        An oscrypto.asymmetric.PrivateKey object

    :param data:
        A byte string of the plaintext to be signed or ciphertext to be
        decrypted. Must be less than or equal to the length of the private key.
        In the case of signing, padding must already be applied. In the case of
        decryption, padding must be removed afterward.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    if not hasattr(private_key, 'asn1') or not isinstance(
            private_key.asn1, PrivateKeyInfo):
        raise TypeError(
            pretty_message(
                '''
            private_key must be an instance of the
            oscrypto.asymmetric.PrivateKey class, not %s
            ''', type_name(private_key)))

    algo = private_key.asn1['private_key_algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(
            pretty_message(
                '''
            private_key must be an RSA key, not %s
            ''', algo.upper()))

    if not isinstance(data, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            data must be a byte string, not %s
            ''', type_name(data)))

    rsa_private_key = private_key.asn1['private_key'].parsed
    transformed_int = pow(int_from_bytes(data),
                          rsa_private_key['private_exponent'].native,
                          rsa_private_key['modulus'].native)
    return int_to_bytes(transformed_int, width=private_key.asn1.byte_size)
Example #7
0
def add_pss_padding(hash_algorithm, salt_length, key_length, message):
    """
    Pads a byte string using the EMSA-PSS-Encode operation described in PKCS#1
    v2.2.

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :return:
        The encoded (passed) message
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding addition code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 512:
        raise ValueError(
            pretty_message(
                '''
            key_length must be 512 or more - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    # The maximal bit size of a non-negative integer is one less than the bit
    # size of the key since the first bit is used to store sign
    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        raise ValueError(
            pretty_message('''
            Key is not long enough to use with specified hash_algorithm and
            salt_length
            '''))

    if salt_length > 0:
        salt = os.urandom(salt_length)
    else:
        salt = b''

    m_prime = (b'\x00' * 8) + message_digest + salt

    m_prime_digest = hash_func(m_prime).digest()

    padding = b'\x00' * (em_len - salt_length - hash_length - 2)

    db = padding + b'\x01' + salt

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    masked_db = int_to_bytes(int_from_bytes(db) ^ int_from_bytes(db_mask))
    masked_db = fill_width(masked_db, len(db_mask))

    zero_bits = (8 * em_len) - em_bits
    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        masked_db = chr_cls(left_int_mask
                            & ord(masked_db[0:1])) + masked_db[1:]

    return masked_db + m_prime_digest + b'\xBC'
Example #8
0
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
    """
    KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :param id_:
        The ID of the usage - 1 for key, 2 for iv, 3 for mac

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(pretty_message(
            '''
            password must be a byte string, not %s
            ''',
            type_name(password)
        ))

    if not isinstance(salt, byte_cls):
        raise TypeError(pretty_message(
            '''
            salt must be a byte string, not %s
            ''',
            type_name(salt)
        ))

    if not isinstance(iterations, int_types):
        raise TypeError(pretty_message(
            '''
            iterations must be an integer, not %s
            ''',
            type_name(iterations)
        ))

    if iterations < 1:
        raise ValueError(pretty_message(
            '''
            iterations must be greater than 0 - is %s
            ''',
            repr(iterations)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 1:
        raise ValueError(pretty_message(
            '''
            key_length must be greater than 0 - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    if id_ not in set([1, 2, 3]):
        raise ValueError(pretty_message(
            '''
            id_ must be one of 1, 2, 3, not %s
            ''',
            repr(id_)
        ))

    utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'

    algo = getattr(hashlib, hash_algorithm)

    # u and v values are bytes (not bits as in the RFC)
    u = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    if hash_algorithm in ['sha384', 'sha512']:
        v = 128
    else:
        v = 64

    # Step 1
    d = chr_cls(id_) * v

    # Step 2
    s = b''
    if salt != b'':
        s_len = v * int(math.ceil(float(len(salt)) / v))
        while len(s) < s_len:
            s += salt
        s = s[0:s_len]

    # Step 3
    p = b''
    if utf16_password != b'':
        p_len = v * int(math.ceil(float(len(utf16_password)) / v))
        while len(p) < p_len:
            p += utf16_password
        p = p[0:p_len]

    # Step 4
    i = s + p

    # Step 5
    c = int(math.ceil(float(key_length) / u))

    a = b'\x00' * (c * u)

    for num in range(1, c + 1):
        # Step 6A
        a2 = algo(d + i).digest()
        for _ in range(2, iterations + 1):
            a2 = algo(a2).digest()

        if num < c:
            # Step 6B
            b = b''
            while len(b) < v:
                b += a2

            b = int_from_bytes(b[0:v]) + 1

            # Step 6C
            for num2 in range(0, len(i) // v):
                start = num2 * v
                end = (num2 + 1) * v
                i_num2 = i[start:end]

                i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)

                # Ensure the new slice is the right size
                i_num2_l = len(i_num2)
                if i_num2_l > v:
                    i_num2 = i_num2[i_num2_l - v:]

                i = i[0:start] + i_num2 + i[end:]

        # Step 7 (one peice at a time)
        begin = (num - 1) * u
        to_copy = min(key_length, u)
        a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]

    return a[0:key_length]
Example #9
0
    def test_int_to_bytes(self):
        self.assertEqual(util.int_to_bytes(0, False, 0), b'')
        self.assertEqual(util.int_to_bytes(0, False), b'\x00')
        self.assertEqual(util.int_to_bytes(0, False, 3), b'\x00\x00\x00')
        self.assertEqual(util.int_to_bytes(0, True, 0), b'')
        self.assertEqual(util.int_to_bytes(0, True), b'\x00')
        self.assertEqual(util.int_to_bytes(0, True, 3), b'\x00\x00\x00')

        self.assertEqual(util.int_to_bytes(128, False), b'\x80')
        self.assertEqual(util.int_to_bytes(128, False, 3), b'\x00\x00\x80')
        self.assertEqual(util.int_to_bytes(-128, True), b'\x80')
        self.assertEqual(util.int_to_bytes(-128, True, 3), b'\xff\xff\x80')

        self.assertEqual(util.int_to_bytes(255, False), b'\xff')
        self.assertEqual(util.int_to_bytes(255, False, 3), b'\x00\x00\xff')
        self.assertEqual(util.int_to_bytes(-1, True), b'\xff')
        self.assertEqual(util.int_to_bytes(-1, True, 3), b'\xff\xff\xff')

        self.assertEqual(util.int_to_bytes(12345678, False), b'\xbc\x61\x4e')
        self.assertEqual(util.int_to_bytes(12345678, False, 3),
                         b'\xbc\x61\x4e')
        self.assertEqual(util.int_to_bytes(12345678, False, 5),
                         b'\x00\x00\xbc\x61\x4e')
        self.assertEqual(util.int_to_bytes(12345678 - 2**24, True),
                         b'\xbc\x61\x4e')
        self.assertEqual(util.int_to_bytes(12345678 - 2**24, True, 3),
                         b'\xbc\x61\x4e')
        self.assertEqual(util.int_to_bytes(12345678 - 2**24, True, 5),
                         b'\xff\xff\xbc\x61\x4e')

        with self.assertRaises(OverflowError):
            util.int_to_bytes(123456789, width=3)
        with self.assertRaises(OverflowError):
            util.int_to_bytes(50000, signed=True, width=2)
Example #10
0
def add_pss_padding(hash_algorithm, salt_length, key_length, message):
    """
    Pads a byte string using the EMSA-PSS-Encode operation described in PKCS#1
    v2.2.

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :return:
        The encoded (passed) message
    """

    if not isinstance(message, byte_cls):
        raise TypeError(pretty_message(
            '''
            message must be a byte string, not %s
            ''',
            type_name(message)
        ))

    if not isinstance(salt_length, int_types):
        raise TypeError(pretty_message(
            '''
            salt_length must be an integer, not %s
            ''',
            type_name(salt_length)
        ))

    if salt_length < 0:
        raise ValueError(pretty_message(
            '''
            salt_length must be 0 or more - is %s
            ''',
            repr(salt_length)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 512:
        raise ValueError(pretty_message(
            '''
            key_length must be 512 or more - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    hash_func = getattr(hashlib, hash_algorithm)

    # The maximal bit size of a non-negative integer is one less than the bit
    # size of the key since the first bit is used to store sign
    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        raise ValueError(pretty_message(
            '''
            Key is not long enough to use with specified hash_algorithm and
            salt_length
            '''
        ))

    if salt_length > 0:
        salt = os.urandom(salt_length)
    else:
        salt = b''

    m_prime = (b'\x00' * 8) + message_digest + salt

    m_prime_digest = hash_func(m_prime).digest()

    padding = b'\x00' * (em_len - salt_length - hash_length - 2)

    db = padding + b'\x01' + salt

    db_mask = mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    masked_db = int_to_bytes(int_from_bytes(db) ^ int_from_bytes(db_mask))
    masked_db = fill_width(masked_db, len(db_mask))

    zero_bits = (8 * em_len) - em_bits
    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        masked_db = chr_cls(left_int_mask & ord(masked_db[0:1])) + masked_db[1:]

    return masked_db + m_prime_digest + b'\xBC'
Example #11
0
def verify_pss_padding(hash_algorithm, salt_length, key_length, message, signature):
    """
    Verifies the PSS padding on an encoded message

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :param signature:
        The signature to verify

    :return:
        A boolean indicating if the signature is invalid
    """

    if not isinstance(message, byte_cls):
        raise TypeError(pretty_message(
            '''
            message must be a byte string, not %s
            ''',
            type_name(message)
        ))

    if not isinstance(signature, byte_cls):
        raise TypeError(pretty_message(
            '''
            signature must be a byte string, not %s
            ''',
            type_name(signature)
        ))

    if not isinstance(salt_length, int_types):
        raise TypeError(pretty_message(
            '''
            salt_length must be an integer, not %s
            ''',
            type_name(salt_length)
        ))

    if salt_length < 0:
        raise ValueError(pretty_message(
            '''
            salt_length must be 0 or more - is %s
            ''',
            repr(salt_length)
        ))

    if hash_algorithm not in set(['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    hash_func = getattr(hashlib, hash_algorithm)

    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        return False

    if signature[-1:] != b'\xBC':
        return False

    zero_bits = (8 * em_len) - em_bits

    masked_db_length = em_len - hash_length - 1
    masked_db = signature[0:masked_db_length]

    first_byte = ord(masked_db[0:1])
    bits_that_should_be_zero = first_byte >> (8 - zero_bits)
    if bits_that_should_be_zero != 0:
        return False

    m_prime_digest = signature[masked_db_length:masked_db_length + hash_length]

    db_mask = mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        db_mask = chr_cls(left_int_mask & ord(db_mask[0:1])) + db_mask[1:]

    db = int_to_bytes(int_from_bytes(masked_db) ^ int_from_bytes(db_mask))
    if len(db) < len(masked_db):
        db = (b'\x00' * (len(masked_db) - len(db))) + db

    zero_length = em_len - hash_length - salt_length - 2
    zero_string = b'\x00' * zero_length
    if not constant_compare(db[0:zero_length], zero_string):
        return False

    if db[zero_length:zero_length + 1] != b'\x01':
        return False

    salt = db[0 - salt_length:]

    m_prime = (b'\x00' * 8) + message_digest + salt

    h_prime = hash_func(m_prime).digest()

    return constant_compare(m_prime_digest, h_prime)
Example #12
0
def raw_rsa_public_crypt(certificate_or_public_key, data):
    """
    Performs a raw RSA algorithm in a byte string using a certificate or
    public key. This is a low-level primitive and is prone to disastrous results
    if used incorrectly.

    :param certificate_or_public_key:
        An oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
        object

    :param data:
        A byte string of the signature when verifying, or padded plaintext when
        encrypting. Must be less than or equal to the length of the public key.
        When verifying, padding will need to be removed afterwards. When
        encrypting, padding must be applied before.

    :return:
        A byte string of the transformed data
    """

    if _backend != 'winlegacy':
        raise SystemError('Pure-python RSA crypt is only for Windows XP/2003')

    has_asn1 = hasattr(certificate_or_public_key, 'asn1')
    valid_types = (PublicKeyInfo, Certificate)
    if not has_asn1 or not isinstance(certificate_or_public_key.asn1, valid_types):
        raise TypeError(pretty_message(
            '''
            certificate_or_public_key must be an instance of the
            oscrypto.asymmetric.PublicKey or oscrypto.asymmetric.Certificate
            classes, not %s
            ''',
            type_name(certificate_or_public_key)
        ))

    algo = certificate_or_public_key.asn1['algorithm']['algorithm'].native
    if algo != 'rsa':
        raise ValueError(pretty_message(
            '''
            certificate_or_public_key must be an RSA key, not %s
            ''',
            algo.upper()
        ))

    if not isinstance(data, byte_cls):
        raise TypeError(pretty_message(
            '''
            data must be a byte string, not %s
            ''',
            type_name(data)
        ))

    rsa_public_key = certificate_or_public_key.asn1['public_key'].parsed
    transformed_int = pow(
        int_from_bytes(data),
        rsa_public_key['public_exponent'].native,
        rsa_public_key['modulus'].native
    )
    return int_to_bytes(
        transformed_int,
        width=certificate_or_public_key.asn1.byte_size
    )
Example #13
0
    def build_mpc(self, signing_private_key, orq_ip, orq_port):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it with a mpc engine

        :param signing_private_key:
            An integer identifier for the private key to sign the certificate with.
            This identifier permits the mpc engine to find the private key associated
            with the public key exported in the key generation process

        :return:
            An asn1crypto.x509.Certificate object of the newly signed
            certificate
        """
        def rsa_mpc_sign(signing_private_key, tbs_cert_dump, hash_algo, orq_ip,
                         orq_port):

            # Calculate hash corresponding to tbs_cert_dump	[Only SHA256]
            digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
            digest.update(tbs_cert_dump)
            digest = hexlify(digest.finalize())
            print("[Client] Hash: " + digest)

            # Send HTTP GET request to the orquestrator for signing
            target = "http://" + orq_ip + ":" + orq_port + "/signMessage/" + str(
                signing_private_key) + "?message=" + str(digest)
            r = requests.get(target)

            response = dict(json.loads(r.text))

            # Sign format must be bytearray
            print("Firma: " + str(response["sign"]))
            firma = unhexlify(response["sign"])
            return firma

        if self._self_signed is not True and self._issuer is None:
            raise ValueError(
                _pretty_message('''
                Certificate must be self-signed, or an issuer must be specified
                '''))

        if self._self_signed:
            self._issuer = self._subject

        if self._serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = util.rand_bytes(4)
            self._serial_number = int_from_bytes(time_part + random_part)

        if self._begin_date is None:
            self._begin_date = datetime.now(timezone.utc)

        if self._end_date is None:
            self._end_date = self._begin_date + timedelta(365)

        if not self.ca:
            for ca_only_extension in set([
                    'policy_mappings', 'policy_constraints',
                    'inhibit_any_policy'
            ]):
                if ca_only_extension in self._other_extensions:
                    raise ValueError(
                        _pretty_message(
                            '''
                        Extension %s is only valid for CA certificates
                        ''', ca_only_extension))

        # The algorith is forced to be 'rsa'
        signature_algo = 'rsa'

        # Hash's algorithm limited to SHA256 (internally)
        signature_algorithm_id = '%s_%s' % (self._hash_algo, signature_algo)

        def _make_extension(name, value):
            return {
                'extn_id': name,
                'critical': self._determine_critical(name),
                'extn_value': value
            }

        extensions = []
        for name in sorted(self._special_extensions):
            value = getattr(self, '_%s' % name)
            if name == 'ocsp_no_check':
                value = core.Null() if value else None
            if value is not None:
                extensions.append(_make_extension(name, value))

        for name in sorted(self._other_extensions.keys()):
            extensions.append(
                _make_extension(name, self._other_extensions[name]))

        tbs_cert = x509.TbsCertificate({
            'version': 'v3',
            'serial_number': self._serial_number,
            'signature': {
                'algorithm': signature_algorithm_id
            },
            'issuer': self._issuer,
            'validity': {
                'not_before': x509.Time(name='utc_time',
                                        value=self._begin_date),
                'not_after': x509.Time(name='utc_time', value=self._end_date),
            },
            'subject': self._subject,
            'subject_public_key_info': self._subject_public_key,
            'extensions': extensions
        })

        # Function binding
        sign_func = rsa_mpc_sign
        print(orq_ip)
        print(orq_port)
        signature = sign_func(signing_private_key, tbs_cert.dump(),
                              self._hash_algo, orq_ip, orq_port)

        return x509.Certificate({
            'tbs_certificate': tbs_cert,
            'signature_algorithm': {
                'algorithm': signature_algorithm_id
            },
            'signature_value': signature
        })
Example #14
0
    def _handshake(self):
        """
        Perform an initial TLS handshake
        """

        session_context = None
        ssl_policy_ref = None
        crl_search_ref = None
        crl_policy_ref = None
        ocsp_search_ref = None
        ocsp_policy_ref = None
        policy_array_ref = None

        try:
            if osx_version_info < (10, 8):
                session_context_pointer = new(Security, 'SSLContextRef *')
                result = Security.SSLNewContext(False, session_context_pointer)
                handle_sec_error(result)
                session_context = unwrap(session_context_pointer)

            else:
                session_context = Security.SSLCreateContext(
                    null(),
                    SecurityConst.kSSLClientSide,
                    SecurityConst.kSSLStreamType
                )

            result = Security.SSLSetIOFuncs(
                session_context,
                _read_callback_pointer,
                _write_callback_pointer
            )
            handle_sec_error(result)

            self._connection_id = id(self) % 2147483647
            _connection_refs[self._connection_id] = self
            _socket_refs[self._connection_id] = self._socket
            result = Security.SSLSetConnection(session_context, self._connection_id)
            handle_sec_error(result)

            utf8_domain = self._hostname.encode('utf-8')
            result = Security.SSLSetPeerDomainName(
                session_context,
                utf8_domain,
                len(utf8_domain)
            )
            handle_sec_error(result)

            if osx_version_info >= (10, 10):
                disable_auto_validation = self._session._manual_validation or self._session._extra_trust_roots
                explicit_validation = (not self._session._manual_validation) and self._session._extra_trust_roots
            else:
                disable_auto_validation = True
                explicit_validation = not self._session._manual_validation

            # Ensure requested protocol support is set for the session
            if osx_version_info < (10, 8):
                for protocol in ['SSLv2', 'SSLv3', 'TLSv1']:
                    protocol_const = _PROTOCOL_STRING_CONST_MAP[protocol]
                    enabled = protocol in self._session._protocols
                    result = Security.SSLSetProtocolVersionEnabled(
                        session_context,
                        protocol_const,
                        enabled
                    )
                    handle_sec_error(result)

                if disable_auto_validation:
                    result = Security.SSLSetEnableCertVerify(session_context, False)
                    handle_sec_error(result)

            else:
                protocol_consts = [_PROTOCOL_STRING_CONST_MAP[protocol] for protocol in self._session._protocols]
                min_protocol = min(protocol_consts)
                max_protocol = max(protocol_consts)
                result = Security.SSLSetProtocolVersionMin(
                    session_context,
                    min_protocol
                )
                handle_sec_error(result)
                result = Security.SSLSetProtocolVersionMax(
                    session_context,
                    max_protocol
                )
                handle_sec_error(result)

                if disable_auto_validation:
                    result = Security.SSLSetSessionOption(
                        session_context,
                        SecurityConst.kSSLSessionOptionBreakOnServerAuth,
                        True
                    )
                    handle_sec_error(result)

            # Disable all sorts of bad cipher suites
            supported_ciphers_pointer = new(Security, 'size_t *')
            result = Security.SSLGetNumberSupportedCiphers(session_context, supported_ciphers_pointer)
            handle_sec_error(result)

            supported_ciphers = deref(supported_ciphers_pointer)

            cipher_buffer = buffer_from_bytes(supported_ciphers * 4)
            supported_cipher_suites_pointer = cast(Security, 'uint32_t *', cipher_buffer)
            result = Security.SSLGetSupportedCiphers(
                session_context,
                supported_cipher_suites_pointer,
                supported_ciphers_pointer
            )
            handle_sec_error(result)

            supported_ciphers = deref(supported_ciphers_pointer)
            supported_cipher_suites = array_from_pointer(
                Security,
                'uint32_t',
                supported_cipher_suites_pointer,
                supported_ciphers
            )
            good_ciphers = []
            for supported_cipher_suite in supported_cipher_suites:
                cipher_suite = int_to_bytes(supported_cipher_suite, width=2)
                cipher_suite_name = CIPHER_SUITE_MAP.get(cipher_suite, cipher_suite)
                good_cipher = _cipher_blacklist_regex.search(cipher_suite_name) is None
                if good_cipher:
                    good_ciphers.append(supported_cipher_suite)

            num_good_ciphers = len(good_ciphers)
            good_ciphers_array = new(Security, 'uint32_t[]', num_good_ciphers)
            array_set(good_ciphers_array, good_ciphers)
            good_ciphers_pointer = cast(Security, 'uint32_t *', good_ciphers_array)
            result = Security.SSLSetEnabledCiphers(
                session_context,
                good_ciphers_pointer,
                num_good_ciphers
            )
            handle_sec_error(result)

            # Set a peer id from the session to allow for session reuse, the hostname
            # is appended to prevent a bug on OS X 10.7 where it tries to reuse a
            # connection even if the hostnames are different.
            peer_id = self._session._peer_id + self._hostname.encode('utf-8')
            result = Security.SSLSetPeerID(session_context, peer_id, len(peer_id))
            handle_sec_error(result)

            handshake_result = Security.SSLHandshake(session_context)
            while handshake_result == SecurityConst.errSSLWouldBlock:
                handshake_result = Security.SSLHandshake(session_context)

            if osx_version_info < (10, 8) and osx_version_info >= (10, 7):
                do_validation = explicit_validation and handshake_result == 0
            else:
                do_validation = explicit_validation and handshake_result == SecurityConst.errSSLServerAuthCompleted

            if do_validation:
                trust_ref_pointer = new(Security, 'SecTrustRef *')
                result = Security.SSLCopyPeerTrust(
                    session_context,
                    trust_ref_pointer
                )
                handle_sec_error(result)
                trust_ref = unwrap(trust_ref_pointer)

                cf_string_hostname = CFHelpers.cf_string_from_unicode(self._hostname)
                ssl_policy_ref = Security.SecPolicyCreateSSL(True, cf_string_hostname)
                result = CoreFoundation.CFRelease(cf_string_hostname)
                handle_cf_error(result)

                # Create a new policy for OCSP checking to disable it
                ocsp_oid_pointer = struct(Security, 'CSSM_OID')
                ocsp_oid = unwrap(ocsp_oid_pointer)
                ocsp_oid.Length = len(SecurityConst.APPLE_TP_REVOCATION_OCSP)
                ocsp_oid_buffer = buffer_from_bytes(SecurityConst.APPLE_TP_REVOCATION_OCSP)
                ocsp_oid.Data = cast(Security, 'char *', ocsp_oid_buffer)

                ocsp_search_ref_pointer = new(Security, 'SecPolicySearchRef *')
                result = Security.SecPolicySearchCreate(
                    SecurityConst.CSSM_CERT_X_509v3,
                    ocsp_oid_pointer,
                    null(),
                    ocsp_search_ref_pointer
                )
                handle_sec_error(result)
                ocsp_search_ref = unwrap(ocsp_search_ref_pointer)

                ocsp_policy_ref_pointer = new(Security, 'SecPolicyRef *')
                result = Security.SecPolicySearchCopyNext(ocsp_search_ref, ocsp_policy_ref_pointer)
                handle_sec_error(result)
                ocsp_policy_ref = unwrap(ocsp_policy_ref_pointer)

                ocsp_struct_pointer = struct(Security, 'CSSM_APPLE_TP_OCSP_OPTIONS')
                ocsp_struct = unwrap(ocsp_struct_pointer)
                ocsp_struct.Version = SecurityConst.CSSM_APPLE_TP_OCSP_OPTS_VERSION
                ocsp_struct.Flags = (
                    SecurityConst.CSSM_TP_ACTION_OCSP_DISABLE_NET |
                    SecurityConst.CSSM_TP_ACTION_OCSP_CACHE_READ_DISABLE
                )
                ocsp_struct_bytes = struct_bytes(ocsp_struct_pointer)

                cssm_data_pointer = struct(Security, 'CSSM_DATA')
                cssm_data = unwrap(cssm_data_pointer)
                cssm_data.Length = len(ocsp_struct_bytes)
                ocsp_struct_buffer = buffer_from_bytes(ocsp_struct_bytes)
                cssm_data.Data = cast(Security, 'char *', ocsp_struct_buffer)

                result = Security.SecPolicySetValue(ocsp_policy_ref, cssm_data_pointer)
                handle_sec_error(result)

                # Create a new policy for CRL checking to disable it
                crl_oid_pointer = struct(Security, 'CSSM_OID')
                crl_oid = unwrap(crl_oid_pointer)
                crl_oid.Length = len(SecurityConst.APPLE_TP_REVOCATION_CRL)
                crl_oid_buffer = buffer_from_bytes(SecurityConst.APPLE_TP_REVOCATION_CRL)
                crl_oid.Data = cast(Security, 'char *', crl_oid_buffer)

                crl_search_ref_pointer = new(Security, 'SecPolicySearchRef *')
                result = Security.SecPolicySearchCreate(
                    SecurityConst.CSSM_CERT_X_509v3,
                    crl_oid_pointer,
                    null(),
                    crl_search_ref_pointer
                )
                handle_sec_error(result)
                crl_search_ref = unwrap(crl_search_ref_pointer)

                crl_policy_ref_pointer = new(Security, 'SecPolicyRef *')
                result = Security.SecPolicySearchCopyNext(crl_search_ref, crl_policy_ref_pointer)
                handle_sec_error(result)
                crl_policy_ref = unwrap(crl_policy_ref_pointer)

                crl_struct_pointer = struct(Security, 'CSSM_APPLE_TP_CRL_OPTIONS')
                crl_struct = unwrap(crl_struct_pointer)
                crl_struct.Version = SecurityConst.CSSM_APPLE_TP_CRL_OPTS_VERSION
                crl_struct.CrlFlags = 0
                crl_struct_bytes = struct_bytes(crl_struct_pointer)

                cssm_data_pointer = struct(Security, 'CSSM_DATA')
                cssm_data = unwrap(cssm_data_pointer)
                cssm_data.Length = len(crl_struct_bytes)
                crl_struct_buffer = buffer_from_bytes(crl_struct_bytes)
                cssm_data.Data = cast(Security, 'char *', crl_struct_buffer)

                result = Security.SecPolicySetValue(crl_policy_ref, cssm_data_pointer)
                handle_sec_error(result)

                policy_array_ref = CFHelpers.cf_array_from_list([
                    ssl_policy_ref,
                    crl_policy_ref,
                    ocsp_policy_ref
                ])

                result = Security.SecTrustSetPolicies(trust_ref, policy_array_ref)
                handle_sec_error(result)

                if self._session._extra_trust_roots:
                    ca_cert_refs = []
                    ca_certs = []
                    for cert in self._session._extra_trust_roots:
                        ca_cert = load_certificate(cert)
                        ca_certs.append(ca_cert)
                        ca_cert_refs.append(ca_cert.sec_certificate_ref)

                    result = Security.SecTrustSetAnchorCertificatesOnly(trust_ref, False)
                    handle_sec_error(result)

                    array_ref = CFHelpers.cf_array_from_list(ca_cert_refs)
                    result = Security.SecTrustSetAnchorCertificates(trust_ref, array_ref)
                    handle_sec_error(result)

                result_pointer = new(Security, 'SecTrustResultType *')
                result = Security.SecTrustEvaluate(trust_ref, result_pointer)
                handle_sec_error(result)

                trust_result_code = deref(result_pointer)
                invalid_chain_error_codes = set([
                    SecurityConst.kSecTrustResultProceed,
                    SecurityConst.kSecTrustResultUnspecified
                ])
                if trust_result_code not in invalid_chain_error_codes:
                    handshake_result = SecurityConst.errSSLXCertChainInvalid
                else:
                    handshake_result = Security.SSLHandshake(session_context)
                    while handshake_result == SecurityConst.errSSLWouldBlock:
                        handshake_result = Security.SSLHandshake(session_context)

            self._done_handshake = True

            handshake_error_codes = set([
                SecurityConst.errSSLXCertChainInvalid,
                SecurityConst.errSSLCertExpired,
                SecurityConst.errSSLCertNotYetValid,
                SecurityConst.errSSLUnknownRootCert,
                SecurityConst.errSSLNoRootCert,
                SecurityConst.errSSLHostNameMismatch
            ])

            # In testing, only errSSLXCertChainInvalid was ever returned for
            # all of these different situations, however we include the others
            # for completeness. To get the real reason we have to use the
            # certificate from the handshake and use the deprecated function
            # SecTrustGetCssmResultCode().
            if handshake_result in handshake_error_codes:
                trust_ref_pointer = new(Security, 'SecTrustRef *')
                result = Security.SSLCopyPeerTrust(
                    session_context,
                    trust_ref_pointer
                )
                handle_sec_error(result)
                trust_ref = unwrap(trust_ref_pointer)

                result_code_pointer = new(Security, 'OSStatus *')
                result = Security.SecTrustGetCssmResultCode(trust_ref, result_code_pointer)
                result_code = deref(result_code_pointer)

                chain = extract_chain(self._server_hello)

                self_signed = False
                expired = False
                not_yet_valid = False
                no_issuer = False
                cert = None
                bad_hostname = False

                if chain:
                    cert = chain[0]
                    oscrypto_cert = load_certificate(cert)
                    self_signed = oscrypto_cert.self_signed
                    no_issuer = not self_signed and result_code == SecurityConst.CSSMERR_TP_NOT_TRUSTED
                    expired = result_code == SecurityConst.CSSMERR_TP_CERT_EXPIRED
                    not_yet_valid = result_code == SecurityConst.CSSMERR_TP_CERT_NOT_VALID_YET
                    bad_hostname = result_code == SecurityConst.CSSMERR_APPLETP_HOSTNAME_MISMATCH

                if chain and chain[0].hash_algo in set(['md5', 'md2']):
                    raise_weak_signature(chain[0])

                if bad_hostname:
                    raise_hostname(cert, self._hostname)

                elif expired or not_yet_valid:
                    raise_expired_not_yet_valid(cert)

                elif no_issuer:
                    raise_no_issuer(cert)

                elif self_signed:
                    raise_self_signed(cert)

                if detect_client_auth_request(self._server_hello):
                    raise_client_auth()

                raise_verification(cert)

            if handshake_result == SecurityConst.errSSLPeerHandshakeFail:
                if detect_client_auth_request(self._server_hello):
                    raise_client_auth()
                raise_handshake()

            if handshake_result == SecurityConst.errSSLWeakPeerEphemeralDHKey:
                raise_dh_params()

            if handshake_result in set([SecurityConst.errSSLRecordOverflow, SecurityConst.errSSLProtocol]):
                self._server_hello += _read_remaining(self._socket)
                raise_protocol_error(self._server_hello)

            if handshake_result in set([SecurityConst.errSSLClosedNoNotify, SecurityConst.errSSLClosedAbort]):
                if not self._done_handshake:
                    self._server_hello += _read_remaining(self._socket)
                if detect_other_protocol(self._server_hello):
                    raise_protocol_error(self._server_hello)
                raise_disconnection()

            if osx_version_info < (10, 10):
                dh_params_length = get_dh_params_length(self._server_hello)
                if dh_params_length is not None and dh_params_length < 1024:
                    raise_dh_params()

            if handshake_result != SecurityConst.errSSLWouldBlock:
                handle_sec_error(handshake_result, TLSError)

            self._session_context = session_context

            protocol_const_pointer = new(Security, 'SSLProtocol *')
            result = Security.SSLGetNegotiatedProtocolVersion(
                session_context,
                protocol_const_pointer
            )
            handle_sec_error(result)
            protocol_const = deref(protocol_const_pointer)

            self._protocol = _PROTOCOL_CONST_STRING_MAP[protocol_const]

            cipher_int_pointer = new(Security, 'SSLCipherSuite *')
            result = Security.SSLGetNegotiatedCipher(
                session_context,
                cipher_int_pointer
            )
            handle_sec_error(result)
            cipher_int = deref(cipher_int_pointer)

            cipher_bytes = int_to_bytes(cipher_int, width=2)
            self._cipher_suite = CIPHER_SUITE_MAP.get(cipher_bytes, cipher_bytes)

            session_info = parse_session_info(
                self._server_hello,
                self._client_hello
            )
            self._compression = session_info['compression']
            self._session_id = session_info['session_id']
            self._session_ticket = session_info['session_ticket']

        except (OSError, socket_.error):
            if session_context:
                if osx_version_info < (10, 8):
                    result = Security.SSLDisposeContext(session_context)
                    handle_sec_error(result)
                else:
                    result = CoreFoundation.CFRelease(session_context)
                    handle_cf_error(result)

            self._session_context = None
            self.close()

            raise

        finally:
            # Trying to release crl_search_ref or ocsp_search_ref results in
            # a segmentation fault, so we do not do that

            if ssl_policy_ref:
                result = CoreFoundation.CFRelease(ssl_policy_ref)
                handle_cf_error(result)
                ssl_policy_ref = None

            if crl_policy_ref:
                result = CoreFoundation.CFRelease(crl_policy_ref)
                handle_cf_error(result)
                crl_policy_ref = None

            if ocsp_policy_ref:
                result = CoreFoundation.CFRelease(ocsp_policy_ref)
                handle_cf_error(result)
                ocsp_policy_ref = None

            if policy_array_ref:
                result = CoreFoundation.CFRelease(policy_array_ref)
                handle_cf_error(result)
                policy_array_ref = None
Example #15
0
    def build(self, signing_private_key):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it

        :param signing_private_key:
            An asn1crypto.keys.PrivateKeyInfo or oscrypto.asymmetric.PrivateKey
            object for the private key to sign the certificate with. If the key
            is self-signed, this should be the private key that matches the
            public key, otherwise it needs to be the issuer's private key.

        :return:
            An asn1crypto.x509.Certificate object of the newly signed
            certificate
        """

        is_oscrypto = isinstance(signing_private_key, asymmetric.PrivateKey)
        if not isinstance(
                signing_private_key, keys.PrivateKeyInfo
        ) and not is_oscrypto and signing_private_key is not None:
            raise TypeError(
                _pretty_message(
                    '''
                signing_private_key must be an instance of
                asn1crypto.keys.PrivateKeyInfo or
                oscrypto.asymmetric.PrivateKey, not %s
                ''', _type_name(signing_private_key)))

        if self._self_signed is not True and self._issuer is None:
            raise ValueError(
                _pretty_message('''
                Certificate must be self-signed, or an issuer must be specified
                '''))

        if self._self_signed:
            self._issuer = self._subject

        if self._serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = util.rand_bytes(4)
            self._serial_number = int_from_bytes(time_part + random_part)

        if self._begin_date is None:
            self._begin_date = datetime.now(timezone.utc)

        if self._end_date is None:
            self._end_date = self._begin_date + timedelta(365)

        if not self.ca:
            for ca_only_extension in set([
                    'policy_mappings', 'policy_constraints',
                    'inhibit_any_policy'
            ]):
                if ca_only_extension in self._other_extensions:
                    raise ValueError(
                        _pretty_message(
                            '''
                        Extension %s is only valid for CA certificates
                        ''', ca_only_extension))

        if signing_private_key is not None:
            signature_algo = signing_private_key.algorithm
            if signature_algo == 'ec':
                signature_algo = 'ecdsa'

            signature_algorithm_id = '%s_%s' % (self._hash_algo,
                                                signature_algo)
        else:
            signature_algorithm_id = '%s_%s' % (
                self._hash_algo, "rsa")  #making rsa assumption for ease

        # RFC 3280 4.1.2.5
        def _make_validity_time(dt):
            if dt < datetime(2050, 1, 1, tzinfo=timezone.utc):
                value = x509.Time(name='utc_time', value=dt)
            else:
                value = x509.Time(name='general_time', value=dt)

            return value

        def _make_extension(name, value):
            return {
                'extn_id': name,
                'critical': self._determine_critical(name),
                'extn_value': value
            }

        extensions = []
        for name in sorted(self._special_extensions):
            value = getattr(self, '_%s' % name)
            if name == 'ocsp_no_check':
                value = core.Null() if value else None
            if value is not None:
                extensions.append(_make_extension(name, value))

        for name in sorted(self._other_extensions.keys()):
            extensions.append(
                _make_extension(name, self._other_extensions[name]))

        tbs_cert = x509.TbsCertificate({
            'version': 'v3',
            'serial_number': self._serial_number,
            'signature': {
                'algorithm': signature_algorithm_id
            },
            'issuer': self._issuer,
            'validity': {
                'not_before': _make_validity_time(self._begin_date),
                'not_after': _make_validity_time(self._end_date),
            },
            'subject': self._subject,
            'subject_public_key_info': self._subject_public_key,
            'extensions': extensions
        })

        if signing_private_key is None:
            return tbs_cert
        elif signing_private_key.algorithm == 'rsa':
            sign_func = asymmetric.rsa_pkcs1v15_sign
        elif signing_private_key.algorithm == 'dsa':
            sign_func = asymmetric.dsa_sign
        elif signing_private_key.algorithm == 'ec':
            sign_func = asymmetric.ecdsa_sign

        if not is_oscrypto:
            signing_private_key = asymmetric.load_private_key(
                signing_private_key)
        signature = sign_func(signing_private_key, tbs_cert.dump(),
                              self._hash_algo)

        return x509.Certificate({
            'tbs_certificate': tbs_cert,
            'signature_algorithm': {
                'algorithm': signature_algorithm_id
            },
            'signature_value': signature
        })
Example #16
0
def pbkdf2(hash_algorithm, password, salt, iterations, key_length):
    """
    Implements PBKDF2 from PKCS#5 v2.2 in pure Python

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(pretty_message(
            '''
            password must be a byte string, not %s
            ''',
            type_name(password)
        ))

    if not isinstance(salt, byte_cls):
        raise TypeError(pretty_message(
            '''
            salt must be a byte string, not %s
            ''',
            type_name(salt)
        ))

    if not isinstance(iterations, int_types):
        raise TypeError(pretty_message(
            '''
            iterations must be an integer, not %s
            ''',
            type_name(iterations)
        ))

    if iterations < 1:
        raise ValueError(pretty_message(
            '''
            iterations must be greater than 0 - is %s
            ''',
            repr(iterations)
        ))

    if not isinstance(key_length, int_types):
        raise TypeError(pretty_message(
            '''
            key_length must be an integer, not %s
            ''',
            type_name(key_length)
        ))

    if key_length < 1:
        raise ValueError(pretty_message(
            '''
            key_length must be greater than 0 - is %s
            ''',
            repr(key_length)
        ))

    if hash_algorithm not in set(['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(pretty_message(
            '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''',
            repr(hash_algorithm)
        ))

    algo = getattr(hashlib, hash_algorithm)

    hash_length = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    blocks = int(math.ceil(key_length / hash_length))

    original_hmac = hmac.new(password, None, algo)

    int_pack = struct.Struct(b'>I').pack

    output = b''
    for block in range(1, blocks + 1):
        prf = original_hmac.copy()
        prf.update(salt + int_pack(block))
        last = prf.digest()
        u = int_from_bytes(last)
        for _ in range(2, iterations + 1):
            prf = original_hmac.copy()
            prf.update(last)
            last = prf.digest()
            u ^= int_from_bytes(last)
        t = int_to_bytes(u)
        output += t

    return output[0:key_length]
Example #17
0
def pkcs12_kdf(hash_algorithm, password, salt, iterations, key_length, id_):
    """
    KDF from RFC7292 appendix b.2 - https://tools.ietf.org/html/rfc7292#page-19

    :param hash_algorithm:
        The string name of the hash algorithm to use: "md5", "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param password:
        A byte string of the password to use an input to the KDF

    :param salt:
        A cryptographic random byte string

    :param iterations:
        The numbers of iterations to use when deriving the key

    :param key_length:
        The length of the desired key in bytes

    :param id_:
        The ID of the usage - 1 for key, 2 for iv, 3 for mac

    :return:
        The derived key as a byte string
    """

    if not isinstance(password, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            password must be a byte string, not %s
            ''', type_name(password)))

    if not isinstance(salt, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            salt must be a byte string, not %s
            ''', type_name(salt)))

    if not isinstance(iterations, int_types):
        raise TypeError(
            pretty_message(
                '''
            iterations must be an integer, not %s
            ''', type_name(iterations)))

    if iterations < 1:
        raise ValueError(
            pretty_message(
                '''
            iterations must be greater than 0 - is %s
            ''', repr(iterations)))

    if not isinstance(key_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            key_length must be an integer, not %s
            ''', type_name(key_length)))

    if key_length < 1:
        raise ValueError(
            pretty_message(
                '''
            key_length must be greater than 0 - is %s
            ''', repr(key_length)))

    if hash_algorithm not in set(
        ['md5', 'sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "md5", "sha1", "sha224", "sha256",
            "sha384", "sha512", not %s
            ''', repr(hash_algorithm)))

    if id_ not in set([1, 2, 3]):
        raise ValueError(
            pretty_message(
                '''
            id_ must be one of 1, 2, 3, not %s
            ''', repr(id_)))

    utf16_password = password.decode('utf-8').encode('utf-16be') + b'\x00\x00'

    algo = getattr(hashlib, hash_algorithm)

    # u and v values are bytes (not bits as in the RFC)
    u = {
        'md5': 16,
        'sha1': 20,
        'sha224': 28,
        'sha256': 32,
        'sha384': 48,
        'sha512': 64
    }[hash_algorithm]

    if hash_algorithm in ['sha384', 'sha512']:
        v = 128
    else:
        v = 64

    # Step 1
    d = chr_cls(id_) * v

    # Step 2
    s = b''
    if salt != b'':
        s_len = v * int(math.ceil(float(len(salt)) / v))
        while len(s) < s_len:
            s += salt
        s = s[0:s_len]

    # Step 3
    p = b''
    if utf16_password != b'':
        p_len = v * int(math.ceil(float(len(utf16_password)) / v))
        while len(p) < p_len:
            p += utf16_password
        p = p[0:p_len]

    # Step 4
    i = s + p

    # Step 5
    c = int(math.ceil(float(key_length) / u))

    a = b'\x00' * (c * u)

    for num in range(1, c + 1):
        # Step 6A
        a2 = algo(d + i).digest()
        for _ in range(2, iterations + 1):
            a2 = algo(a2).digest()

        if num < c:
            # Step 6B
            b = b''
            while len(b) < v:
                b += a2

            b = int_from_bytes(b[0:v]) + 1

            # Step 6C
            for num2 in range(0, len(i) // v):
                start = num2 * v
                end = (num2 + 1) * v
                i_num2 = i[start:end]

                i_num2 = int_to_bytes(int_from_bytes(i_num2) + b)

                # Ensure the new slice is the right size
                i_num2_l = len(i_num2)
                if i_num2_l > v:
                    i_num2 = i_num2[i_num2_l - v:]

                i = i[0:start] + i_num2 + i[end:]

        # Step 7 (one peice at a time)
        begin = (num - 1) * u
        to_copy = min(key_length, u)
        a = a[0:begin] + a2[0:to_copy] + a[begin + to_copy:]

    return a[0:key_length]
Example #18
0
    def build(self, signing_private_key):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it

        :param signing_private_key:
            An asn1crypto.keys.PrivateKeyInfo or oscrypto.asymmetric.PrivateKey
            object for the private key to sign the certificate with. If the key
            is self-signed, this should be the private key that matches the
            public key, otherwise it needs to be the issuer's private key.

        :return:
            An asn1crypto.x509.Certificate object of the newly signed
            certificate
        """

        is_oscrypto = isinstance(signing_private_key, asymmetric.PrivateKey)
        if not isinstance(signing_private_key, keys.PrivateKeyInfo) and not is_oscrypto:
            raise TypeError(_pretty_message(
                '''
                signing_private_key must be an instance of
                asn1crypto.keys.PrivateKeyInfo or
                oscrypto.asymmetric.PrivateKey, not %s
                ''',
                _type_name(signing_private_key)
            ))

        if self._self_signed is not True and self._issuer is None:
            raise ValueError(_pretty_message(
                '''
                Certificate must be self-signed, or an issuer must be specified
                '''
            ))

        if self._self_signed:
            self._issuer = self._subject

        if self._serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = util.rand_bytes(4)
            self._serial_number = int_from_bytes(time_part + random_part)

        if self._begin_date is None:
            self._begin_date = datetime.now(timezone.utc)

        if self._end_date is None:
            self._end_date = self._begin_date + timedelta(365)

        if not self.ca:
            for ca_only_extension in set(['policy_mappings', 'policy_constraints', 'inhibit_any_policy']):
                if ca_only_extension in self._other_extensions:
                    raise ValueError(_pretty_message(
                        '''
                        Extension %s is only valid for CA certificates
                        ''',
                        ca_only_extension
                    ))

        signature_algo = signing_private_key.algorithm
        if signature_algo == 'ec':
            signature_algo = 'ecdsa'

        signature_algorithm_id = '%s_%s' % (self._hash_algo, signature_algo)

        # RFC 3280 4.1.2.5
        def _make_validity_time(dt):
            if dt < datetime(2050, 1, 1, tzinfo=timezone.utc):
                value = x509.Time(name='utc_time', value=dt)
            else:
                value = x509.Time(name='general_time', value=dt)

            return value

        def _make_extension(name, value):
            return {
                'extn_id': name,
                'critical': self._determine_critical(name),
                'extn_value': value
            }

        extensions = []
        for name in sorted(self._special_extensions):
            value = getattr(self, '_%s' % name)
            if name == 'ocsp_no_check':
                value = core.Null() if value else None
            if value is not None:
                extensions.append(_make_extension(name, value))

        for name in sorted(self._other_extensions.keys()):
            extensions.append(_make_extension(name, self._other_extensions[name]))

        tbs_cert = x509.TbsCertificate({
            'version': 'v3',
            'serial_number': self._serial_number,
            'signature': {
                'algorithm': signature_algorithm_id
            },
            'issuer': self._issuer,
            'validity': {
                'not_before': _make_validity_time(self._begin_date),
                'not_after': _make_validity_time(self._end_date),
            },
            'subject': self._subject,
            'subject_public_key_info': self._subject_public_key,
            'extensions': extensions
        })

        if signing_private_key.algorithm == 'rsa':
            sign_func = asymmetric.rsa_pkcs1v15_sign
        elif signing_private_key.algorithm == 'dsa':
            sign_func = asymmetric.dsa_sign
        elif signing_private_key.algorithm == 'ec':
            sign_func = asymmetric.ecdsa_sign

        if not is_oscrypto:
            signing_private_key = asymmetric.load_private_key(signing_private_key)
        signature = sign_func(signing_private_key, tbs_cert.dump(), self._hash_algo)

        return x509.Certificate({
            'tbs_certificate': tbs_cert,
            'signature_algorithm': {
                'algorithm': signature_algorithm_id
            },
            'signature_value': signature
        })
Example #19
0
    def _handshake(self):
        """
        Perform an initial TLS handshake
        """

        session_context = None

        try:
            if osx_version_info < (10, 8):
                session_context_pointer = new(Security, 'SSLContextRef *')
                result = Security.SSLNewContext(False, session_context_pointer)
                handle_sec_error(result)
                session_context = unwrap(session_context_pointer)

            else:
                session_context = Security.SSLCreateContext(
                    null(), SecurityConst.kSSLClientSide,
                    SecurityConst.kSSLStreamType)

            result = Security.SSLSetIOFuncs(session_context,
                                            _read_callback_pointer,
                                            _write_callback_pointer)
            handle_sec_error(result)

            self._connection_id = id(self) % 2147483647
            _connection_refs[self._connection_id] = self
            _socket_refs[self._connection_id] = self._socket
            result = Security.SSLSetConnection(session_context,
                                               self._connection_id)
            handle_sec_error(result)

            utf8_domain = self._hostname.encode('utf-8')
            result = Security.SSLSetPeerDomainName(session_context,
                                                   utf8_domain,
                                                   len(utf8_domain))
            handle_sec_error(result)

            disable_auto_validation = self._session._manual_validation or self._session._extra_trust_roots
            explicit_validation = (not self._session._manual_validation
                                   ) and self._session._extra_trust_roots

            # Ensure requested protocol support is set for the session
            if osx_version_info < (10, 8):
                for protocol in ['SSLv2', 'SSLv3', 'TLSv1']:
                    protocol_const = _PROTOCOL_STRING_CONST_MAP[protocol]
                    enabled = protocol in self._session._protocols
                    result = Security.SSLSetProtocolVersionEnabled(
                        session_context, protocol_const, enabled)
                    handle_sec_error(result)

                if disable_auto_validation:
                    result = Security.SSLSetEnableCertVerify(
                        session_context, False)
                    handle_sec_error(result)

            else:
                protocol_consts = [
                    _PROTOCOL_STRING_CONST_MAP[protocol]
                    for protocol in self._session._protocols
                ]
                min_protocol = min(protocol_consts)
                max_protocol = max(protocol_consts)
                result = Security.SSLSetProtocolVersionMin(
                    session_context, min_protocol)
                handle_sec_error(result)
                result = Security.SSLSetProtocolVersionMax(
                    session_context, max_protocol)
                handle_sec_error(result)

                if disable_auto_validation:
                    result = Security.SSLSetSessionOption(
                        session_context,
                        SecurityConst.kSSLSessionOptionBreakOnServerAuth, True)
                    handle_sec_error(result)

            # Disable all sorts of bad cipher suites
            supported_ciphers_pointer = new(Security, 'size_t *')
            result = Security.SSLGetNumberSupportedCiphers(
                session_context, supported_ciphers_pointer)
            handle_sec_error(result)

            supported_ciphers = deref(supported_ciphers_pointer)

            cipher_buffer = buffer_from_bytes(supported_ciphers * 4)
            supported_cipher_suites_pointer = cast(Security, 'uint32_t *',
                                                   cipher_buffer)
            result = Security.SSLGetSupportedCiphers(
                session_context, supported_cipher_suites_pointer,
                supported_ciphers_pointer)
            handle_sec_error(result)

            supported_ciphers = deref(supported_ciphers_pointer)
            supported_cipher_suites = array_from_pointer(
                Security, 'uint32_t', supported_cipher_suites_pointer,
                supported_ciphers)
            good_ciphers = []
            for supported_cipher_suite in supported_cipher_suites:
                cipher_suite = int_to_bytes(supported_cipher_suite, width=2)
                cipher_suite_name = CIPHER_SUITE_MAP.get(
                    cipher_suite, cipher_suite)
                good_cipher = _cipher_blacklist_regex.search(
                    cipher_suite_name) is None
                if good_cipher:
                    good_ciphers.append(supported_cipher_suite)

            num_good_ciphers = len(good_ciphers)
            good_ciphers_array = new(Security, 'uint32_t[]', num_good_ciphers)
            array_set(good_ciphers_array, good_ciphers)
            good_ciphers_pointer = cast(Security, 'uint32_t *',
                                        good_ciphers_array)
            result = Security.SSLSetEnabledCiphers(session_context,
                                                   good_ciphers_pointer,
                                                   num_good_ciphers)
            handle_sec_error(result)

            # Set a peer id from the session to allow for session reuse
            peer_id = self._session._peer_id
            result = Security.SSLSetPeerID(session_context, peer_id,
                                           len(peer_id))
            handle_sec_error(result)

            handshake_result = Security.SSLHandshake(session_context)
            while handshake_result == SecurityConst.errSSLWouldBlock:
                handshake_result = Security.SSLHandshake(session_context)

            if explicit_validation and handshake_result == SecurityConst.errSSLServerAuthCompleted:
                trust_ref_pointer = new(Security, 'SecTrustRef *')
                result = Security.SSLCopyPeerTrust(session_context,
                                                   trust_ref_pointer)
                handle_sec_error(result)
                trust_ref = unwrap(trust_ref_pointer)

                ca_cert_refs = []
                ca_certs = []
                for cert in self._session._extra_trust_roots:
                    ca_cert = load_certificate(cert)
                    ca_certs.append(ca_cert)
                    ca_cert_refs.append(ca_cert.sec_certificate_ref)

                array_ref = CFHelpers.cf_array_from_list(ca_cert_refs)
                result = Security.SecTrustSetAnchorCertificates(
                    trust_ref, array_ref)
                handle_sec_error(result)

                result_pointer = new(Security, 'SecTrustResultType *')
                result = Security.SecTrustEvaluate(trust_ref, result_pointer)
                handle_sec_error(result)

                trust_result_code = deref(result_pointer)
                invalid_chain_error_codes = set([
                    SecurityConst.kSecTrustResultProceed,
                    SecurityConst.kSecTrustResultUnspecified
                ])
                if trust_result_code not in invalid_chain_error_codes:
                    handshake_result = SecurityConst.errSSLXCertChainInvalid
                else:
                    handshake_result = Security.SSLHandshake(session_context)
                    while handshake_result == SecurityConst.errSSLWouldBlock:
                        handshake_result = Security.SSLHandshake(
                            session_context)

            self._done_handshake = True

            handshake_error_codes = set([
                SecurityConst.errSSLXCertChainInvalid,
                SecurityConst.errSSLCertExpired,
                SecurityConst.errSSLCertNotYetValid,
                SecurityConst.errSSLUnknownRootCert,
                SecurityConst.errSSLNoRootCert,
                SecurityConst.errSSLHostNameMismatch
            ])

            # In testing, only errSSLXCertChainInvalid was ever returned for
            # all of these different situations, however we include the others
            # for completeness. To get the real reason we have to use the
            # certificate from the handshake and use the deprecated function
            # SecTrustGetCssmResultCode().
            if handshake_result in handshake_error_codes:
                trust_ref_pointer = new(Security, 'SecTrustRef *')
                result = Security.SSLCopyPeerTrust(session_context,
                                                   trust_ref_pointer)
                handle_sec_error(result)
                trust_ref = unwrap(trust_ref_pointer)

                result_code_pointer = new(Security, 'OSStatus *')
                result = Security.SecTrustGetCssmResultCode(
                    trust_ref, result_code_pointer)
                result_code = deref(result_code_pointer)

                chain = extract_chain(self._server_hello)

                self_signed = False
                expired = False
                not_yet_valid = False
                no_issuer = False
                cert = None
                bad_hostname = False

                if chain:
                    cert = chain[0]
                    oscrypto_cert = load_certificate(cert)
                    self_signed = oscrypto_cert.self_signed
                    no_issuer = not self_signed and result_code == SecurityConst.CSSMERR_TP_NOT_TRUSTED
                    expired = result_code == SecurityConst.CSSMERR_TP_CERT_EXPIRED
                    not_yet_valid = result_code == SecurityConst.CSSMERR_TP_CERT_NOT_VALID_YET
                    bad_hostname = result_code == SecurityConst.CSSMERR_APPLETP_HOSTNAME_MISMATCH

                if chain and chain[0].hash_algo in set(['md5', 'md2']):
                    raise_weak_signature(chain[0])

                if bad_hostname:
                    raise_hostname(cert, self._hostname)

                elif expired or not_yet_valid:
                    raise_expired_not_yet_valid(cert)

                elif no_issuer:
                    raise_no_issuer(cert)

                elif self_signed:
                    raise_self_signed(cert)

                if detect_client_auth_request(self._server_hello):
                    raise_client_auth()

                raise_verification(cert)

            if handshake_result == SecurityConst.errSSLPeerHandshakeFail:
                if detect_client_auth_request(self._server_hello):
                    raise_client_auth()
                raise_handshake()

            if handshake_result == SecurityConst.errSSLWeakPeerEphemeralDHKey:
                raise_dh_params()

            if handshake_result == SecurityConst.errSSLRecordOverflow:
                raise_protocol_error(self._server_hello)

            if handshake_result in set([
                    SecurityConst.errSSLClosedNoNotify,
                    SecurityConst.errSSLClosedAbort
            ]):
                if detect_other_protocol(self._server_hello):
                    raise_protocol_error(self._server_hello)
                raise_disconnection()

            if handshake_result != SecurityConst.errSSLWouldBlock:
                handle_sec_error(handshake_result, TLSError)

            self._session_context = session_context

            protocol_const_pointer = new(Security, 'SSLProtocol *')
            result = Security.SSLGetNegotiatedProtocolVersion(
                session_context, protocol_const_pointer)
            handle_sec_error(result)
            protocol_const = deref(protocol_const_pointer)

            self._protocol = _PROTOCOL_CONST_STRING_MAP[protocol_const]

            cipher_int_pointer = new(Security, 'SSLCipherSuite *')
            result = Security.SSLGetNegotiatedCipher(session_context,
                                                     cipher_int_pointer)
            handle_sec_error(result)
            cipher_int = deref(cipher_int_pointer)

            cipher_bytes = int_to_bytes(cipher_int, width=2)
            self._cipher_suite = CIPHER_SUITE_MAP.get(cipher_bytes,
                                                      cipher_bytes)

            session_info = parse_session_info(self._server_hello,
                                              self._client_hello)
            self._compression = session_info['compression']
            self._session_id = session_info['session_id']
            self._session_ticket = session_info['session_ticket']

        except (OSError, socket_.error):
            if session_context:
                if osx_version_info < (10, 8):
                    result = Security.SSLDisposeContext(session_context)
                    handle_sec_error(result)
                else:
                    result = CoreFoundation.CFRelease(session_context)
                    handle_cf_error(result)

            self._session_context = None
            self.close()

            raise
Example #20
0
def verify_pss_padding(hash_algorithm, salt_length, key_length, message,
                       signature):
    """
    Verifies the PSS padding on an encoded message

    :param hash_algorithm:
        The string name of the hash algorithm to use: "sha1", "sha224",
        "sha256", "sha384", "sha512"

    :param salt_length:
        The length of the salt as an integer - typically the same as the length
        of the output from the hash_algorithm

    :param key_length:
        The length of the RSA key, in bits

    :param message:
        A byte string of the message to pad

    :param signature:
        The signature to verify

    :return:
        A boolean indicating if the signature is invalid
    """

    if _backend != 'winlegacy' and sys.platform != 'darwin':
        raise SystemError(
            pretty_message('''
            Pure-python RSA PSS signature padding verification code is only for
            Windows XP/2003 and OS X
            '''))

    if not isinstance(message, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            message must be a byte string, not %s
            ''', type_name(message)))

    if not isinstance(signature, byte_cls):
        raise TypeError(
            pretty_message(
                '''
            signature must be a byte string, not %s
            ''', type_name(signature)))

    if not isinstance(salt_length, int_types):
        raise TypeError(
            pretty_message(
                '''
            salt_length must be an integer, not %s
            ''', type_name(salt_length)))

    if salt_length < 0:
        raise ValueError(
            pretty_message(
                '''
            salt_length must be 0 or more - is %s
            ''', repr(salt_length)))

    if hash_algorithm not in set(
        ['sha1', 'sha224', 'sha256', 'sha384', 'sha512']):
        raise ValueError(
            pretty_message(
                '''
            hash_algorithm must be one of "sha1", "sha224", "sha256", "sha384",
            "sha512", not %s
            ''', repr(hash_algorithm)))

    hash_func = getattr(hashlib, hash_algorithm)

    em_bits = key_length - 1
    em_len = int(math.ceil(em_bits / 8))

    message_digest = hash_func(message).digest()
    hash_length = len(message_digest)

    if em_len < hash_length + salt_length + 2:
        return False

    if signature[-1:] != b'\xBC':
        return False

    zero_bits = (8 * em_len) - em_bits

    masked_db_length = em_len - hash_length - 1
    masked_db = signature[0:masked_db_length]

    first_byte = ord(masked_db[0:1])
    bits_that_should_be_zero = first_byte >> (8 - zero_bits)
    if bits_that_should_be_zero != 0:
        return False

    m_prime_digest = signature[masked_db_length:masked_db_length + hash_length]

    db_mask = _mgf1(hash_algorithm, m_prime_digest, em_len - hash_length - 1)

    left_bit_mask = ('0' * zero_bits) + ('1' * (8 - zero_bits))
    left_int_mask = int(left_bit_mask, 2)

    if left_int_mask != 255:
        db_mask = chr_cls(left_int_mask & ord(db_mask[0:1])) + db_mask[1:]

    db = int_to_bytes(int_from_bytes(masked_db) ^ int_from_bytes(db_mask))
    if len(db) < len(masked_db):
        db = (b'\x00' * (len(masked_db) - len(db))) + db

    zero_length = em_len - hash_length - salt_length - 2
    zero_string = b'\x00' * zero_length
    if not constant_compare(db[0:zero_length], zero_string):
        return False

    if db[zero_length:zero_length + 1] != b'\x01':
        return False

    salt = db[0 - salt_length:]

    m_prime = (b'\x00' * 8) + message_digest + salt

    h_prime = hash_func(m_prime).digest()

    return constant_compare(m_prime_digest, h_prime)
Example #21
0
    def build(self, signing_private_key_path, debug=False):
        """
        Validates the certificate information, constructs the ASN.1 structure
        and then signs it
        :param signing_private_key:
            path to a .pem file with a private key
        :return:
            An m2m.Certificate object of the newly signed
            certificate
        """
        if self._self_signed is not True and self._issuer is None:
            raise ValueError(_pretty_message(
                '''
                Certificate must be self-signed, or an issuer must be specified
                '''
            ))

        if self._self_signed:
            self._issuer = self._subject

        if self.serial_number is None:
            time_part = int_to_bytes(int(time.time()))
            random_part = random.getrandbits(24).to_bytes(3, byteorder='big')  # Must contain at least 20 randomly generated BITS
            self.serial_number = int_from_bytes(time_part + random_part)

        # Only re non-optionals are always in this dict
        properties = {
            'version':Integer(value=self._version),
            'serialNumber':OctetString(value=self.serial_number.to_bytes(20, byteorder='big')),
            'subject':self.subject,
        }

        # Optional fields are only added if they're not None
        if self.ca_algorithm is not None:
            properties['cAAlgorithm'] = self.ca_algorithm
        if self.ca_algorithm_parameters is not None:
            properties['cAAlgParams'] = OctetString(value=self.ca_algorithm_parameters)
        if self.issuer is not None:
            properties['issuer'] = self.issuer
        if self.valid_from is not None:
            properties['validFrom'] = self.valid_from
        if self.valid_duration is not None:
            properties['validDuration'] =  self.valid_duration
        if self.pk_algorithm is not None:
            properties['pKAlgorithm'] = self.pk_algorithm
        if self.pk_algorithm_parameters is not None:
            properties['pKAlgParams'] = OctetString(value=self.pk_algorithm_parameters)
        if self.public_key is not None:
            properties['pubKey'] = OctetString(value=self.public_key)
        if self.authkey_id is not None:
            properties['authKeyId'] = self.authkey_id
        if self.subject_key_id is not None:
            properties['subjKeyId'] = OctetString(value=self.subject_key_id)
        if self.key_usage is not None:
            properties['keyUsage'] = OctetString(value=self.key_usage)
        if self.basic_constraints is not None:
            properties['basicConstraints'] =  Integer(value=self.basic_constraints)
        if self.certificate_policy is not None:
            properties['certificatePolicy'] = self.certificate_policy
        if self.subject_alternative_name is not None:
            properties['subjectAltName'] = self.subject_alternative_name
        if self.issuer_alternative_name is not None:
            properties['issuerAltName'] = self.issuer_alternative_name
        if self.extended_key_usage is not None:
            properties['extendedKeyUsage'] = self.extended_key_usage
        if self.auth_info_access_ocsp is not None:
            properties['authInfoAccessOCSP'] = self.auth_info_access_ocsp
        if self.crl_distribution_point_uri is not None:
            properties['cRLDistribPointURI'] = self.crl_distribution_point_uri
        if self.x509_extensions is not None:
            properties['x509extensions'] = self.x509_extensions

        # import ipdb; ipdb.set_trace()
        # break /usr/local/lib/python3.5/dist-packages/asn1crypto/core.py:2786
        tbs_cert = TBSCertificate(properties)

        bytes_to_sign = tbs_cert.dump()
        signature = generate_signature(bytes_to_sign, signing_private_key_path)

        # assert verify_signature(bytes_to_sign, signature, "public.pem")

        if debug:
            print("Build  - Signed_bytes ({len}): {content}".format(len=len(bytes_to_sign), content=hexlify(bytes_to_sign)))
            print("Build  - Signature ({len}): {content}".format(len=len(signature), content=hexlify(signature)))


        return Certificate({
            'tbsCertificate': tbs_cert,
            'cACalcValue': signature
        })