コード例 #1
0
 def test_givenLocalWeightsNotLastVersion_whenInstantiatingAFastTextSeq2SeqModel_thenShouldDownloadWeights(
         self, load_state_dict_mock, torch_mock, isfile_mock, last_version_mock):
     isfile_mock.return_value = True
     last_version_mock.return_value = False
     with patch("deepparse.network.seq2seq.download_weights") as download_weights_mock:
         self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose)
         download_weights_mock.assert_called_with(self.model_type, self.a_root_path, verbose=self.verbose)
コード例 #2
0
    def test_whenForwardStep_thenStepIsOk(self):
        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device)
        # forward pass for two address: '['15 major st london ontario n5z1e1', '15 major st london ontario n5z1e1']'
        self.decoder_input_setUp()

        predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor)

        self.assert_output_is_valid_dim(predictions)
    def test_whenForwardStepWithTarget_thenStepIsOk(self):
        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, output_size=self.number_of_tags)
        # forward pass for two address: "["15 major st london ontario n5z1e1", "15 major st london ontario n5z1e1"]"
        self.decoder_input_setUp()

        predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor, self.a_target_vector)

        self.assert_output_is_valid_dim(predictions, output_dim=self.number_of_tags)
コード例 #4
0
    def test_retrainedModel_whenForwardStepWithTarget_thenStepIsOk(self):
        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device,
                                                  self.verbose,
                                                  path_to_retrained_model=self.a_retrain_model)
        # forward pass for two address: '['15 major st london ontario n5z1e1', '15 major st london ontario n5z1e1']'
        self.decoder_input_setUp()

        predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor, self.a_target_vector)

        self.assert_output_is_valid_dim(predictions)
コード例 #5
0
    def test_whenInstantiateASeq2SeqModel_thenEncodeIsCalledOnce(self, load_state_dict_mock, torch_mock, isfile_mock,
                                                                 last_version_mock, download_weights_mock,
                                                                 encoder_mock):
        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose)

        to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks()
        self.seq2seq_model._encoder_step(to_predict_mock, lengths_tensor_mock, self.a_batch_size)

        encoder_call = [call()(to_predict_mock, lengths_tensor_mock)]

        encoder_mock.assert_has_calls(encoder_call)
コード例 #6
0
 def test_givenNotLocalWeights_whenInstantiatingAFastTextSeq2SeqModel_thenShouldDownloadWeights(
         self, load_state_dict_mock, torch_mock, isfile_mock):
     isfile_mock.return_value = False
     with patch("deepparse.network.seq2seq.download_weights"
                ) as download_weights_mock:
         FastTextSeq2SeqModel(self.a_cpu_device,
                              output_size=self.output_size,
                              verbose=self.verbose)
         download_weights_mock.assert_called_with(self.model_type,
                                                  self.a_root_path,
                                                  verbose=self.verbose)
コード例 #7
0
    def test_givenAFasttext2SeqAttModel_whenForwardPassWithTarget_thenProperlyDoPAss(
        self,
        load_state_dict_mock,
        torch_mock,
        isfile_mock,
        last_version_mock,
        download_weights_mock,
        decoder_mock,
        encoder_mock,
        random_mock,
    ):
        random_mock.return_value = self.a_value_lower_than_threshold

        target_mock = MagicMock()
        to_predict_mock, lengths_tensor_mock = self.setup_encoder_mocks()

        _, decoder_hidden_mock = self.setUp_decoder_mocks(
            decoder_mock, attention_mechanism=True)

        # we don't use the one of the setUp_decoder_mocks since we do the full loop
        decoder_input_mock = MagicMock()
        to_mock = MagicMock()
        torch_mock.zeros().to().new_full.return_value = to_mock

        # We mock the return of the decoder output
        encoder_mock.__call__().return_value = (decoder_input_mock,
                                                decoder_hidden_mock)

        seq2seq_model = FastTextSeq2SeqModel(
            self.a_torch_device,
            self.output_size,
            self.verbose,
            attention_mechanism=True,
        )

        seq2seq_model.forward(
            to_predict=to_predict_mock,
            lengths_tensor=lengths_tensor_mock,
            target=target_mock,
        )

        encoder_mock.assert_has_calls(
            [call()(to_predict_mock, lengths_tensor_mock)])
        lengths_tensor_mock.assert_has_calls([call.max().item()])
        decoder_mock.assert_has_calls([
            call()(
                to_mock,
                decoder_hidden_mock,
                decoder_input_mock,
                lengths_tensor_mock,
            )
        ])
        target_mock.assert_has_calls([call.transpose(0, 1)])
