Exemplo n.º 1
0
    def test_grad_clipping(self) -> None:
        training_metrics = {}
        validation_metrics = {}

        def make_workloads(tag: str) -> workload.Stream:
            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=1000, validation_freq=100)
            tm, vm = trainer.result()
            training_metrics[tag] = tm
            validation_metrics[tag] = vm

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            hparams=self.hparams,
            workloads=make_workloads("original"),
            trial_seed=self.trial_seed,
        )
        controller.run()

        updated_hparams = {"clip_grad_l2_norm": 0.0001, **self.hparams}
        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            hparams=updated_hparams,
            workloads=make_workloads("clipped_by_norm"),
            trial_seed=self.trial_seed,
        )
        controller.run()

        for idx, (original, clipped) in enumerate(
                zip(training_metrics["original"],
                    training_metrics["clipped_by_norm"])):
            if idx < 10:
                continue
            assert original["loss"] != clipped["loss"]

        updated_hparams = {"clip_grad_val": 0.0001, **self.hparams}
        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            hparams=updated_hparams,
            workloads=make_workloads("clipped_by_val"),
            trial_seed=self.trial_seed,
        )
        controller.run()

        for idx, (original, clipped) in enumerate(
                zip(training_metrics["original"],
                    training_metrics["clipped_by_val"])):
            if idx < 10:
                continue
            assert original["loss"] != clipped["loss"]
Exemplo n.º 2
0
    def test_checkpointing(self, tmp_path: pathlib.Path) -> None:
        checkpoint_dir = tmp_path.joinpath("checkpoint")

        old_error = -1

        def make_workloads_1() -> workload.Stream:
            nonlocal old_error

            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=10)
            training_metrics, validation_metrics = trainer.result()
            old_error = validation_metrics[-1]["binary_error"]

            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            hparams=self.hparams,
            workloads=make_workloads_1(),
            trial_seed=self.trial_seed,
        )
        controller.run()

        # Restore the checkpoint on a new trial instance and recompute
        # validation. The validation error should be the same as it was
        # previously.
        def make_workloads_2() -> workload.Stream:
            interceptor = workload.WorkloadResponseInterceptor()

            yield from interceptor.send(workload.validation_workload(), [])
            metrics = interceptor.metrics_result()

            new_error = metrics["validation_metrics"]["binary_error"]
            assert new_error == pytest.approx(old_error)

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            hparams=self.hparams,
            workloads=make_workloads_2(),
            load_path=checkpoint_dir,
            trial_seed=self.trial_seed,
        )
        controller.run()
Exemplo n.º 3
0
    def test_xor_multi(self) -> None:
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=1000, validation_freq=100)
            training_metrics, validation_metrics = trainer.result()

            # We expect the validation error and training loss to be
            # monotonically decreasing.
            for older, newer in zip(training_metrics, training_metrics[1:]):
                assert newer["loss"] <= older["loss"]

            for older, newer in zip(validation_metrics,
                                    validation_metrics[1:]):
                assert newer["binary_error"] <= older["binary_error"]

            assert validation_metrics[-1]["binary_error"] == pytest.approx(0.0)

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            workloads=make_workloads(),
            hparams=self.hparams,
            trial_seed=self.trial_seed,
        )
        controller.run()
Exemplo n.º 4
0
 def controller_fn(workloads: workload.Stream) -> det.TrialController:
     return utils.make_trial_controller_from_trial_implementation(
         trial_class=pytorch_xor_model.XORTrial,
         hparams=self.hparams,
         workloads=workloads,
         trial_seed=self.trial_seed,
     )
