def test_checkpointing(self): # make checkpoint directory checkpoint_folder = self.base_dir + "/checkpoint/" os.mkdir(checkpoint_folder) config = get_fast_test_task_config() cuda_available = torch.cuda.is_available() task = build_task(config) task.prepare(use_gpu=cuda_available) # create a checkpoint hook checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"]) # call the on end phase function checkpoint_hook.on_phase_end(task) # we should be able to train a task using the checkpoint on all available # devices for use_gpu in {False, cuda_available}: # load the checkpoint checkpoint = load_checkpoint(checkpoint_folder) # create a new task task = build_task(config) # set the checkpoint task.set_checkpoint(checkpoint) task.prepare(use_gpu=use_gpu) # we should be able to run the trainer using the checkpoint trainer = LocalTrainer(use_gpu=use_gpu) trainer.train(task)
def test_from_checkpoint(self): config = get_test_task_config() for use_head in [True, False]: config["model"] = self.get_model_config(use_head) task = build_task(config) task.prepare() checkpoint_folder = f"{self.base_dir}/{use_head}/" input_args = {"config": config} # Simulate training by setting the model parameters to zero for param in task.model.parameters(): param.data.zero_() checkpoint_hook = CheckpointHook( checkpoint_folder, input_args, phase_types=["train"] ) # Create checkpoint dir, save checkpoint os.mkdir(checkpoint_folder) checkpoint_hook.on_start(task) task.train = True checkpoint_hook.on_phase_end(task) # Model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) model = ClassyModel.from_checkpoint(checkpoint) self.assertTrue(isinstance(model, MyTestModel)) # All parameters must be zero for param in model.parameters(): self.assertTrue(torch.all(param.data == 0))
def test_checkpoint_period(self) -> None: """ Test that the checkpoint_period works as expected. """ config = get_test_task_config() task = build_task(config) task.prepare() local_variables = {} checkpoint_folder = self.base_dir + "/checkpoint_end_test/" checkpoint_period = 10 for phase_types in [["train"], ["train", "test"]]: # create a checkpoint hook checkpoint_hook = CheckpointHook( checkpoint_folder, {}, phase_types=phase_types, checkpoint_period=checkpoint_period, ) # create checkpoint dir os.mkdir(checkpoint_folder) # call the on start function checkpoint_hook.on_start(task) # shouldn't create any checkpoints until there are checkpoint_period # phases which are in phase_types count = 0 valid_phase_count = 0 while valid_phase_count < checkpoint_period - 1: task.train = count % 2 == 0 # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) valid_phase_count += 1 if task.phase_type in phase_types else 0 count += 1 # create a phase which is in phase_types task.train = True # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) # model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNotNone(checkpoint) # delete the checkpoint dir shutil.rmtree(checkpoint_folder)
def test_failure(self) -> None: self.assertFalse(PathManager.exists("test://foo")) PathManager.register_handler(TestPathHandler()) # make sure that TestPathHandler is being used self.assertTrue(PathManager.exists("test://foo")) checkpoint_folder = "test://root" checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"]) config = get_test_task_config() task = build_task(config) task.prepare() # we should raise an exception while trying to save the checkpoint with self.assertRaises(TestException): checkpoint_hook.on_phase_end(task)
def test_state_checkpointing(self) -> None: """ Test that the state gets checkpointed without any errors, but only on the right phase_type and only if the checkpoint directory exists. """ config = get_test_task_config() task = build_task(config) task.prepare() local_variables = {} checkpoint_folder = self.base_dir + "/checkpoint_end_test/" input_args = {"foo": "bar"} # create a checkpoint hook checkpoint_hook = CheckpointHook(checkpoint_folder, input_args, phase_types=["train"]) # checkpoint directory doesn't exist # call the on start function with self.assertRaises(FileNotFoundError): checkpoint_hook.on_start(task) # call the on end phase function with self.assertRaises(AssertionError): checkpoint_hook.on_phase_end(task, local_variables) # try loading a non-existent checkpoint checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) # create checkpoint dir, verify on_start hook runs os.mkdir(checkpoint_folder) checkpoint_hook.on_start(task) # Phase_type is test, expect no checkpoint task.train = False # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) task.train = True # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) # model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNotNone(checkpoint) for key in ["input_args", "classy_state_dict"]: self.assertIn(key, checkpoint) # not testing for equality of classy_state_dict, that is tested in # a separate test self.assertDictEqual(checkpoint["input_args"], input_args)