def test_encryptor_finalize(patch_default_backend, patch_cipher): tester = Encryptor(algorithm=MagicMock(), key=sentinel.key, associated_data=sentinel.aad, iv=sentinel.iv) test = tester.finalize() tester._encryptor.finalize.assert_called_once_with() assert test is tester._encryptor.finalize.return_value
def test_encryptor_init(patch_default_backend, patch_cipher): mock_algorithm = MagicMock() tester = Encryptor(algorithm=mock_algorithm, key=sentinel.key, associated_data=sentinel.aad, iv=sentinel.iv) assert tester.source_key is sentinel.key mock_algorithm.encryption_algorithm.assert_called_once_with(sentinel.key) mock_algorithm.encryption_mode.assert_called_once_with(sentinel.iv) patch_default_backend.assert_called_once_with() patch_cipher.assert_called_once_with( mock_algorithm.encryption_algorithm.return_value, mock_algorithm.encryption_mode.return_value, backend=patch_default_backend.return_value, ) patch_cipher.return_value.encryptor.assert_called_once_with() assert tester._encryptor is patch_cipher.return_value.encryptor.return_value tester._encryptor.authenticate_additional_data.assert_called_once_with(sentinel.aad)
class StreamEncryptor(_EncryptionStream): # pylint: disable=too-many-instance-attributes """Provides a streaming encryptor for encrypting a stream source. Behaves as a standard file-like object. .. note:: Take care when encrypting framed messages with large frame length and large non-framed messages. See :class:`aws_encryption_sdk.stream` for more details. .. note:: If config is provided, all other parameters are ignored. :param config: Client configuration object (config or individual parameters required) :type config: aws_encryption_sdk.streaming_client.EncryptorConfig :param source: Source data to encrypt or decrypt :type source: str, bytes, io.IOBase, or file :param materials_manager: `CryptoMaterialsManager` from which to obtain cryptographic materials (either `materials_manager` or `key_provider` required) :type materials_manager: aws_encryption_sdk.materials_manager.base.CryptoMaterialsManager :param key_provider: `MasterKeyProvider` from which to obtain data keys for encryption (either `materials_manager` or `key_provider` required) :type key_provider: aws_encryption_sdk.key_providers.base.MasterKeyProvider :param int source_length: Length of source data (optional) .. note:: If source_length is not provided and unframed message is being written or read() is called, will attempt to seek() to the end of the stream and tell() to find the length of source data. .. note:: .. versionadded:: 1.3.0 If `source_length` and `materials_manager` are both provided, the total plaintext bytes encrypted will not be allowed to exceed `source_length`. To maintain backwards compatibility, this is not enforced if a `key_provider` is provided. :param dict encryption_context: Dictionary defining encryption context :param algorithm: Algorithm to use for encryption :type algorithm: aws_encryption_sdk.identifiers.Algorithm :param int frame_length: Frame length in bytes """ _config_class = EncryptorConfig def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.sequence_number = 1 self.content_type = aws_encryption_sdk.internal.utils.content_type( self.config.frame_length) self._bytes_encrypted = 0 if self.config.frame_length == 0 and ( self.config.source_length is not None and self.config.source_length > MAX_NON_FRAMED_SIZE): raise SerializationError("Source too large for non-framed message") self.__unframed_plaintext_cache = io.BytesIO() self.__message_complete = False def ciphertext_length(self): """Returns the length of the resulting ciphertext message in bytes. :rtype: int """ return aws_encryption_sdk.internal.formatting.ciphertext_length( header=self.header, plaintext_length=self.stream_length) def _prep_message(self): """Performs initial message setup. :raises MasterKeyProviderError: if primary master key is not a member of supplied MasterKeyProvider :raises MasterKeyProviderError: if no Master Keys are returned from key_provider """ validate_commitment_policy_on_encrypt(self.config.commitment_policy, self.config.algorithm) try: plaintext_length = self.stream_length except NotSupportedError: plaintext_length = None encryption_materials_request = EncryptionMaterialsRequest( algorithm=self.config.algorithm, encryption_context=self.config.encryption_context.copy(), frame_length=self.config.frame_length, plaintext_rostream=aws_encryption_sdk.internal.utils.streams. ROStream(self.source_stream), plaintext_length=plaintext_length, commitment_policy=self.config.commitment_policy, ) self._encryption_materials = self.config.materials_manager.get_encryption_materials( request=encryption_materials_request) if self.config.algorithm is not None and self._encryption_materials.algorithm != self.config.algorithm: raise ActionNotAllowedError( ("Cryptographic materials manager provided algorithm suite" " differs from algorithm suite in request.\n" "Required: {requested}\n" "Provided: {provided}").format( requested=self.config.algorithm, provided=self._encryption_materials.algorithm)) if self._encryption_materials.signing_key is None: self.signer = None else: self.signer = Signer.from_key_bytes( algorithm=self._encryption_materials.algorithm, key_bytes=self._encryption_materials.signing_key) aws_encryption_sdk.internal.utils.validate_frame_length( frame_length=self.config.frame_length, algorithm=self._encryption_materials.algorithm) message_id = aws_encryption_sdk.internal.utils.message_id( self._encryption_materials.algorithm.message_id_length()) self._derived_data_key = derive_data_encryption_key( source_key=self._encryption_materials.data_encryption_key.data_key, algorithm=self._encryption_materials.algorithm, message_id=message_id, ) self._header = self.generate_header(message_id) self._write_header() if self.content_type == ContentType.NO_FRAMING: self._prep_non_framed() self._message_prepped = True def generate_header(self, message_id): """Generates the header object. :param message_id: The randomly generated id for the message :type message_id: bytes """ version = VERSION if self._encryption_materials.algorithm.message_format_version == 0x02: version = SerializationVersion.V2 kwargs = dict( version=version, algorithm=self._encryption_materials.algorithm, message_id=message_id, encryption_context=self._encryption_materials.encryption_context, encrypted_data_keys=self._encryption_materials.encrypted_data_keys, content_type=self.content_type, frame_length=self.config.frame_length, ) if self._encryption_materials.algorithm.is_committing(): commitment_key = calculate_commitment_key( source_key=self._encryption_materials.data_encryption_key. data_key, algorithm=self._encryption_materials.algorithm, message_id=message_id, ) kwargs["commitment_key"] = commitment_key if version == SerializationVersion.V1: kwargs["type"] = TYPE kwargs["content_aad_length"] = 0 kwargs[ "header_iv_length"] = self._encryption_materials.algorithm.iv_len return MessageHeader(**kwargs) def _write_header(self): """Builds the message header and writes it to the output stream.""" self.output_buffer += serialize_header(header=self._header, signer=self.signer) self.output_buffer += serialize_header_auth( version=self._header.version, algorithm=self._encryption_materials.algorithm, header=self.output_buffer, data_encryption_key=self._derived_data_key, signer=self.signer, ) def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" try: plaintext_length = self.stream_length self.__unframed_plaintext_cache = self.source_stream except NotSupportedError: # We need to know the plaintext length before we can start processing the data. # If we cannot seek on the source then we need to read the entire source into memory. self.__unframed_plaintext_cache = io.BytesIO() self.__unframed_plaintext_cache.write(self.source_stream.read()) plaintext_length = self.__unframed_plaintext_cache.tell() self.__unframed_plaintext_cache.seek(0) aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( content_type=self.content_type, is_final_frame=True) associated_data = assemble_content_aad( message_id=self._header.message_id, aad_content_string=aad_content_string, seq_num=1, length=plaintext_length, ) self.encryptor = Encryptor( algorithm=self._encryption_materials.algorithm, key=self._derived_data_key, associated_data=associated_data, iv=non_framed_body_iv(self._encryption_materials.algorithm), ) self.output_buffer += serialize_non_framed_open( algorithm=self._encryption_materials.algorithm, iv=self.encryptor.iv, plaintext_length=plaintext_length, signer=self.signer, ) def _read_bytes_to_non_framed_body(self, b): """Reads the requested number of bytes from source to a streaming non-framed message body. :param int b: Number of bytes to read :returns: Encrypted bytes from source stream :rtype: bytes """ _LOGGER.debug("Reading %d bytes", b) plaintext = self.__unframed_plaintext_cache.read(b) plaintext_length = len(plaintext) if self.tell() + len(plaintext) > MAX_NON_FRAMED_SIZE: raise SerializationError("Source too large for non-framed message") ciphertext = self.encryptor.update(plaintext) self._bytes_encrypted += plaintext_length if self.signer is not None: self.signer.update(ciphertext) if len(plaintext) < b: _LOGGER.debug( "Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b) closing = self.encryptor.finalize() if self.signer is not None: self.signer.update(closing) closing += serialize_non_framed_close(tag=self.encryptor.tag, signer=self.signer) if self.signer is not None: closing += serialize_footer(self.signer) self.__message_complete = True return ciphertext + closing return ciphertext def _read_bytes_to_framed_body(self, b): """Reads the requested number of bytes from source to a streaming framed message body. :param int b: Number of bytes to read :returns: Bytes read from source stream, encrypted, and serialized :rtype: bytes """ _LOGGER.debug("collecting %d bytes", b) _b = b if b > 0: _frames_to_read = math.ceil(b / float(self.config.frame_length)) b = int(_frames_to_read * self.config.frame_length) _LOGGER.debug( "%d bytes requested; reading %d bytes after normalizing to frame length", _b, b) plaintext = self.source_stream.read(b) plaintext_length = len(plaintext) _LOGGER.debug("%d bytes read from source", plaintext_length) finalize = False if b < 0 or plaintext_length < b: _LOGGER.debug("Final plaintext read from source") finalize = True output = b"" final_frame_written = False while ( # If not finalizing on this pass, exit when plaintext is exhausted (not finalize and plaintext) # If finalizing on this pass, wait until final frame is written or (finalize and not final_frame_written)): current_plaintext_length = len(plaintext) is_final_frame = finalize and current_plaintext_length < self.config.frame_length bytes_in_frame = min(current_plaintext_length, self.config.frame_length) _LOGGER.debug( "Writing %d bytes into%s frame %d", bytes_in_frame, " final" if is_final_frame else "", self.sequence_number, ) self._bytes_encrypted += bytes_in_frame ciphertext, plaintext = serialize_frame( algorithm=self._encryption_materials.algorithm, plaintext=plaintext, message_id=self._header.message_id, data_encryption_key=self._derived_data_key, frame_length=self.config.frame_length, sequence_number=self.sequence_number, is_final_frame=is_final_frame, signer=self.signer, ) final_frame_written = is_final_frame output += ciphertext self.sequence_number += 1 if finalize: _LOGGER.debug("Writing footer") if self.signer is not None: output += serialize_footer(self.signer) self.__message_complete = True return output def _read_bytes(self, b): """Reads the requested number of bytes from a streaming message body. :param int b: Number of bytes to read :raises NotSupportedError: if content type is not supported """ _LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type) if 0 <= b <= len(self.output_buffer) or self.__message_complete: _LOGGER.debug( "No need to read from source stream or source stream closed") return if self.content_type == ContentType.FRAMED_DATA: _LOGGER.debug("Reading to framed body") self.output_buffer += self._read_bytes_to_framed_body(b) elif self.content_type == ContentType.NO_FRAMING: _LOGGER.debug("Reading to non-framed body") self.output_buffer += self._read_bytes_to_non_framed_body(b) else: raise NotSupportedError("Unsupported content type") # To maintain backwards compatibility, only enforce this if a CMM is provided by the caller. if self.config.key_provider is None and self.config.source_length is not None: # Enforce that if the caller provided a source length value, the total bytes encrypted # must not exceed that value. if self._bytes_encrypted > self.config.source_length: raise CustomMaximumValueExceeded( "Bytes encrypted has exceeded stated source length estimate:\n{actual:d} > {estimated:d}" .format(actual=self._bytes_encrypted, estimated=self.config.source_length)) def close(self): """Closes out the stream.""" _LOGGER.debug("Closing stream") super(StreamEncryptor, self).close()