def test_givenABPEmbAddressParser_whenRetrain_thenTrainingOccur(self): address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_torch_device, verbose=self.verbose) performance_after_training = address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir) self.assertIsNotNone(performance_after_training)
def test_givenAFasttextLightAddressParser_whenRetrain_thenTrainingDoesNotOccur( self): address_parser = AddressParser( model_type=self.a_fasttext_light_model_type, device=self.a_torch_device, verbose=self.verbose) with self.assertRaises(ValueError): _ = address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir)
def test_givenAFasttextAddressParser_whenTestWithFasttextCkpt_thenTestOccur( self): address_parser = AddressParser(model_type=self.a_fasttext_model_type, device=self.a_torch_device, verbose=self.verbose) performance_after_test = address_parser.test( self.test_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir, checkpoint="fasttext") self.assertIsNotNone(performance_after_test)
def test_givenAFasttextAddressParser_whenTestMultipleEpochs_thenTestOccurCorrectly( self): address_parser = AddressParser(model_type=self.a_fasttext_model_type, device=self.a_torch_device, verbose=self.verbose) self.training(address_parser) performance_after_test = address_parser.test( self.test_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir) self.assertIsNotNone(performance_after_test)
def test_givenABPEmbAddressParser_whenTestWithStrCkpt_thenTestOccur(self): address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_torch_device, verbose=self.verbose) self.training(address_parser) performance_after_test = address_parser.test( self.test_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir, checkpoint=self.bpemb_local_path) self.assertIsNotNone(performance_after_test)
def test_givenAFasttextAddressParser_whenTestWithNumWorkerAt0_thenTestOccur(self): address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_cpu_device, verbose=self.verbose, ) self.training(address_parser, self.training_container, self.a_zero_number_of_workers) performance_after_test = address_parser.test( self.test_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, ) self.assertIsNotNone(performance_after_test)
def test_givenAFasttextAddressParser_whenRetrainMultipleEpochs_thenTrainingOccurCorrectly( self): address_parser = AddressParser(model_type=self.a_fasttext_model_type, device=self.a_torch_device, verbose=self.verbose) performance_after_training = address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_three_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir) self.assertIsNotNone(performance_after_training)
def training( self, address_parser: AddressParser, data_container: DatasetContainer, num_workers: int, prediction_tags=None, ): address_parser.retrain( data_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=num_workers, logging_path=self.a_checkpoints_saving_dir, prediction_tags=prediction_tags, )
def test_givenABPEmbAddressParser_whenTestWithStrCkpt_thenTestOccur(self): address_parser = AddressParser( model_type=self.a_bpemb_model_type, device=self.a_cpu_device, verbose=self.verbose, ) self.training(address_parser, self.training_container, self.a_number_of_workers) performance_after_test = address_parser.test( self.test_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, ) self.assertIsNotNone(performance_after_test)
def test_givenAddressParser_whenRetrainNewTagsNoEOS_thenTrainingDoesNotOccur(self): address_parser = AddressParser( model_type=self.a_fasttext_light_model_type, device=self.a_cpu_device, verbose=self.verbose, ) with self.assertRaises(ValueError): _ = address_parser.retrain( self.new_prediction_data_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir, prediction_tags={"ATag": 0}, )
def test_givenDecoderBPEmbToFreeze_thenFreezeEmbeddingsLayer(self): address_parser = AddressParser( model_type=self.a_bpemb_model_type, device=self.a_cpu_device, verbose=self.verbose, ) address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir, layers_to_freeze="decoder", ) self.assert_layer_frozen(address_parser.model.embedding_network)
def test_givenAFasttextMagnitudeModel_whenRetrain_thenRaiseError( self, download_weights_mock, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, mock_model, experiment_mock, ): self.address_parser = AddressParser( model_type=self.a_fasttext_light_model_type, device=self.a_device, verbose=self.verbose, ) with self.assertRaises(ValueError): self.address_parser_retrain_call()
def test_integration_parsing_with_retrain_bpemb(self): model_type = "bpemb" path_to_retrained_model = self.path_to_retrain_bpemb address_parser = AddressParser( model_type=model_type, path_to_retrained_model=path_to_retrained_model) self.assertEqual(model_type, address_parser.model_type)
def test_givenAFasttextAddressParser_whenRetrainNewTags_thenTrainingOccur(self): address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_cpu_device, verbose=self.verbose, ) performance_after_training = address_parser.retrain( self.new_prediction_data_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir, prediction_tags=self.with_new_prediction_tags, ) self.assertIsNotNone(performance_after_training)
def test_givenAFasttextAddressParser_whenTestMultipleEpochs_thenTestOccurCorrectly( self, ): address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_cpu_device, verbose=self.verbose, ) self.training(address_parser, self.training_container, self.a_number_of_workers) performance_after_test = address_parser.test( self.test_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, ) self.assertIsNotNone(performance_after_test)
def test_givenSeq2SeqToFreeze_thenFreezeLayer(self): address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_cpu_device, verbose=self.verbose, ) address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, logging_path=self.a_checkpoints_saving_dir, layers_to_freeze="seq2seq", ) self.assert_layer_frozen(address_parser.model.encoder.lstm) self.assert_layer_frozen(address_parser.model.decoder.lstm) self.assert_layer_not_frozen(address_parser.model.decoder.linear)
def test_givenAFasttextAddressParser_whenRetrainWithConfig_thenTrainingOccur( self): address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_cpu_device, verbose=self.verbose, ) performance_after_training = address_parser.retrain( self.training_container, self.a_train_ratio, epochs=self.a_single_epoch, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, learning_rate=self.a_learning_rate, logging_path=self.a_checkpoints_saving_dir, ) self.assertIsNotNone(performance_after_training)
def test_givenABPEmbModel_whenTest_thenInstantiateExperimentProperly( self, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock, dataloader_mock): self.address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_device, verbose=self.verbose) self.address_parser_test_call() self.assert_experiment_test(experiment_mock, model_mock)
def test_givenABPEmbModel_whenTestVerboseTrue_thenInstantiateWithVerbose( self, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock, data_loader_mock, ): verbose = True self.address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_device, verbose=verbose) self.address_parser_test_call() self.assert_experiment_test_method_is_call(data_loader_mock, experiment_mock, verbose=verbose)
def test_givenABPEmbModel_whenRetrain_thenInstantiateDataLoaderAndTrainProperly( self, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock, dataloader_mock): self.address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_device, verbose=self.verbose) self.address_parser_retrain_call() self.assert_experiment_train_method_is_call(dataloader_mock, experiment_mock)
def test_givenABPEmbModel_whenRetrain_thenRaiseError( self, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock): self.address_parser = AddressParser(model_type=self.a_bpemb_model_type, device=self.a_device, verbose=self.verbose) self.address_parser_retrain_call() optimizer_mock.assert_called_with(model_mock().parameters(), self.a_learning_rate)
def test_givenAFasttextModel_whenRetrainWithNewParamsAndNewTags_thenSaveNewParamsDictAndParams( self, download_weights_mock, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock, data_loader_mock, torch_save_mock, ): self.address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_device, verbose=self.verbose, ) self.address_parser_retrain_call( prediction_tags=self.address_components, seq2seq_params=self.seq2seq_params) saving_model_path = self.saving_template_path.format( self.a_fasttext_model_type) save_call = [ call( { "address_tagger_model": experiment_mock().model.network.state_dict(), "model_type": self.a_fasttext_model_type, "seq2seq_params": self.seq2seq_params, "prediction_tags": self.address_components, }, saving_model_path, ) ] torch_save_mock.assert_has_calls(save_call)
def test_givenABPEmbAddressParser_whenTestWithBPEmbCkptNewTags_thenTestOccur(self): address_parser = AddressParser( model_type=self.a_bpemb_model_type, device=self.a_cpu_device, verbose=self.verbose, ) self.training( address_parser, self.new_prediction_data_container, self.a_number_of_workers, prediction_tags=self.with_new_prediction_tags, ) performance_after_test = address_parser.test( self.new_prediction_data_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, ) self.assertIsNotNone(performance_after_test)
def test_givenAFasttextModel_whenRetrain_thenInstantiateOptimizer( self, download_weights_mock, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock): self.address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_device, verbose=self.verbose) self.address_parser_retrain_call() optimizer_mock.assert_called_with(model_mock().parameters(), self.a_learning_rate)
def test_givenNotTrainingDataContainer_thenRaiseValueError( self, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_patch, ): self.address_parser = AddressParser( model_type=self.a_bpemb_model_type, device=self.a_device, verbose=self.verbose, ) mocked_data_container = ADataContainer(is_training_container=False) with self.assertRaises(ValueError): self.address_parser.test( mocked_data_container, self.a_batch_size, num_workers=self.a_number_of_workers, callbacks=self.a_callbacks_list, seed=self.a_seed, )
def test_givenWrongFreezeLayersName_thenRaiseValueError( self, download_weights_mock, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock, data_loader_mock, torch_save_mock, ): self.address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_device, verbose=self.verbose, ) with self.assertRaises(ValueError): self.address_parser_retrain_call(layers_to_freeze="error_in_name")
def main(args): address_parser = AddressParser(model_type=args.model_type, device=0) train_container = PickleDatasetContainer(args.train_dataset_path) lr_scheduler = StepLR(step_size=20) address_parser.retrain(train_container, 0.8, epochs=args.epochs, batch_size=args.batch_size, num_workers=6, learning_rate=args.learning_rate, callbacks=[lr_scheduler], logging_path=f"./chekpoints/{args.model_type}") test_container = PickleDatasetContainer(args.test_dataset_path) checkpoint = "best" address_parser.test(test_container, batch_size=args.batch_size, num_workers=4, logging_path=f"./chekpoints/{args.model_type}", checkpoint=checkpoint)
def main(args): address_parser = AddressParser(model_type=args.model_type, device=0) if args.mode in ("train", "both"): train_container = PickleDatasetContainer(args.train_dataset_path) lr_scheduler = StepLR(step_size=20) address_parser.retrain(train_container, 0.8, epochs=100, batch_size=1024, num_workers=6, learning_rate=0.001, callbacks=[lr_scheduler], logging_path=f"./chekpoints/{args.model_type}") if args.mode in ("test", "both"): test_container = PickleDatasetContainer(args.test_dataset_path) if args.mode == "test": checkpoint = handle_pre_trained_checkpoint(args.model_type) else: checkpoint = "best" address_parser.test(test_container, batch_size=2048, num_workers=4, logging_path=f"./chekpoints/{args.model_type}", checkpoint=checkpoint)
def test_givenABPEmbAddressParser_whenTestWithConfigWithCallbacksNewTags_thenCallbackAreUse( self, ): address_parser = AddressParser( model_type=self.a_bpemb_model_type, device=self.a_cpu_device, verbose=self.verbose, ) self.training( address_parser, self.new_prediction_data_container, self.a_number_of_workers, prediction_tags=self.with_new_prediction_tags, ) callback_mock = MagicMock() performance_after_test = address_parser.test( self.new_prediction_data_container, batch_size=self.a_batch_size, num_workers=self.a_number_of_workers, callbacks=[callback_mock], ) self.assertIsNotNone(performance_after_test) callback_test_start_call = [call.on_test_begin({})] callback_mock.assert_has_calls(callback_test_start_call) callback_test_end_call = [ call.on_test_end( { "time": ANY, "test_loss": performance_after_test["test_loss"], "test_accuracy": performance_after_test["test_accuracy"], } ) ] callback_mock.assert_has_calls(callback_test_end_call) callback_mock.assert_not_called()
def test_givenAFasttextModel_whenTest_thenInstantiateDataLoaderAndTestProperly( self, download_weights_mock, embeddings_model_mock, vectorizer_model_mock, data_padding_mock, model_mock, data_transform_mock, optimizer_mock, experiment_mock, dataloader_mock): self.address_parser = AddressParser( model_type=self.a_fasttext_model_type, device=self.a_device, verbose=self.verbose) self.address_parser_test_call() self.assert_experiment_test_method_is_call(dataloader_mock, experiment_mock)