def test_checkpointing_and_restoring(self, tmp_path: pathlib.Path) -> None:
        def make_trial_controller_fn(
            workloads: workload.Stream,
            checkpoint_dir: typing.Optional[str] = None,
            latest_checkpoint: typing.Optional[typing.Dict[str,
                                                           typing.Any]] = None,
            steps_completed: int = 0,
        ) -> det.TrialController:
            updated_hparams = {
                "lr_scheduler_step_mode":
                pytorch.LRScheduler.StepMode.STEP_EVERY_BATCH.value,
                **self.hparams,
            }
            return utils.make_trial_controller_from_trial_implementation(
                trial_class=pytorch_xor_model.XORTrialWithLRScheduler,
                hparams=updated_hparams,
                workloads=workloads,
                trial_seed=self.trial_seed,
                checkpoint_dir=checkpoint_dir,
                latest_checkpoint=latest_checkpoint,
                steps_completed=steps_completed,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                               tmp_path)
    def test_checkpoint_save_load_hooks(self, tmp_path: pathlib.Path) -> None:
        class OneVarLM(la_model.OneVarLM):
            def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
                assert "test" in checkpoint
                assert checkpoint["test"] is True

            def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
                checkpoint["test"] = True

        class OneVarLA(la_model.OneVarTrial):
            def __init__(self, context):
                super().__init__(context, OneVarLM)

        def make_trial_controller_fn(
                workloads: workload.Stream,
                load_path: typing.Optional[str] = None) -> det.TrialController:

            return utils.make_trial_controller_from_trial_implementation(
                trial_class=OneVarLA,
                hparams=self.hparams,
                workloads=workloads,
                load_path=load_path,
                trial_seed=self.trial_seed,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                               tmp_path)
    def test_checkpoint_load_hook(self, tmp_path: pathlib.Path) -> None:
        class OneVarLM(la_model.OneVarLM):
            def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
                assert "test" in checkpoint

        class OneVarLA(la_model.OneVarTrial):
            def __init__(self, context):
                super().__init__(context, OneVarLM)

        def make_trial_controller_fn(
            workloads: workload.Stream,
            checkpoint_dir: typing.Optional[str] = None,
            latest_checkpoint: typing.Optional[typing.Dict[str,
                                                           typing.Any]] = None,
            steps_completed: int = 0,
        ) -> det.TrialController:

            return utils.make_trial_controller_from_trial_implementation(
                trial_class=OneVarLA,
                hparams=self.hparams,
                workloads=workloads,
                trial_seed=self.trial_seed,
                checkpoint_dir=checkpoint_dir,
                latest_checkpoint=latest_checkpoint,
                steps_completed=steps_completed,
            )

        with pytest.raises(AssertionError):
            utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                                   tmp_path)
    def test_optimizer_state(self, tmp_path: Path,
                             xor_trial_controller: Callable) -> None:
        def make_trial_controller_fn(
                workloads: workload.Stream,
                load_path: Optional[str] = None) -> det.TrialController:
            hparams = {**self.hparams, "optimizer": "adam"}
            return xor_trial_controller(hparams,
                                        workloads,
                                        load_path=load_path)

        utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                               tmp_path)
예제 #5
0
    def test_optimizer_state(self, tmp_path: Path, xor_trial_controller: Callable) -> None:
        def make_trial_controller_fn(
            workloads: workload.Stream, load_path: Optional[str] = None
        ) -> det.TrialController:
            return xor_trial_controller(
                self.hparams,
                workloads,
                scheduling_unit=100,
                load_path=load_path,
                trial_seed=self.trial_seed,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn, tmp_path)
    def test_checkpointing_and_restoring(self, tmp_path: pathlib.Path) -> None:
        def make_trial_controller_fn(
                workloads: workload.Stream,
                load_path: typing.Optional[str] = None) -> det.TrialController:
            return utils.make_trial_controller_from_trial_implementation(
                trial_class=la_model.OneVarTrial,
                hparams=self.hparams,
                workloads=workloads,
                load_path=load_path,
                trial_seed=self.trial_seed,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                               tmp_path)
    def test_optimizer_state(self, tmp_path: Path,
                             xor_trial_controller: Callable) -> None:
        def make_trial_controller_fn(
            workloads: workload.Stream,
            checkpoint_dir: Optional[str] = None,
            latest_checkpoint: Optional[Dict[str, Any]] = None,
            steps_completed: int = 0,
        ) -> det.TrialController:
            hparams = {**self.hparams, "optimizer": "adam"}
            return xor_trial_controller(
                hparams,
                workloads,
                checkpoint_dir=checkpoint_dir,
                latest_checkpoint=latest_checkpoint,
                steps_completed=steps_completed,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                               tmp_path)
예제 #8
0
    def test_checkpointing_and_restoring(self, tmp_path: pathlib.Path) -> None:
        def make_trial_controller_fn(
            workloads: workload.Stream,
            checkpoint_dir: Optional[str] = None,
            latest_checkpoint: Optional[Dict[str, Any]] = None,
            steps_completed: int = 0,
        ) -> determined.TrialController:
            return utils.make_trial_controller_from_trial_implementation(
                trial_class=deepspeed_linear_model.LinearPipelineEngineTrial,
                hparams=self.hparams,
                workloads=workloads,
                trial_seed=self.trial_seed,
                checkpoint_dir=checkpoint_dir,
                latest_checkpoint=latest_checkpoint,
                steps_completed=steps_completed,
                expose_gpus=True,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn, tmp_path)
    def test_checkpointing_and_restoring(self, tmp_path: pathlib.Path) -> None:
        def make_trial_controller_fn(
            workloads: workload.Stream,
            checkpoint_dir: typing.Optional[str] = None,
            latest_checkpoint: typing.Optional[typing.Dict[str,
                                                           typing.Any]] = None,
            steps_completed: int = 0,
        ) -> det.TrialController:
            return utils.make_trial_controller_from_trial_implementation(
                trial_class=la_model.OneVarTrial,
                hparams=self.hparams,
                workloads=workloads,
                trial_seed=self.trial_seed,
                checkpoint_dir=checkpoint_dir,
                latest_checkpoint=latest_checkpoint,
                steps_completed=steps_completed,
            )

        utils.checkpointing_and_restoring_test(make_trial_controller_fn,
                                               tmp_path)