コード例 #8
0
    def test_givenRetrainedWeights_whenInstantiatingAFastTextSeq2SeqModel_thenShouldUseRetrainedWeights(
            self, load_state_dict_mock, torch_mock):
        all_layers_params = MagicMock()
        torch_mock.load.return_value = all_layers_params
        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device,
                                                  self.verbose,
                                                  path_to_retrained_model=self.a_path_to_retrained_model)

        torch_load_call = [call.load(self.a_path_to_retrained_model, map_location=self.a_torch_device)]
        torch_mock.assert_has_calls(torch_load_call)

        load_state_dict_call = [call(all_layers_params)]
        load_state_dict_mock.assert_has_calls(load_state_dict_call)
    def test_retrainedModel_whenForwardStep_thenStepIsOk(self):
        self.seq2seq_model = FastTextSeq2SeqModel(
            self.a_torch_device,
            output_size=self.re_trained_output_dim,
            verbose=self.verbose,
            path_to_retrained_model=self.a_retrain_model_path,
        )
        # forward pass for two address: "["15 major st london ontario n5z1e1", "15 major st london ontario n5z1e1"]"
        self.decoder_input_setUp()

        predictions = self.seq2seq_model.forward(self.to_predict_tensor, self.a_lengths_tensor)

        self.assert_output_is_valid_dim(predictions, output_dim=self.re_trained_output_dim)
コード例 #10
0
    def test_whenInstantiateASeq2SeqModelNoTarget_thenDecoderIsCalled(
        self,
        load_state_dict_mock,
        torch_mock,
        isfile_mock,
        last_version_mock,
        download_weights_mock,
        decoder_mock,
    ):
        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose)

        decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks(decoder_mock)
        max_length = 4  # a sequence of 4 tokens
        self.seq2seq_model._decoder_step(decoder_input_mock, decoder_hidden_mock, self.a_none_target, max_length,
                                         self.a_batch_size)

        decoder_call = [call()(decoder_input_mock.view(), decoder_hidden_mock)] * max_length

        decoder_mock.assert_has_calls(decoder_call)
コード例 #11
0
    def test_whenInstantiateASeq2SeqModelWithTarget_thenDecoderIsCalled(
        self,
        load_state_dict_mock,
        torch_mock,
        isfile_mock,
        last_version_mock,
        download_weights_mock,
        decoder_mock,
        random_mock,
    ):
        random_mock.return_value = self.a_value_lower_than_threshold

        seq2seq_model = FastTextSeq2SeqModel(self.a_cpu_device,
                                             output_size=self.output_size,
                                             verbose=self.verbose)

        decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks(
            decoder_mock, attention_mechanism=False)

        lengths_tensor_mock = MagicMock()
        max_length = 4  # a sequence of 4 tokens
        lengths_tensor_mock.max().item.return_value = max_length
        encoder_outputs = MagicMock()
        seq2seq_model._decoder_step(
            decoder_input_mock,
            decoder_hidden_mock,
            encoder_outputs,
            self.a_none_target,
            lengths_tensor_mock,
            self.a_batch_size,
        )

        decoder_call = []

        for idx in range(max_length):
            decoder_call.append(call()(
                self.a_transpose_target_vector[idx].view(
                    1, self.a_batch_size, 1),
                decoder_hidden_mock,
            ))

        self.assert_has_calls_tensor_equals(decoder_mock, decoder_call)
コード例 #12
0
    def test_whenInstantiateASeq2SeqModelWithTarget_thenDecoderIsCalled(self, load_state_dict_mock, torch_mock,
                                                                        isfile_mock, last_version_mock,
                                                                        download_weights_mock, decoder_mock,
                                                                        random_mock):
        random_mock.return_value = self.a_value_lower_than_threshold

        self.seq2seq_model = FastTextSeq2SeqModel(self.a_torch_device, self.verbose)

        decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks(decoder_mock)
        max_length = 4  # a sequence of 4 tokens
        self.seq2seq_model._decoder_step(decoder_input_mock, decoder_hidden_mock, self.a_target_vector, max_length,
                                         self.a_batch_size)

        decoder_call = []

        for idx in range(max_length):
            decoder_call.append(call()(self.a_transpose_target_vector[idx].view(1, self.a_batch_size, 1),
                                       decoder_hidden_mock))

        self.assert_has_calls_tensor_equals(decoder_mock, decoder_call)
コード例 #13
0
    def test_whenInstantiateASeq2SeqAttModelNoTarget_thenDecoderIsCalled(
        self,
        load_state_dict_mock,
        torch_mock,
        isfile_mock,
        last_version_mock,
        download_weights_mock,
        decoder_mock,
    ):
        seq2seq_model = FastTextSeq2SeqModel(
            self.a_cpu_device,
            output_size=self.output_size,
            verbose=self.verbose,
            attention_mechanism=True,
        )

        decoder_input_mock, decoder_hidden_mock = self.setUp_decoder_mocks(
            decoder_mock, attention_mechanism=True)

        view_mock = MagicMock()
        decoder_input_mock.view.return_value = view_mock

        lengths_tensor_mock = MagicMock()
        max_length = 4  # a sequence of 4 tokens
        lengths_tensor_mock.max().item.return_value = max_length
        encoder_outputs = MagicMock()
        seq2seq_model._decoder_step(
            decoder_input_mock,
            decoder_hidden_mock,
            encoder_outputs,
            self.a_none_target,
            lengths_tensor_mock,
            self.a_batch_size,
        )

        decoder_call = [
            call()(view_mock, decoder_hidden_mock, encoder_outputs,
                   lengths_tensor_mock)
        ] * max_length

        decoder_mock.assert_has_calls(decoder_call)