def test_givenLocalWeightsNotLastVersion_whenInstantiatingAFastTextSeq2SeqModel_thenShouldDownloadWeights( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock): isfile_mock.return_value = True last_version_mock.return_value = False with patch("deepparse.network.seq2seq.download_weights") as download_weights_mock: self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose) download_weights_mock.assert_called_with(self.model_type, self.a_root_path, verbose=self.verbose)
def test_whenForwardStep_thenStepIsOk(self): self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device) # forward pass for two address: '['15 major st london ontario n5z1e1', '15 major st london ontario n5z1e1']' self.decoder_input_setUp() predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor) self.assert_output_is_valid_dim(predictions)
def test_whenForwardStepWithTarget_thenStepIsOk(self): self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, output_size=self.number_of_tags) # forward pass for two address: "["15 major st london ontario n5z1e1", "15 major st london ontario n5z1e1"]" self.decoder_input_setUp() predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor, self.a_target_vector) self.assert_output_is_valid_dim(predictions, output_dim=self.number_of_tags)
def test_retrainedModel_whenForwardStepWithTarget_thenStepIsOk(self): self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose, path_to_retrained_model=self.a_retrain_model) # forward pass for two address: '['15 major st london ontario n5z1e1', '15 major st london ontario n5z1e1']' self.decoder_input_setUp() predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor, self.a_target_vector) self.assert_output_is_valid_dim(predictions)
def test_whenInstantiateASeq2SeqModel_thenEncodeIsCalledOnce(self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, encoder_mock): self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose) to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks() self.seq2seq_model._encoder_step(to_predict_mock, lengths_tensor_mock, self.a_batch_size) encoder_call = [call()(to_predict_mock, lengths_tensor_mock)] encoder_mock.assert_has_calls(encoder_call)
def test_givenNotLocalWeights_whenInstantiatingAFastTextSeq2SeqModel_thenShouldDownloadWeights( self, load_state_dict_mock, torch_mock, isfile_mock): isfile_mock.return_value = False with patch("deepparse.network.seq2seq.download_weights" ) as download_weights_mock: FastTextSeq2SeqModel(self.a_cpu_device, output_size=self.output_size, verbose=self.verbose) download_weights_mock.assert_called_with(self.model_type, self.a_root_path, verbose=self.verbose)
def test_givenAFasttext2SeqAttModel_whenForwardPassWithTarget_thenProperlyDoPAss( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, encoder_mock, random_mock, ): random_mock.return_value = self.a_value_lower_than_threshold target_mock = MagicMock() to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks() _, decoder_hidden_mock = self.setUp_decoder_mocks( decoder_mock, attention_mechanism=True) # we don't use the one of the setUp_decoder_mocks since we do the full loop decoder_input_mock = MagicMock() to_mock = MagicMock() torch_mock.zeros().to().new_full.return_value = to_mock # We mock the return of the decoder output encoder_mock.__call__().return_value = (decoder_input_mock, decoder_hidden_mock) seq2seq_model = FastTextSeq2SeqModel( self.a_torch_device, self.output_size, self.verbose, attention_mechanism=True, ) seq2seq_model.forward( to_predict=to_predict_mock, lengths_tensor=lengths_tensor_mock, target=target_mock, ) encoder_mock.assert_has_calls( [call()(to_predict_mock, lengths_tensor_mock)]) lengths_tensor_mock.assert_has_calls([call.max().item()]) decoder_mock.assert_has_calls([ call()( to_mock, decoder_hidden_mock, decoder_input_mock, lengths_tensor_mock, ) ]) target_mock.assert_has_calls([call.transpose(0, 1)])
def test_givenRetrainedWeights_whenInstantiatingAFastTextSeq2SeqModel_thenShouldUseRetrainedWeights( self, load_state_dict_mock, torch_mock): all_layers_params = MagicMock() torch_mock.load.return_value = all_layers_params self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose, path_to_retrained_model=self.a_path_to_retrained_model) torch_load_call = [call.load(self.a_path_to_retrained_model, map_location=self.a_torch_device)] torch_mock.assert_has_calls(torch_load_call) load_state_dict_call = [call(all_layers_params)] load_state_dict_mock.assert_has_calls(load_state_dict_call)
def test_retrainedModel_whenForwardStep_thenStepIsOk(self): self.seq2seq_model = FastTextSeq2SeqModel( self.a_torch_device, output_size=self.re_trained_output_dim, verbose=self.verbose, path_to_retrained_model=self.a_retrain_model_path, ) # forward pass for two address: "["15 major st london ontario n5z1e1", "15 major st london ontario n5z1e1"]" self.decoder_input_setUp() predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor) self.assert_output_is_valid_dim(predictions, output_dim=self.re_trained_output_dim)
def test_whenInstantiateASeq2SeqModelNoTarget_thenDecoderIsCalled( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, ): self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose) decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks(decoder_mock) max_length = 4 # a sequence of 4 tokens self.seq2seq_model._decoder_step(decoder_input_mock, decoder_hidden_mock, self.a_none_target, max_length, self.a_batch_size) decoder_call = [call()(decoder_input_mock.view(), decoder_hidden_mock)] * max_length decoder_mock.assert_has_calls(decoder_call)
def test_whenInstantiateASeq2SeqModelWithTarget_thenDecoderIsCalled( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, random_mock, ): random_mock.return_value = self.a_value_lower_than_threshold seq2seq_model = FastTextSeq2SeqModel(self.a_cpu_device, output_size=self.output_size, verbose=self.verbose) decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks( decoder_mock, attention_mechanism=False) lengths_tensor_mock = MagicMock() max_length = 4 # a sequence of 4 tokens lengths_tensor_mock.max().item.return_value = max_length encoder_outputs = MagicMock() seq2seq_model._decoder_step( decoder_input_mock, decoder_hidden_mock, encoder_outputs, self.a_none_target, lengths_tensor_mock, self.a_batch_size, ) decoder_call = [] for idx in range(max_length): decoder_call.append(call()( self.a_transpose_target_vector[idx].view( 1, self.a_batch_size, 1), decoder_hidden_mock, )) self.assert_has_calls_tensor_equals(decoder_mock, decoder_call)
def test_whenInstantiateASeq2SeqModelWithTarget_thenDecoderIsCalled(self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, random_mock): random_mock.return_value = self.a_value_lower_than_threshold self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose) decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks(decoder_mock) max_length = 4 # a sequence of 4 tokens self.seq2seq_model._decoder_step(decoder_input_mock, decoder_hidden_mock, self.a_target_vector, max_length, self.a_batch_size) decoder_call = [] for idx in range(max_length): decoder_call.append(call()(self.a_transpose_target_vector[idx].view(1, self.a_batch_size, 1), decoder_hidden_mock)) self.assert_has_calls_tensor_equals(decoder_mock, decoder_call)
def test_whenInstantiateASeq2SeqAttModelNoTarget_thenDecoderIsCalled( self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock, download_weights_mock, decoder_mock, ): seq2seq_model = FastTextSeq2SeqModel( self.a_cpu_device, output_size=self.output_size, verbose=self.verbose, attention_mechanism=True, ) decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks( decoder_mock, attention_mechanism=True) view_mock = MagicMock() decoder_input_mock.view.return_value = view_mock lengths_tensor_mock = MagicMock() max_length = 4 # a sequence of 4 tokens lengths_tensor_mock.max().item.return_value = max_length encoder_outputs = MagicMock() seq2seq_model._decoder_step( decoder_input_mock, decoder_hidden_mock, encoder_outputs, self.a_none_target, lengths_tensor_mock, self.a_batch_size, ) decoder_call = [ call()(view_mock, decoder_hidden_mock, encoder_outputs, lengths_tensor_mock) ] * max_length decoder_mock.assert_has_calls(decoder_call)