def test_trainer_saves_models_at_specified_interval(self): iterator = BasicIterator(batch_size=4) iterator.index_with(self.vocab) trainer = Trainer(self.model, self.optimizer, iterator, self.instances, num_epochs=2, serialization_dir=self.TEST_DIR, model_save_interval=0.0001) trainer.train() # Now check the serialized files for models saved during the epoch. prefix = 'model_state_epoch_*' file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix))) epochs = [ re.search(r"_([0-9\.\-]+)\.th", fname).group(1) for fname in file_names ] # We should have checkpoints at the end of each epoch and during each, e.g. # [0.timestamp, 0, 1.timestamp, 1] assert len(epochs) == 4 assert epochs[3] == '1' assert '.' in epochs[0] # Now make certain we can restore from timestamped checkpoint. # To do so, remove the checkpoint from the end of epoch 1&2, so # that we are forced to restore from the timestamped checkpoints. for k in range(2): os.remove( os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k))) os.remove( os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k))) os.remove(os.path.join(self.TEST_DIR, 'best.th')) restore_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, num_epochs=2, serialization_dir=self.TEST_DIR, model_save_interval=0.0001) epoch, _ = restore_trainer._restore_checkpoint() # pylint: disable=protected-access assert epoch == 2 # One batch per epoch. assert restore_trainer._batch_num_total == 2 # pylint: disable=protected-access
def test_trainer_can_resume_training(self): trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=1, serialization_dir=self.TEST_DIR) trainer.train() new_trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, validation_dataset=self.instances, num_epochs=3, serialization_dir=self.TEST_DIR) epoch, val_metrics_per_epoch = new_trainer._restore_checkpoint() # pylint: disable=protected-access assert epoch == 1 assert len(val_metrics_per_epoch) == 1 assert isinstance(val_metrics_per_epoch[0], float) assert val_metrics_per_epoch[0] != 0. new_trainer.train()