示例#1
0
    def test_callbacks(self):
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=15,
                                    validation_freq=4,
                                    scheduling_unit=5)
            training_metrics, validation_metrics = trainer.result()

        hparams = {
            "learning_rate": 0.001,
            "global_batch_size": 3,
            "dataset_range": 10,
            # 15 steps * 5 batches per step * 3 records per batch // 12 records per epoch
            "epochs": 15 * 5 * 3 // 12,
            # steps // validation_freq
            "validations": 3,
        }
        exp_config = utils.make_default_exp_config(hparams,
                                                   scheduling_unit=100,
                                                   searcher_metric="val_loss")
        exp_config["records_per_epoch"] = 12

        controller = utils.make_trial_controller_from_trial_implementation(
            tf_keras_one_var_model.OneVarTrial,
            hparams,
            make_workloads(),
            exp_config=exp_config,
        )
        controller.run()
    def test_one_var_training(self, test_checkpointing, tmp_path):
        checkpoint_dir = tmp_path.joinpath("checkpoint")

        # In the test_checkpointing case, we will call make_workloads() twice but batches and w
        # will persist across both calls.
        batches = enumerate([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]])
        w = 0.0

        trial_class = tf_keras_one_var_model.OneVarTrial

        def make_workloads() -> workload.Stream:
            nonlocal w
            interceptor = workload.WorkloadResponseInterceptor()

            for idx, batch in batches:
                yield from interceptor.send(workload.train_workload(1), [])
                metrics = interceptor.metrics_result()

                # Calculate what the loss should be.
                loss = trial_class.calc_loss(w, batch)

                epsilon = 0.0001
                assert abs(metrics["metrics"]["avg_metrics"]["loss"] - loss) < epsilon

                # Update what the weight should be.
                w = w - hparams["learning_rate"] * trial_class.calc_gradient(w, batch)

                if test_checkpointing and idx == 3:
                    # Checkpoint and let the next TrialController finish the work.l
                    yield workload.checkpoint_workload(), [
                        checkpoint_dir
                    ], workload.ignore_workload_response
                    break

            yield workload.terminate_workload(), [], workload.ignore_workload_response

        hparams = {"learning_rate": 0.001, "global_batch_size": 3, "dataset_range": 10}
        exp_config = utils.make_default_exp_config(hparams, scheduling_unit=100)
        exp_config["records_per_epoch"] = 100
        # TODO(DET-2436): Add a unit test for native implementation with tf dataset.
        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class,
            hparams,
            make_workloads(),
            exp_config=exp_config,
            trial_seed=self.trial_seed,
        )
        controller.run()

        # In the checkpointing case, we need to create another controller to finish training.
        if test_checkpointing:
            controller = utils.make_trial_controller_from_trial_implementation(
                trial_class,
                hparams,
                make_workloads(),
                exp_config=exp_config,
                load_path=checkpoint_dir,
                trial_seed=self.trial_seed,
            )
            controller.run()
示例#3
0
def make_mock_cluster_info(container_addrs: List[str], container_rank: int,
                           num_slots: int) -> det.ClusterInfo:
    config = utils.make_default_exp_config({}, 100, "loss", None)
    trial_info_mock = det.TrialInfo(
        trial_id=1,
        experiment_id=1,
        trial_seed=0,
        hparams={},
        config=config,
        steps_completed=0,
        trial_run_id=0,
        debug=False,
        unique_port_offset=0,
        inter_node_network_interface=None,
    )
    rendezvous_info_mock = det.RendezvousInfo(container_addrs=container_addrs,
                                              container_rank=container_rank)
    cluster_info_mock = det.ClusterInfo(
        master_url="localhost",
        cluster_id="clusterId",
        agent_id="agentId",
        slot_ids=list(range(num_slots)),
        task_id="taskId",
        allocation_id="allocationId",
        session_token="sessionToken",
        task_type="TRIAL",
        rendezvous_info=rendezvous_info_mock,
        trial_info=trial_info_mock,
    )
    return cluster_info_mock
