コード例 #1
0
    def test_from_checkpoint(self):
        config = get_test_task_config()
        for use_head in [True, False]:
            config["model"] = self.get_model_config(use_head)
            task = build_task(config)
            task.prepare()

            checkpoint_folder = f"{self.base_dir}/{use_head}/"
            input_args = {"config": config}

            # Simulate training by setting the model parameters to zero
            for param in task.model.parameters():
                param.data.zero_()

            checkpoint_hook = CheckpointHook(
                checkpoint_folder, input_args, phase_types=["train"]
            )

            # Create checkpoint dir, save checkpoint
            os.mkdir(checkpoint_folder)
            checkpoint_hook.on_start(task)

            task.train = True
            checkpoint_hook.on_phase_end(task)

            # Model should be checkpointed. load and compare
            checkpoint = load_checkpoint(checkpoint_folder)

            model = ClassyModel.from_checkpoint(checkpoint)
            self.assertTrue(isinstance(model, MyTestModel))

            # All parameters must be zero
            for param in model.parameters():
                self.assertTrue(torch.all(param.data == 0))
コード例 #2
0
    def test_state_checkpointing(self) -> None:
        """
        Test that the state gets checkpointed without any errors, but only on the
        right phase_type and only if the checkpoint directory exists.
        """
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        local_variables = {}
        checkpoint_folder = self.base_dir + "/checkpoint_end_test/"
        input_args = {"foo": "bar"}

        # create a checkpoint hook
        checkpoint_hook = CheckpointHook(checkpoint_folder,
                                         input_args,
                                         phase_types=["train"])

        # checkpoint directory doesn't exist
        # call the on start function
        with self.assertRaises(FileNotFoundError):
            checkpoint_hook.on_start(task)
        # call the on end phase function
        with self.assertRaises(AssertionError):
            checkpoint_hook.on_phase_end(task, local_variables)
        # try loading a non-existent checkpoint
        checkpoint = load_checkpoint(checkpoint_folder)
        self.assertIsNone(checkpoint)

        # create checkpoint dir, verify on_start hook runs
        os.mkdir(checkpoint_folder)
        checkpoint_hook.on_start(task)

        # Phase_type is test, expect no checkpoint
        task.train = False
        # call the on end phase function
        checkpoint_hook.on_phase_end(task, local_variables)
        checkpoint = load_checkpoint(checkpoint_folder)
        self.assertIsNone(checkpoint)

        task.train = True
        # call the on end phase function
        checkpoint_hook.on_phase_end(task, local_variables)
        # model should be checkpointed. load and compare
        checkpoint = load_checkpoint(checkpoint_folder)
        self.assertIsNotNone(checkpoint)
        for key in ["input_args", "classy_state_dict"]:
            self.assertIn(key, checkpoint)
        # not testing for equality of classy_state_dict, that is tested in
        # a separate test
        self.assertDictEqual(checkpoint["input_args"], input_args)
コード例 #3
0
    def test_checkpoint_period(self) -> None:
        """
        Test that the checkpoint_period works as expected.
        """
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        local_variables = {}
        checkpoint_folder = self.base_dir + "/checkpoint_end_test/"
        checkpoint_period = 10

        for phase_types in [["train"], ["train", "test"]]:
            # create a checkpoint hook
            checkpoint_hook = CheckpointHook(
                checkpoint_folder,
                {},
                phase_types=phase_types,
                checkpoint_period=checkpoint_period,
            )

            # create checkpoint dir
            os.mkdir(checkpoint_folder)

            # call the on start function
            checkpoint_hook.on_start(task)

            # shouldn't create any checkpoints until there are checkpoint_period
            # phases which are in phase_types
            count = 0
            valid_phase_count = 0
            while valid_phase_count < checkpoint_period - 1:
                task.train = count % 2 == 0
                # call the on end phase function
                checkpoint_hook.on_phase_end(task, local_variables)
                checkpoint = load_checkpoint(checkpoint_folder)
                self.assertIsNone(checkpoint)
                valid_phase_count += 1 if task.phase_type in phase_types else 0
                count += 1

            # create a phase which is in phase_types
            task.train = True
            # call the on end phase function
            checkpoint_hook.on_phase_end(task, local_variables)
            # model should be checkpointed. load and compare
            checkpoint = load_checkpoint(checkpoint_folder)
            self.assertIsNotNone(checkpoint)
            # delete the checkpoint dir
            shutil.rmtree(checkpoint_folder)