Exemplo n.º 5
0
    def test_fail_restore_invalid_checkpoint(self,
                                             tmp_path: pathlib.Path) -> None:
        # Build, train, and save a checkpoint with the normal hyperparameters.
        checkpoint_dir = tmp_path.joinpath("checkpoint")

        def make_workloads_1() -> workload.Stream:
            trainer = utils.TrainAndValidate()
            yield from trainer.send(steps=1, validation_freq=1)
            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response
            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller1 = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialMulti,
            hparams=self.hparams,
            workloads=make_workloads_1(),
            trial_seed=self.trial_seed,
        )
        controller1.run()

        # Verify that an invalid architecture fails to load from the checkpoint.
        def make_workloads_2() -> workload.Stream:
            trainer = utils.TrainAndValidate()
            yield from trainer.send(steps=1, validation_freq=1)
            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response
            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        hparams2 = {
            "hidden_size": 3,
            "learning_rate": 0.5,
            "global_batch_size": 4
        }

        with pytest.raises(RuntimeError):
            controller2 = utils.make_trial_controller_from_trial_implementation(
                trial_class=pytorch_xor_model.XORTrialMulti,
                hparams=hparams2,
                workloads=make_workloads_2(),
                load_path=checkpoint_dir,
                trial_seed=self.trial_seed,
            )
            controller2.run()
Exemplo n.º 6
0
    def test_lr_schedule_and_lr_checkpoint(self,
                                           tmp_path: pathlib.Path) -> None:
        checkpoint_dir = tmp_path.joinpath("checkpoint")
        training_metrics = []

        def make_workloads(checkpoint_dir: str = "") -> workload.Stream:
            nonlocal training_metrics

            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10,
                                    validation_freq=10,
                                    batches_per_step=1)
            tm, _ = trainer.result()
            training_metrics += tm

            if checkpoint_dir:
                yield workload.checkpoint_workload(), [
                    checkpoint_dir
                ], workload.ignore_workload_response

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialRestoreLR,
            hparams=self.hparams,
            workloads=make_workloads(checkpoint_dir),
            trial_seed=self.trial_seed,
        )
        controller.run()

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialRestoreLR,
            hparams=self.hparams,
            workloads=make_workloads(),
            load_path=checkpoint_dir,
            trial_seed=self.trial_seed,
        )
        controller.run()

        lrs = [metric["lr"] for metric in training_metrics]
        for i in range(1, len(lrs)):
            assert lrs[i] == lrs[i - 1] + 1
Exemplo n.º 7
0
 def make_trial_controller_fn(
         workloads: workload.Stream,
         load_path: typing.Optional[str] = None) -> det.TrialController:
     return utils.make_trial_controller_from_trial_implementation(
         trial_class=pytorch_xor_model.XORTrialOptimizerState,
         hparams=self.hparams,
         workloads=workloads,
         load_path=load_path,
         trial_seed=self.trial_seed,
     )
Exemplo n.º 8
0
    def test_custom_eval(self) -> None:
        training_metrics = {}
        validation_metrics = {}

        def make_workloads(tag: str) -> workload.Stream:
            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=900, validation_freq=100)
            tm, vm = trainer.result()
            training_metrics[tag] = tm
            validation_metrics[tag] = vm

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrial,
            hparams=self.hparams,
            workloads=make_workloads("A"),
            trial_seed=self.trial_seed,
        )
        controller.run()

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialCustomEval,
            hparams=self.hparams,
            workloads=make_workloads("B"),
            trial_seed=self.trial_seed,
        )
        controller.run()

        for original, custom_eval in zip(training_metrics["A"],
                                         training_metrics["B"]):
            assert original["loss"] == custom_eval["loss"]

        for original, custom_eval in zip(validation_metrics["A"],
                                         validation_metrics["B"]):
            assert original["loss"] == custom_eval["loss"]
 def _xor_trial_controller(
     hparams: Dict[str, Any],
     workloads: workload.Stream,
     batches_per_step: int = 1,
     load_path: Optional[str] = None,
     trial_seed: int = 0,
 ) -> det.TrialController:
     return utils.make_trial_controller_from_trial_implementation(
         request.param,
         hparams,
         workloads,
         batches_per_step=batches_per_step,
         load_path=load_path,
         trial_seed=trial_seed,
     )
Exemplo n.º 10
0
    def test_per_metric_reducers(self) -> None:
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()
            yield from trainer.send(steps=2,
                                    validation_freq=1,
                                    batches_per_step=1)
            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialPerMetricReducers,
            hparams=self.hparams,
            workloads=make_workloads(),
            trial_seed=self.trial_seed,
        )
        controller.run()
