def test_trainer_can_run_cuda(self): trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, num_epochs=2, cuda_device=0) trainer.train()
def test_trainer_can_run_multiple_gpu(self): multigpu_iterator = BasicIterator(batch_size=4) multigpu_iterator.index_with(self.vocab) trainer = Trainer(self.model, self.optimizer, multigpu_iterator, self.instances, num_epochs=2, cuda_device=[0, 1]) trainer.train()
def test_should_stop_early_with_increasing_metric(self): new_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR, patience=5, validation_metric="+test") assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4]) # pylint: disable=protected-access assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1]) # pylint: disable=protected-access
def test_trainer_can_log_histograms(self): # enable activation logging for module in self.model.modules(): module.should_log_activations = True trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, num_epochs=3, serialization_dir=self.TEST_DIR, histogram_interval=2) trainer.train()
def test_trainer_raises_on_model_with_no_loss_key(self): class FakeModel(torch.nn.Module): def forward(self, **kwargs): # pylint: disable=arguments-differ,unused-argument return {} with pytest.raises(RuntimeError): trainer = Trainer(FakeModel(), self.optimizer, self.iterator, self.instances, num_epochs=2, serialization_dir=self.TEST_DIR) trainer.train()
def test_trainer_can_run_with_lr_scheduler(self): lr_params = Params({"type": "reduce_on_plateau"}) lr_scheduler = LearningRateScheduler.from_params( self.optimizer, lr_params) trainer = Trainer(model=self.model, optimizer=self.optimizer, iterator=self.iterator, learning_rate_scheduler=lr_scheduler, validation_metric="-loss", train_dataset=self.instances, validation_dataset=self.instances, num_epochs=2) trainer.train()
def test_can_optimise_model_with_dense_and_sparse_params(self): optimizer_params = Params({"type": "dense_sparse_adam"}) parameters = [[n, p] for n, p in self.model.named_parameters() if p.requires_grad] optimizer = Optimizer.from_params(parameters, optimizer_params) iterator = BasicIterator(2) iterator.index_with(self.vocab) Trainer(self.model, optimizer, iterator, self.instances).train()
def test_trainer_respects_num_serialized_models_to_keep(self): trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, num_epochs=5, serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=3) trainer.train() # Now check the serialized files for prefix in ['model_state_epoch_*', 'training_state_epoch_*']: file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) epochs = [ int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names ] assert sorted(epochs) == [2, 3, 4]
def test_trainer_can_run(self): trainer = Trainer(model=self.model, optimizer=self.optimizer, iterator=self.iterator, train_dataset=self.instances, validation_dataset=self.instances, num_epochs=2) metrics = trainer.train() assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) assert 'best_validation_accuracy' in metrics assert isinstance(metrics['best_validation_accuracy'], float) assert 'best_validation_accuracy3' in metrics assert isinstance(metrics['best_validation_accuracy3'], float) assert 'best_epoch' in metrics assert isinstance(metrics['best_epoch'], int) # Making sure that both increasing and decreasing validation metrics work. trainer = Trainer(model=self.model, optimizer=self.optimizer, iterator=self.iterator, train_dataset=self.instances, validation_dataset=self.instances, validation_metric='+loss', num_epochs=2) metrics = trainer.train() assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) assert 'best_validation_accuracy' in metrics assert isinstance(metrics['best_validation_accuracy'], float) assert 'best_validation_accuracy3' in metrics assert isinstance(metrics['best_validation_accuracy3'], float) assert 'best_epoch' in metrics assert isinstance(metrics['best_epoch'], int)
def test_should_stop_early_with_flat_lining_metric(self): flatline = [.2] * 6 assert Trainer( self.model, self.optimizer, # pylint: disable=protected-access self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR, patience=5, validation_metric="+test")._should_stop_early(flatline) # pylint: disable=protected-access assert Trainer( self.model, self.optimizer, # pylint: disable=protected-access self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR, patience=5, validation_metric="-test")._should_stop_early(flatline) # pylint: disable=protected-access
def test_should_stop_early_with_invalid_patience(self): for patience in [0, -1, -2, 1.5, 'None']: with pytest.raises( ConfigurationError, message='No ConfigurationError for patience={}'.format( patience)): Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=100, patience=patience, validation_metric="+test")
def test_trainer_respects_keep_serialized_model_every_num_seconds(self): # To test: # Create an iterator that sleeps for 2.5 second per epoch, so the total training # time for one epoch is slightly greater then 2.5 seconds. # Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds. # Check the resulting checkpoints. Should then have models at epochs # 2, 4, plus the last two at 5 and 6. class WaitingIterator(BasicIterator): # pylint: disable=arguments-differ def _create_batches(self, *args, **kwargs): time.sleep(2.5) return super(WaitingIterator, self)._create_batches(*args, **kwargs) iterator = WaitingIterator(batch_size=2) iterator.index_with(self.vocab) trainer = Trainer(self.model, self.optimizer, iterator, self.instances, num_epochs=6, serialization_dir=self.TEST_DIR, num_serialized_models_to_keep=2, keep_serialized_model_every_num_seconds=5) trainer.train() # Now check the serialized files for prefix in ['model_state_epoch_*', 'training_state_epoch_*']: file_names = glob.glob(os.path.join(self.TEST_DIR, prefix)) epochs = [ int(re.search(r"_([0-9])\.th", fname).group(1)) for fname in file_names ] # epoch N has N-1 in file name assert sorted(epochs) == [1, 3, 4, 5]
def test_trainer_saves_models_at_specified_interval(self): iterator = BasicIterator(batch_size=4) iterator.index_with(self.vocab) trainer = Trainer(self.model, self.optimizer, iterator, self.instances, num_epochs=2, serialization_dir=self.TEST_DIR, model_save_interval=0.0001) trainer.train() # Now check the serialized files for models saved during the epoch. prefix = 'model_state_epoch_*' file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix))) epochs = [ re.search(r"_([0-9\.\-]+)\.th", fname).group(1) for fname in file_names ] # We should have checkpoints at the end of each epoch and during each, e.g. # [0.timestamp, 0, 1.timestamp, 1] assert len(epochs) == 4 assert epochs[3] == '1' assert '.' in epochs[0] # Now make certain we can restore from timestamped checkpoint. # To do so, remove the checkpoint from the end of epoch 1&2, so # that we are forced to restore from the timestamped checkpoints. for k in range(2): os.remove( os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k))) os.remove( os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k))) os.remove(os.path.join(self.TEST_DIR, 'best.th')) restore_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, num_epochs=2, serialization_dir=self.TEST_DIR, model_save_interval=0.0001) epoch, _ = restore_trainer._restore_checkpoint() # pylint: disable=protected-access assert epoch == 2 # One batch per epoch. assert restore_trainer._batch_num_total == 2 # pylint: disable=protected-access
def test_trainer_can_resume_training(self): trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=1, serialization_dir=self.TEST_DIR) trainer.train() new_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR) epoch, val_metrics_per_epoch = new_trainer._restore_checkpoint() # pylint: disable=protected-access assert epoch == 1 assert len(val_metrics_per_epoch) == 1 assert isinstance(val_metrics_per_epoch[0], float) assert val_metrics_per_epoch[0] != 0. new_trainer.train()
def test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_decreasing_metric( self): new_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR, patience=5, validation_metric="-test") # when it is the only metric it should be considered the best assert new_trainer._is_best_so_far(1, []) # pylint: disable=protected-access # when it is the same as one before it it is not considered the best assert not new_trainer._is_best_so_far(.3, [.3, .3, .3, .2, .5, .1]) # pylint: disable=protected-access # when it is the best it is considered the best assert new_trainer._is_best_so_far(.013, [.3, .3, .3, .2, .5, .1]) # pylint: disable=protected-access # when it is not the the best it is not considered the best assert not new_trainer._is_best_so_far(13.00, [.3, .3, .3, .2, .5, .1]) # pylint: disable=protected-access
def test_should_stop_early_with_early_stopping_disabled(self): # Increasing metric trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=100, patience=None, validation_metric="+test") decreasing_history = [float(i) for i in reversed(range(20))] assert not trainer._should_stop_early(decreasing_history) # pylint: disable=protected-access # Decreasing metric trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=100, patience=None, validation_metric="-test") increasing_history = [float(i) for i in range(20)] assert not trainer._should_stop_early(increasing_history) # pylint: disable=protected-access
def train_model(params: Params, serialization_dir: str, file_friendly_logging: bool = False, recover: bool = False) -> Model: """ Trains the model specified in the given :class:`Params` object, using the data and training parameters also specified in that object, and saves the results in ``serialization_dir``. Parameters ---------- params : ``Params`` A parameter object specifying an AllenNLP Experiment. serialization_dir : ``str`` The directory in which to save results and logs. file_friendly_logging : ``bool``, optional (default=False) If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow down tqdm's output to only once every 10 seconds. recover : ``bool``, optional (default=False) If ``True``, we will try to recover a training run from an existing serialization directory. This is only intended for use when something actually crashed during the middle of a run. For continuing training a model on new data, see the ``fine-tune`` command. Returns ------- best_model: ``Model`` The model with the best epoch weights. """ prepare_environment(params) create_serialization_dir(params, serialization_dir, recover) prepare_global_logging(serialization_dir, file_friendly_logging) check_for_gpu(params.get('trainer').get('cuda_device', -1)) params.to_file(os.path.join(serialization_dir, CONFIG_NAME)) all_datasets = datasets_from_params(params) datasets_for_vocab_creation = set( params.pop("datasets_for_vocab_creation", all_datasets)) for dataset in datasets_for_vocab_creation: if dataset not in all_datasets: raise ConfigurationError( f"invalid 'dataset_for_vocab_creation' {dataset}") logger.info( "From dataset instances, %s will be considered for vocabulary creation.", ", ".join(datasets_for_vocab_creation)) vocab = Vocabulary.from_params( params.pop("vocabulary", {}), (instance for key, dataset in all_datasets.items() for instance in dataset if key in datasets_for_vocab_creation)) vocab.save_to_files(os.path.join(serialization_dir, "vocabulary")) model = Model.from_params(vocab=vocab, params=params.pop('model')) iterator = DataIterator.from_params(params.pop("iterator")) iterator.index_with(vocab) validation_iterator_params = params.pop("validation_iterator", None) if validation_iterator_params: validation_iterator = DataIterator.from_params( validation_iterator_params) validation_iterator.index_with(vocab) else: validation_iterator = None train_data = all_datasets['train'] validation_data = all_datasets.get('validation') test_data = all_datasets.get('test') trainer_params = params.pop("trainer") no_grad_regexes = trainer_params.pop("no_grad", ()) for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) frozen_parameter_names, tunable_parameter_names = \ get_frozen_and_tunable_parameter_names(model) logger.info("Following parameters are Frozen (without gradient):") for name in frozen_parameter_names: logger.info(name) logger.info("Following parameters are Tunable (with gradient):") for name in tunable_parameter_names: logger.info(name) trainer = Trainer.from_params(model, serialization_dir, iterator, train_data, validation_data, trainer_params, validation_iterator=validation_iterator) evaluate_on_test = params.pop_bool("evaluate_on_test", False) params.assert_empty('base train command') try: metrics = trainer.train() except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logging.info( "Training interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir, files_to_archive=params.files_to_archive) raise # Now tar up results archive_model(serialization_dir, files_to_archive=params.files_to_archive) logger.info("Loading the best epoch weights.") best_model_state_path = os.path.join(serialization_dir, 'best.th') best_model_state = torch.load(best_model_state_path) best_model = model best_model.load_state_dict(best_model_state) if test_data and evaluate_on_test: logger.info( "The model will be evaluated using the best epoch weights.") test_metrics = evaluate( best_model, test_data, validation_iterator or iterator, cuda_device=trainer._cuda_devices[0] # pylint: disable=protected-access ) for key, value in test_metrics.items(): metrics["test_" + key] = value elif test_data: logger.info( "To evaluate on the test set after training, pass the " "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.") metrics_json = json.dumps(metrics, indent=2) with open(os.path.join(serialization_dir, "metrics.json"), "w") as metrics_file: metrics_file.write(metrics_json) logger.info("Metrics: %s", metrics_json) return best_model
def fine_tune_model(model: Model, params: Params, serialization_dir: str, extend_vocab: bool = False, file_friendly_logging: bool = False) -> Model: """ Fine tunes the given model, using a set of parameters that is largely identical to those used for :func:`~allennlp.commands.train.train_model`, except that the ``model`` section is ignored, if it is present (as we are already given a ``Model`` here). The main difference between the logic done here and the logic done in ``train_model`` is that here we do not worry about vocabulary construction or creating the model object. Everything else is the same. Parameters ---------- archive : ``Archive`` A saved model archive that is the result of running the ``train`` command. train_data_path : ``str`` Path to the training data to use for fine-tuning. serialization_dir : ``str`` The directory in which to save results and logs. validation_data_path : ``str``, optional Path to the validation data to use while fine-tuning. extend_vocab: ``bool``, optional (default=False) If ``True``, we use the new instances to extend your vocabulary. file_friendly_logging : ``bool``, optional (default=False) If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow down tqdm's output to only once every 10 seconds. """ prepare_environment(params) if os.path.exists(serialization_dir) and os.listdir(serialization_dir): raise ConfigurationError(f"Serialization directory ({serialization_dir}) " f"already exists and is not empty.") os.makedirs(serialization_dir, exist_ok=True) prepare_global_logging(serialization_dir, file_friendly_logging) serialization_params = deepcopy(params).as_dict(quiet=True) with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file: json.dump(serialization_params, param_file, indent=4) if params.pop('model', None): logger.warning("You passed parameters for the model in your configuration file, but we " "are ignoring them, using instead the model parameters in the archive.") vocabulary_params = params.pop('vocabulary', {}) if vocabulary_params.get('directory_path', None): logger.warning("You passed `directory_path` in parameters for the vocabulary in " "your configuration file, but it will be ignored. ") all_datasets = datasets_from_params(params) vocab = model.vocab if extend_vocab: datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets)) for dataset in datasets_for_vocab_creation: if dataset not in all_datasets: raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}") logger.info("Extending model vocabulary using %s data.", ", ".join(datasets_for_vocab_creation)) vocab.extend_from_instances(vocabulary_params, (instance for key, dataset in all_datasets.items() for instance in dataset if key in datasets_for_vocab_creation)) vocab.save_to_files(os.path.join(serialization_dir, "vocabulary")) iterator = DataIterator.from_params(params.pop("iterator")) iterator.index_with(model.vocab) train_data = all_datasets['train'] validation_data = all_datasets.get('validation') test_data = all_datasets.get('test') trainer_params = params.pop("trainer") no_grad_regexes = trainer_params.pop("no_grad", ()) for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad_regexes): parameter.requires_grad_(False) frozen_parameter_names, tunable_parameter_names = \ get_frozen_and_tunable_parameter_names(model) logger.info("Following parameters are Frozen (without gradient):") for name in frozen_parameter_names: logger.info(name) logger.info("Following parameters are Tunable (with gradient):") for name in tunable_parameter_names: logger.info(name) trainer = Trainer.from_params(model, serialization_dir, iterator, train_data, validation_data, trainer_params) evaluate_on_test = params.pop_bool("evaluate_on_test", False) params.assert_empty('base train command') try: metrics = trainer.train() except KeyboardInterrupt: # if we have completed an epoch, try to create a model archive. if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)): logging.info("Fine-tuning interrupted by the user. Attempting to create " "a model archive using the current best epoch weights.") archive_model(serialization_dir, files_to_archive=params.files_to_archive) raise # Now tar up results archive_model(serialization_dir, files_to_archive=params.files_to_archive) if test_data and evaluate_on_test: test_metrics = evaluate(model, test_data, iterator, cuda_device=trainer._cuda_devices[0]) # pylint: disable=protected-access for key, value in test_metrics.items(): metrics["test_" + key] = value elif test_data: logger.info("To evaluate on the test set after training, pass the " "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.") metrics_json = json.dumps(metrics, indent=2) with open(os.path.join(serialization_dir, "metrics.json"), "w") as metrics_file: metrics_file.write(metrics_json) logger.info("Metrics: %s", metrics_json) return model