def test_train_model_can_instantiate_from_params(self): params = Params.from_file(self.FIXTURES_ROOT / "simple_tagger" / "experiment.json") # Can instantiate from base class params TrainModel.from_params( params=params, serialization_dir=self.TEST_DIR, local_rank=0, batch_weight_key="" )
def test_pretrained_configs(self, path): params = Params.from_file( path, params_overrides="{" "'trainer.cuda_device': -1, " "'trainer.use_amp': false, " "'trainer.num_epochs': 2, " "}", ) # Patch max_instances in the multitask case patch_dataset_reader(params["dataset_reader"]) if "validation_dataset_reader" in params: # Unclear why this doesn't work for biattentive_classification_network if "biattentive_classification_network" not in path: patch_dataset_reader(params["validation_dataset_reader"]) # Patch any pretrained glove files with smaller fixtures. patch_glove(params) # Patch image_dir and feature_cache_dir keys so they point at our test fixtures instead. patch_image_dir(params) # Remove unnecessary keys. for key in ("random_seed", "numpy_seed", "pytorch_seed", "distributed"): if key in params: del params[key] # Just make sure the train loop can be instantiated. TrainModel.from_params(params=params, serialization_dir=self.TEST_DIR, local_rank=0)
def test_pretrained_configs(self, path): params = Params.from_file( path, params_overrides="{" "'trainer.cuda_device': -1, " "'trainer.use_amp': false, " "'trainer.num_epochs': 2, " "'dataset_reader.max_instances': 4, " "'dataset_reader.lazy': false, " "}", ) # Patch any pretrained glove files with smaller fixtures. patch_glove(params) # Remove unnecessary keys. for key in ("random_seed", "numpy_seed", "pytorch_seed", "distributed"): if key in params: del params[key] # Just make sure the train loop can be instantiated. TrainModel.from_params(params=params, serialization_dir=self.TEST_DIR, local_rank=0)
def test_train_can_fine_tune_model_from_archive(self): params = Params.from_file( self.FIXTURES_ROOT / "basic_classifier" / "experiment_from_archive.jsonnet" ) train_loop = TrainModel.from_params( params=params, serialization_dir=self.TEST_DIR, local_rank=0, batch_weight_key="" ) train_loop.run() model = Model.from_archive( self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" ) # This is checking that the vocabulary actually got extended. The data that we're using for # training is different from the data we used to produce the model archive, and we set # parameters such that the vocab should have been extended. assert train_loop.model.vocab.get_vocab_size() > model.vocab.get_vocab_size()
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--action", type=Action, choices=list(Action), required=True) parser.add_argument("--config", required=True) parser.add_argument("--serialization-dir", required=True) parser.add_argument("--batch-count", type=int, default=0) parser.add_argument("--assume-multiprocess-types", action="store_true") args = parser.parse_args() params = Params.from_file(args.config) train_model = TrainModel.from_params(params, args.serialization_dir, local_rank=0, batch_weight_key="") trainer = train_model.trainer raw_generator = trainer.iterator(trainer.train_data, num_epochs=1, shuffle=True) if args.action is Action.log: log_iterable(raw_generator, args.assume_multiprocess_types) elif args.action is Action.time: time_iterable(raw_generator, args.batch_count) elif args.action is Action.first: time_to_first(raw_generator) else: raise Exception(f"Unaccounted for action {args.action}")