Exemplo n.º 11
0
    def test_lr_schedule_step_epoch(self, tmp_path: pathlib.Path) -> None:
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()
            yield from trainer.send(steps=10,
                                    validation_freq=10,
                                    batches_per_step=1)
            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialStepEveryEpoch,
            hparams=self.hparams,
            workloads=make_workloads(),
            trial_seed=self.trial_seed,
        )
        controller.run()
Exemplo n.º 12
0
    def test_lr_schedule_user_modify_fail(self,
                                          tmp_path: pathlib.Path) -> None:
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()
            yield from trainer.send(steps=10,
                                    validation_freq=10,
                                    batches_per_step=1)
            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialUserStepLRFail,
            hparams=self.hparams,
            workloads=make_workloads(),
            trial_seed=self.trial_seed,
        )
        with pytest.raises(check.CheckFailedError):
            controller.run()
Exemplo n.º 13
0
    def test_xor_training_metrics(self) -> None:
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=10)
            training_metrics, validation_metrics = trainer.result()

            for metrics in training_metrics:
                assert "accuracy" in metrics

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class=pytorch_xor_model.XORTrialWithTrainingMetrics,
            hparams=self.hparams,
            workloads=make_workloads(),
            trial_seed=self.trial_seed,
        )
        controller.run()
Exemplo n.º 14
0
        def _xor_trial_controller(
            hparams: Dict[str, Any],
            workloads: workload.Stream,
            batches_per_step: int = 1,
            load_path: Optional[str] = None,
            exp_config: Optional[Dict] = None,
        ) -> det.TrialController:
            if request.param == estimator_xor_model.XORTrialDataLayer:
                exp_config = utils.make_default_exp_config(
                    hparams=hparams,
                    batches_per_step=batches_per_step,
                )
                exp_config["data"] = exp_config.get("data", {})
                exp_config["data"]["skip_checkpointing_input"] = True

            return utils.make_trial_controller_from_trial_implementation(
                trial_class=request.param,
                hparams=hparams,
                workloads=workloads,
                batches_per_step=batches_per_step,
                load_path=load_path,
                exp_config=exp_config,
            )
Exemplo n.º 15
0
    def test_one_var_training(self, test_checkpointing, tmp_path):
        checkpoint_dir = tmp_path.joinpath("checkpoint")

        # In the test_checkpointing case, we will call make_workloads() twice but batches and w
        # will persist across both calls.
        batches = enumerate([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]])
        w = 0.0

        trial_class = tf_keras_one_var_model.OneVarTrial

        def make_workloads() -> workload.Stream:
            nonlocal w
            interceptor = workload.WorkloadResponseInterceptor()

            for idx, batch in batches:
                yield from interceptor.send(workload.train_workload(1), [1])
                metrics = interceptor.metrics_result()

                # Calculate what the loss should be.
                loss = trial_class.calc_loss(w, batch)

                assert metrics["avg_metrics"]["loss"] == pytest.approx(loss)

                # Update what the weight should be.
                w = w - hparams["learning_rate"] * trial_class.calc_gradient(
                    w, batch)

                if test_checkpointing and idx == 3:
                    # Checkpoint and let the next TrialController finish the work.l
                    yield workload.checkpoint_workload(), [
                        checkpoint_dir
                    ], workload.ignore_workload_response
                    break

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response

        hparams = {
            "learning_rate": 0.001,
            "global_batch_size": 3,
            "dataset_range": 10
        }
        # TODO(DET-2436): Add a unit test for native implementation with tf dataset.
        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class,
            hparams,
            make_workloads(),
            trial_seed=self.trial_seed,
        )
        controller.run()

        # In the checkpointing case, we need to create another controller to finish training.
        if test_checkpointing:
            controller = utils.make_trial_controller_from_trial_implementation(
                trial_class,
                hparams,
                make_workloads(),
                load_path=checkpoint_dir,
                trial_seed=self.trial_seed,
            )
            controller.run()