예제 #1
0
    def test_givenABPEmbSeq2SeqModel_whenForwardPassWithTarget_thenProperlyDoPAss(
        self,
        load_state_dict_mock,
        torch_mock,
        isfile_mock,
        last_version_mock,
        download_weights_mock,
        decoder_mock,
        encoder_mock,
        random_mock,
    ):
        random_mock.return_value = self.a_value_lower_than_threshold

        target_mock = MagicMock()
        to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks()

        # 1) We reset it later
        _, decoder_hidden_mock = self.setUp_decoder_mocks(
            decoder_mock, attention_mechanism=False)
        decomposition_lengths_mock = MagicMock()

        # We don't use the one of the setUp_decoder_mocks since we do the full loop
        decoder_input_mock = MagicMock()
        to_mock = MagicMock()
        torch_mock.zeros().to().new_full.return_value = to_mock

        # We mock the return of the decoder output
        encoder_mock.__call__().return_value = (decoder_input_mock,
                                                decoder_hidden_mock)

        with patch("deepparse.network.bpemb_seq2seq.EmbeddingNetwork"
                   ) as embedding_network_patch:
            # we mock the output of the embedding layer
            embedded_output_mock = MagicMock()
            embedding_network_patch().return_value = embedded_output_mock
            seq2seq_model = BPEmbSeq2SeqModel(self.a_torch_device,
                                              self.output_size, self.verbose)

            seq2seq_model.forward(
                to_predict=to_predict_mock,
                decomposition_lengths=decomposition_lengths_mock,
                lengths_tensor=lengths_tensor_mock,
                target=target_mock,
            )

            embedding_network_patch.assert_has_calls(
                [call()(to_predict_mock, decomposition_lengths_mock)])
            encoder_mock.assert_has_calls(
                [call()(embedded_output_mock, lengths_tensor_mock)])
            lengths_tensor_mock.assert_has_calls([call.max().item()])
            decoder_mock.assert_has_calls([
                call()(
                    to_mock,
                    decoder_hidden_mock,
                    decoder_input_mock,
                    lengths_tensor_mock,
                )
            ])
            target_mock.assert_has_calls([call.transpose(0, 1)])
    def test_givenAFasttext2SeqAttModel_whenForwardPassWithTarget_thenProperlyDoPAss(
        self,
        load_state_dict_mock,
        torch_mock,
        isfile_mock,
        last_version_mock,
        download_weights_mock,
        decoder_mock,
        encoder_mock,
        random_mock,
    ):
        random_mock.return_value = self.a_value_lower_than_threshold

        target_mock = MagicMock()
        to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks()

        _, decoder_hidden_mock = self.setUp_decoder_mocks(
            decoder_mock, attention_mechanism=True)

        # we don't use the one of the setUp_decoder_mocks since we do the full loop
        decoder_input_mock = MagicMock()
        to_mock = MagicMock()
        torch_mock.zeros().to().new_full.return_value = to_mock

        # We mock the return of the decoder output
        encoder_mock.__call__().return_value = (decoder_input_mock,
                                                decoder_hidden_mock)

        seq2seq_model = FastTextSeq2SeqModel(
            self.a_torch_device,
            self.output_size,
            self.verbose,
            attention_mechanism=True,
        )

        seq2seq_model.forward(
            to_predict=to_predict_mock,
            lengths_tensor=lengths_tensor_mock,
            target=target_mock,
        )

        encoder_mock.assert_has_calls(
            [call()(to_predict_mock, lengths_tensor_mock)])
        lengths_tensor_mock.assert_has_calls([call.max().item()])
        decoder_mock.assert_has_calls([
            call()(
                to_mock,
                decoder_hidden_mock,
                decoder_input_mock,
                lengths_tensor_mock,
            )
        ])
        target_mock.assert_has_calls([call.transpose(0, 1)])