Ejemplo n.º 1
0
def test_decryptor_finalize(patch_default_backend, patch_cipher):
    tester = Decryptor(algorithm=MagicMock(),
                       key=sentinel.key,
                       associated_data=sentinel.aad,
                       iv=sentinel.iv,
                       tag=sentinel.tag)

    test = tester.finalize()

    tester._decryptor.finalize.assert_called_once_with()
    assert test is tester._decryptor.finalize.return_value
Ejemplo n.º 2
0
class StreamDecryptor(_EncryptionStream):  # pylint: disable=too-many-instance-attributes
    """Provides a streaming encryptor for encrypting a stream source.
    Behaves as a standard file-like object.

    .. note::
        Take care when decrypting framed messages with large frame length and large non-framed
        messages.  See :class:`aws_encryption_sdk.stream` for more details.

    .. note::
        If config is provided, all other parameters are ignored.

    :param config: Client configuration object (config or individual parameters required)
    :type config: aws_encryption_sdk.streaming_client.DecryptorConfig
    :param source: Source data to encrypt or decrypt
    :type source: str, bytes, io.IOBase, or file
    :param materials_manager: `CryptoMaterialsManager` from which to obtain cryptographic materials
        (either `materials_manager` or `key_provider` required)
    :type materials_manager: aws_encryption_sdk.materials_managers.base.CryptoMaterialsManager
    :param key_provider: `MasterKeyProvider` from which to obtain data keys for decryption
        (either `materials_manager` or `key_provider` required)
    :type key_provider: aws_encryption_sdk.key_providers.base.MasterKeyProvider
    :param int source_length: Length of source data (optional)

        .. note::
            If source_length is not provided and read() is called, will attempt to seek()
            to the end of the stream and tell() to find the length of source data.

    :param int max_body_length: Maximum frame size (or content length for non-framed messages)
        in bytes to read from ciphertext message.
    """

    _config_class = DecryptorConfig

    def __init__(self, **kwargs):  # pylint: disable=unused-argument,super-init-not-called
        """Prepares necessary initial values."""
        self.last_sequence_number = 0

    def _prep_message(self):
        """Performs initial message setup."""
        self._header, self.header_auth = self._read_header()
        if self._header.content_type == ContentType.NO_FRAMING:
            self._prep_non_framed()
        self._message_prepped = True

    def _read_header(self):
        """Reads the message header from the input stream.

        :returns: tuple containing deserialized header and header_auth objects
        :rtype: tuple of aws_encryption_sdk.structures.MessageHeader
            and aws_encryption_sdk.internal.structures.MessageHeaderAuthentication
        :raises CustomMaximumValueExceeded: if frame length is greater than the custom max value
        """
        header_start = self.source_stream.tell()
        header = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header(
            self.source_stream)

        if (self.config.max_body_length is not None
                and header.content_type == ContentType.FRAMED_DATA
                and header.frame_length > self.config.max_body_length):
            raise CustomMaximumValueExceeded(
                'Frame Size in header found larger than custom value: {found} > {custom}'
                .format(found=header.frame_length,
                        custom=self.config.max_body_length))

        header_end = self.source_stream.tell()
        decrypt_materials_request = DecryptionMaterialsRequest(
            encrypted_data_keys=header.encrypted_data_keys,
            algorithm=header.algorithm,
            encryption_context=header.encryption_context)
        decryption_materials = self.config.materials_manager.decrypt_materials(
            request=decrypt_materials_request)
        if decryption_materials.verification_key is None:
            self.verifier = None
        else:
            self.verifier = Verifier.from_key_bytes(
                algorithm=header.algorithm,
                key_bytes=decryption_materials.verification_key)
        if self.verifier is not None:
            self.source_stream.seek(header_start)
            self.verifier.update(
                self.source_stream.read(header_end - header_start))

        header_auth = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header_auth(
            stream=self.source_stream,
            algorithm=header.algorithm,
            verifier=self.verifier)
        self._derived_data_key = derive_data_encryption_key(
            source_key=decryption_materials.data_key.data_key,
            algorithm=header.algorithm,
            message_id=header.message_id)
        aws_encryption_sdk.internal.formatting.deserialize.validate_header(
            header=header,
            header_auth=header_auth,
            stream=self.source_stream,
            header_start=header_start,
            header_end=header_end,
            data_key=self._derived_data_key)
        return header, header_auth

    def _prep_non_framed(self):
        """Prepare the opening data for a non-framed message."""
        iv, tag, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
            stream=self.source_stream,
            header=self._header,
            verifier=self.verifier)

        if self.config.max_body_length is not None and self.body_length > self.config.max_body_length:
            raise CustomMaximumValueExceeded(
                'Non-framed message content length found larger than custom value: {found} > {custom}'
                .format(found=self.body_length,
                        custom=self.config.max_body_length))

        aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
            content_type=self._header.content_type, is_final_frame=True)
        associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
            message_id=self._header.message_id,
            aad_content_string=aad_content_string,
            seq_num=1,
            length=self.body_length)
        self.decryptor = Decryptor(algorithm=self._header.algorithm,
                                   key=self._derived_data_key,
                                   associated_data=associated_data,
                                   iv=iv,
                                   tag=tag)
        self.body_start = self.source_stream.tell()
        self.body_end = self.body_start + self.body_length

    def _read_bytes_from_non_framed_body(self, b):
        """Reads the requested number of bytes from a streaming non-framed message body.

        :param int b: Number of bytes to read
        :returns: Decrypted bytes from source stream
        :rtype: bytes
        """
        _LOGGER.debug('starting non-framed body read')
        # Always read the entire message for non-framed message bodies.
        bytes_to_read = self.body_end - self.source_stream.tell()
        _LOGGER.debug('%s bytes requested; reading %s bytes', b, bytes_to_read)
        ciphertext = self.source_stream.read(bytes_to_read)
        if len(self.output_buffer) + len(ciphertext) < self.body_length:
            raise SerializationError(
                'Total message body contents less than specified in body description'
            )
        if self.verifier is not None:
            self.verifier.update(ciphertext)
        plaintext = self.decryptor.update(ciphertext)
        plaintext += self.decryptor.finalize()
        aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
            stream=self.source_stream,
            header=self._header,
            verifier=self.verifier)
        self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer(
            stream=self.source_stream, verifier=self.verifier)
        self.source_stream.close()
        return plaintext

    def _read_bytes_from_framed_body(self, b):
        """Reads the requested number of bytes from a streaming framed message body.

        :param int b: Number of bytes to read
        :returns: Bytes read from source stream and decrypted
        :rtype: bytes
        """
        plaintext = b''
        final_frame = False
        _LOGGER.debug('collecting %s bytes', b)
        while len(plaintext) < b and not final_frame:
            _LOGGER.debug('Reading frame')
            frame_data, final_frame = aws_encryption_sdk.internal.formatting.deserialize.deserialize_frame(
                stream=self.source_stream,
                header=self._header,
                verifier=self.verifier)
            _LOGGER.debug('Read complete for frame %s',
                          frame_data.sequence_number)
            if frame_data.sequence_number != self.last_sequence_number + 1:
                raise SerializationError(
                    'Malformed message: frames out of order')
            self.last_sequence_number += 1
            aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
                content_type=self._header.content_type,
                is_final_frame=frame_data.final_frame)
            associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
                message_id=self._header.message_id,
                aad_content_string=aad_content_string,
                seq_num=frame_data.sequence_number,
                length=len(frame_data.ciphertext))
            plaintext += decrypt(algorithm=self._header.algorithm,
                                 key=self._derived_data_key,
                                 encrypted_data=frame_data,
                                 associated_data=associated_data)
            _LOGGER.debug('bytes collected: %s', len(plaintext))
        if final_frame:
            _LOGGER.debug('Reading footer')
            self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer(
                stream=self.source_stream, verifier=self.verifier)
            self.source_stream.close()
        return plaintext

    def _read_bytes(self, b):
        """Reads the requested number of bytes from a streaming message body.

        :param int b: Number of bytes to read
        :raises NotSupportedError: if content type is not supported
        """
        if self.source_stream.closed:
            _LOGGER.debug('Source stream closed')
            return

        if b <= len(self.output_buffer):
            _LOGGER.debug(
                '%s bytes requested less than or equal to current output buffer size %s',
                b, len(self.output_buffer))
            return

        if self._header.content_type == ContentType.FRAMED_DATA:
            self.output_buffer += self._read_bytes_from_framed_body(b)
        elif self._header.content_type == ContentType.NO_FRAMING:
            self.output_buffer += self._read_bytes_from_non_framed_body(b)
        else:
            raise NotSupportedError('Unsupported content type')

    def close(self):
        """Closes out the stream."""
        _LOGGER.debug('Closing stream')
        if not hasattr(self, 'footer'):
            raise SerializationError('Footer not read')
        super(StreamDecryptor, self).close()