示例#4
0
def test_reject_nonscalar_searcher_metric() -> None:
    metric_name = "validation_error"

    hparams = {"global_batch_size": 64}
    experiment_config = utils.make_default_exp_config(hparams, 1)
    experiment_config["searcher"] = {"metric": metric_name}
    env = utils.make_default_env_context(hparams=hparams,
                                         experiment_config=experiment_config)
    rendezvous_info = utils.make_default_rendezvous_info()
    storage_manager = NoopStorageManager(os.devnull)
    tensorboard_manager = NoopTensorboardManager()
    metric_writer = NoopBatchMetricWriter()

    def make_workloads() -> workload.Stream:
        yield workload.train_workload(
            1, num_batches=100), [], workload.ignore_workload_response
        yield workload.validation_workload(
        ), [], workload.ignore_workload_response

    # Normal Python numbers and NumPy scalars are acceptable; other values are not.
    cases = [
        (True, 17),
        (True, 0.17),
        (True, np.float64(0.17)),
        (True, np.float32(0.17)),
        (False, "foo"),
        (False, [0.17]),
        (False, {}),
    ]
    for is_valid, metric_value in cases:
        workload_manager = layers.build_workload_manager(
            env,
            make_workloads(),
            rendezvous_info,
            storage_manager,
            tensorboard_manager,
            metric_writer,
        )

        trial_controller = NoopTrialController(
            iter(workload_manager),
            validation_metrics={metric_name: metric_value})
        if is_valid:
            trial_controller.run()
        else:
            with pytest.raises(AssertionError, match="non-scalar"):
                trial_controller.run()
示例#5
0
        def _xor_trial_controller(
            hparams: Dict[str, Any],
            workloads: workload.Stream,
            batches_per_step: int = 1,
            load_path: Optional[str] = None,
            exp_config: Optional[Dict] = None,
        ) -> det.TrialController:
            if request.param == estimator_xor_model.XORTrialDataLayer:
                exp_config = utils.make_default_exp_config(
                    hparams=hparams, batches_per_step=batches_per_step)
                exp_config["data"] = exp_config.get("data", {})
                exp_config["data"]["skip_checkpointing_input"] = True

            return utils.make_trial_controller_from_trial_implementation(
                trial_class=request.param,
                hparams=hparams,
                workloads=workloads,
                batches_per_step=batches_per_step,
                load_path=load_path,
                exp_config=exp_config,
            )
示例#6
0
    def test_one_var_training(self, test_checkpointing, tmp_path):
        checkpoint_dir = str(tmp_path.joinpath("checkpoint"))
        latest_checkpoint = None
        steps_completed = 0

        # In the test_checkpointing case, we will call make_workloads() twice but batches and w
        # will persist across both calls.
        batches = enumerate([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]])
        w = 0.0

        trial_class = tf_keras_one_var_model.OneVarTrial

        def make_workloads() -> workload.Stream:
            nonlocal w
            interceptor = workload.WorkloadResponseInterceptor()

            for idx, batch in batches:
                yield from interceptor.send(workload.train_workload(1))
                metrics = interceptor.metrics_result()

                # Calculate what the loss should be.
                loss = trial_class.calc_loss(w, batch)

                epsilon = 0.0001
                assert abs(metrics["metrics"]["avg_metrics"]["loss"] -
                           loss) < epsilon

                # Update what the weight should be.
                w = w - hparams["learning_rate"] * trial_class.calc_gradient(
                    w, batch)

                if test_checkpointing and idx == 3:
                    # Checkpoint and let the next TrialController finish the work.
                    interceptor = workload.WorkloadResponseInterceptor()
                    yield from interceptor.send(workload.checkpoint_workload())
                    nonlocal latest_checkpoint, steps_completed
                    latest_checkpoint = interceptor.metrics_result()["uuid"]
                    # steps_completed is unused, but can't be 0.
                    steps_completed = 1
                    break

        hparams = {
            "learning_rate": 0.001,
            "global_batch_size": 3,
            "dataset_range": 10
        }
        exp_config = utils.make_default_exp_config(
            hparams,
            scheduling_unit=100,
            searcher_metric=trial_class._searcher_metric)
        exp_config["records_per_epoch"] = 100
        controller = utils.make_trial_controller_from_trial_implementation(
            trial_class,
            hparams,
            make_workloads(),
            exp_config=exp_config,
            trial_seed=self.trial_seed,
            checkpoint_dir=checkpoint_dir,
        )
        controller.run()

        # In the checkpointing case, we need to create another controller to finish training.
        if test_checkpointing:
            controller = utils.make_trial_controller_from_trial_implementation(
                trial_class,
                hparams,
                make_workloads(),
                exp_config=exp_config,
                trial_seed=self.trial_seed,
                checkpoint_dir=checkpoint_dir,
                latest_checkpoint=latest_checkpoint,
                steps_completed=steps_completed,
            )
            controller.run()