def test_givenABPEmbSeq2SeqModel_whenForwardPassWithTarget_thenProperlyDoPAss( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, encoder_mock, random_mock, ): random_mock.return_value = self.a_value_lower_than_threshold target_mock = MagicMock() to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks() # 1) We reset it later _, decoder_hidden_mock = self.setUp_decoder_mocks( decoder_mock, attention_mechanism=False) decomposition_lengths_mock = MagicMock() # We don't use the one of the setUp_decoder_mocks since we do the full loop decoder_input_mock = MagicMock() to_mock = MagicMock() torch_mock.zeros().to().new_full.return_value = to_mock # We mock the return of the decoder output encoder_mock.__call__().return_value = (decoder_input_mock, decoder_hidden_mock) with patch("deepparse.network.bpemb_seq2seq.EmbeddingNetwork" ) as embedding_network_patch: # we mock the output of the embedding layer embedded_output_mock = MagicMock() embedding_network_patch().return_value = embedded_output_mock seq2seq_model = BPEmbSeq2SeqModel(self.a_torch_device, self.output_size, self.verbose) seq2seq_model.forward( to_predict=to_predict_mock, decomposition_lengths=decomposition_lengths_mock, lengths_tensor=lengths_tensor_mock, target=target_mock, ) embedding_network_patch.assert_has_calls( [call()(to_predict_mock, decomposition_lengths_mock)]) encoder_mock.assert_has_calls( [call()(embedded_output_mock, lengths_tensor_mock)]) lengths_tensor_mock.assert_has_calls([call.max().item()]) decoder_mock.assert_has_calls([ call()( to_mock, decoder_hidden_mock, decoder_input_mock, lengths_tensor_mock, ) ]) target_mock.assert_has_calls([call.transpose(0, 1)])
def test_givenAFasttext2SeqAttModel_whenForwardPassWithTarget_thenProperlyDoPAss( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, encoder_mock, random_mock, ): random_mock.return_value = self.a_value_lower_than_threshold target_mock = MagicMock() to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks() _, decoder_hidden_mock = self.setUp_decoder_mocks( decoder_mock, attention_mechanism=True) # we don't use the one of the setUp_decoder_mocks since we do the full loop decoder_input_mock = MagicMock() to_mock = MagicMock() torch_mock.zeros().to().new_full.return_value = to_mock # We mock the return of the decoder output encoder_mock.__call__().return_value = (decoder_input_mock, decoder_hidden_mock) seq2seq_model = FastTextSeq2SeqModel( self.a_torch_device, self.output_size, self.verbose, attention_mechanism=True, ) seq2seq_model.forward( to_predict=to_predict_mock, lengths_tensor=lengths_tensor_mock, target=target_mock, ) encoder_mock.assert_has_calls( [call()(to_predict_mock, lengths_tensor_mock)]) lengths_tensor_mock.assert_has_calls([call.max().item()]) decoder_mock.assert_has_calls([ call()( to_mock, decoder_hidden_mock, decoder_input_mock, lengths_tensor_mock, ) ]) target_mock.assert_has_calls([call.transpose(0, 1)])