class StreamDecryptor(_EncryptionStream):  # pylint: disable=too-many-instance-attributes
    """Provides a streaming encryptor for encrypting a stream source.
    Behaves as a standard file-like object.

    .. note::
        Take care when decrypting framed messages with large frame length and large non-framed
        messages.  See :class:`aws_encryption_sdk.stream` for more details.

    .. note::
        If config is provided, all other parameters are ignored.

    .. versionadded:: 1.5.0
       The *keyring* parameter.

    :param config: Client configuration object (config or individual parameters required)
    :type config: aws_encryption_sdk.streaming_client.DecryptorConfig
    :param source: Source data to encrypt or decrypt
    :type source: str, bytes, io.IOBase, or file
    :param CryptoMaterialsManager materials_manager:
        Cryptographic materials manager to use for encryption
        (either ``materials_manager``, ``keyring``, ``key_provider`` required)
    :param Keyring keyring: Keyring to use for encryption
        (either ``materials_manager``, ``keyring``, ``key_provider`` required)
    :param MasterKeyProvider key_provider:
        Master key provider to use for encryption
        (either ``materials_manager``, ``keyring``, ``key_provider`` required)
    :param int source_length: Length of source data (optional)

        .. note::
            If source_length is not provided and read() is called, will attempt to seek()
            to the end of the stream and tell() to find the length of source data.

    :param int max_body_length: Maximum frame size (or content length for non-framed messages)
        in bytes to read from ciphertext message.
    """

    _config_class = DecryptorConfig

    def __init__(self, **kwargs):  # pylint: disable=unused-argument,super-init-not-called
        """Prepares necessary initial values."""
        self.last_sequence_number = 0
        self.__unframed_bytes_read = 0

    def _prep_message(self):
        """Performs initial message setup."""
        self._header, self.header_auth = self._read_header()
        if self._header.content_type == ContentType.NO_FRAMING:
            self._prep_non_framed()
        self._message_prepped = True

    def _read_header(self):
        """Reads the message header from the input stream.

        :returns: tuple containing deserialized header and header_auth objects
        :rtype: tuple of aws_encryption_sdk.structures.MessageHeader
            and aws_encryption_sdk.internal.structures.MessageHeaderAuthentication
        :raises CustomMaximumValueExceeded: if frame length is greater than the custom max value
        """
        header, raw_header = deserialize_header(self.source_stream)
        self.__unframed_bytes_read += len(raw_header)

        if (
            self.config.max_body_length is not None
            and header.content_type == ContentType.FRAMED_DATA
            and header.frame_length > self.config.max_body_length
        ):
            raise CustomMaximumValueExceeded(
                "Frame Size in header found larger than custom value: {found:d} > {custom:d}".format(
                    found=header.frame_length, custom=self.config.max_body_length
                )
            )

        decrypt_materials_request = DecryptionMaterialsRequest(
            encrypted_data_keys=header.encrypted_data_keys,
            algorithm=header.algorithm,
            encryption_context=header.encryption_context,
        )
        decryption_materials = self.config.materials_manager.decrypt_materials(request=decrypt_materials_request)
        self.keyring_trace = decryption_materials.keyring_trace

        if decryption_materials.verification_key is None:
            self.verifier = None
        else:
            self.verifier = Verifier.from_key_bytes(
                algorithm=header.algorithm, key_bytes=decryption_materials.verification_key
            )
        if self.verifier is not None:
            self.verifier.update(raw_header)

        header_auth = deserialize_header_auth(
            stream=self.source_stream, algorithm=header.algorithm, verifier=self.verifier
        )
        self._derived_data_key = derive_data_encryption_key(
            source_key=decryption_materials.data_key.data_key, algorithm=header.algorithm, message_id=header.message_id
        )
        validate_header(header=header, header_auth=header_auth, raw_header=raw_header, data_key=self._derived_data_key)
        return header, header_auth

    def _prep_non_framed(self):
        """Prepare the opening data for a non-framed message."""
        self._unframed_body_iv, self.body_length = deserialize_non_framed_values(
            stream=self.source_stream, header=self._header, verifier=self.verifier
        )

        if self.config.max_body_length is not None and self.body_length > self.config.max_body_length:
            raise CustomMaximumValueExceeded(
                "Non-framed message content length found larger than custom value: {found:d} > {custom:d}".format(
                    found=self.body_length, custom=self.config.max_body_length
                )
            )

        self.__unframed_bytes_read += self._header.algorithm.iv_len
        self.__unframed_bytes_read += 8  # encrypted content length field
        self._body_start = self.__unframed_bytes_read
        self._body_end = self._body_start + self.body_length

    def _read_bytes_from_non_framed_body(self, b):
        """Reads the requested number of bytes from a streaming non-framed message body.

        :param int b: Number of bytes to read
        :returns: Decrypted bytes from source stream
        :rtype: bytes
        """
        _LOGGER.debug("starting non-framed body read")
        # Always read the entire message for non-framed message bodies.
        bytes_to_read = self.body_length

        _LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read)
        ciphertext = self.source_stream.read(bytes_to_read)

        if len(self.output_buffer) + len(ciphertext) < self.body_length:
            raise SerializationError("Total message body contents less than specified in body description")

        if self.verifier is not None:
            self.verifier.update(ciphertext)

        tag = deserialize_tag(stream=self.source_stream, header=self._header, verifier=self.verifier)

        aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
            content_type=self._header.content_type, is_final_frame=True
        )
        associated_data = assemble_content_aad(
            message_id=self._header.message_id,
            aad_content_string=aad_content_string,
            seq_num=1,
            length=self.body_length,
        )
        self.decryptor = Decryptor(
            algorithm=self._header.algorithm,
            key=self._derived_data_key,
            associated_data=associated_data,
            iv=self._unframed_body_iv,
            tag=tag,
        )

        plaintext = self.decryptor.update(ciphertext)
        plaintext += self.decryptor.finalize()

        self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier)
        return plaintext

    def _read_bytes_from_framed_body(self, b):
        """Reads the requested number of bytes from a streaming framed message body.

        :param int b: Number of bytes to read
        :returns: Bytes read from source stream and decrypted
        :rtype: bytes
        """
        plaintext = b""
        final_frame = False
        _LOGGER.debug("collecting %d bytes", b)
        while len(plaintext) < b and not final_frame:
            _LOGGER.debug("Reading frame")
            frame_data, final_frame = deserialize_frame(
                stream=self.source_stream, header=self._header, verifier=self.verifier
            )
            _LOGGER.debug("Read complete for frame %d", frame_data.sequence_number)
            if frame_data.sequence_number != self.last_sequence_number + 1:
                raise SerializationError("Malformed message: frames out of order")
            self.last_sequence_number += 1
            aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
                content_type=self._header.content_type, is_final_frame=frame_data.final_frame
            )
            associated_data = assemble_content_aad(
                message_id=self._header.message_id,
                aad_content_string=aad_content_string,
                seq_num=frame_data.sequence_number,
                length=len(frame_data.ciphertext),
            )
            plaintext += decrypt(
                algorithm=self._header.algorithm,
                key=self._derived_data_key,
                encrypted_data=frame_data,
                associated_data=associated_data,
            )
            plaintext_length = len(plaintext)
            _LOGGER.debug("bytes collected: %d", plaintext_length)
        if final_frame:
            _LOGGER.debug("Reading footer")
            self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier)

        return plaintext

    def _read_bytes(self, b):
        """Reads the requested number of bytes from a streaming message body.

        :param int b: Number of bytes to read
        :raises NotSupportedError: if content type is not supported
        """
        if hasattr(self, "footer"):
            _LOGGER.debug("Source stream processing complete")
            return

        buffer_length = len(self.output_buffer)
        if 0 <= b <= buffer_length:
            _LOGGER.debug("%d bytes requested less than or equal to current output buffer size %d", b, buffer_length)
            return

        if self._header.content_type == ContentType.FRAMED_DATA:
            self.output_buffer += self._read_bytes_from_framed_body(b)
        elif self._header.content_type == ContentType.NO_FRAMING:
            self.output_buffer += self._read_bytes_from_non_framed_body(b)
        else:
            raise NotSupportedError("Unsupported content type")

    def close(self):
        """Closes out the stream."""
        _LOGGER.debug("Closing stream")
        if not hasattr(self, "footer"):
            raise SerializationError("Footer not read")
        super(StreamDecryptor, self).close()