Ejemplo n.º 1
0
        def make_workloads() -> workload.Stream:
            nonlocal w
            interceptor = workload.WorkloadResponseInterceptor()

            for idx, batch in batches:
                yield from interceptor.send(workload.train_workload(1))
                metrics = interceptor.metrics_result()

                # Calculate what the loss should be.
                loss = trial_class.calc_loss(w, batch)

                epsilon = 0.0001
                assert abs(metrics["metrics"]["avg_metrics"]["loss"] -
                           loss) < epsilon

                # Update what the weight should be.
                w = w - hparams["learning_rate"] * trial_class.calc_gradient(
                    w, batch)

                if test_checkpointing and idx == 3:
                    # Checkpoint and let the next TrialController finish the work.
                    interceptor = workload.WorkloadResponseInterceptor()
                    yield from interceptor.send(workload.checkpoint_workload())
                    nonlocal latest_checkpoint, steps_completed
                    latest_checkpoint = interceptor.metrics_result()["uuid"]
                    # steps_completed is unused, but can't be 0.
                    steps_completed = 1
                    break
Ejemplo n.º 2
0
        def make_workloads1() -> workload.Stream:
            nonlocal controller

            yield workload.train_workload(
                1, 1, 0), [], workload.ignore_workload_response
            assert controller is not None, "controller was never set!"
            assert controller.trial.counter.__dict__ == {
                "validation_steps_started": 0,
                "validation_steps_ended": 0,
                "checkpoints_ended": 0,
            }

            yield workload.validation_workload(
            ), [], workload.ignore_workload_response
            assert controller.trial.counter.__dict__ == {
                "validation_steps_started": 1,
                "validation_steps_ended": 1,
                "checkpoints_ended": 0,
            }

            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response
            assert controller.trial.counter.__dict__ == {
                "validation_steps_started": 1,
                "validation_steps_ended": 1,
                "checkpoints_ended": 1,
            }

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response
Ejemplo n.º 3
0
 def make_workloads_2() -> workload.Stream:
     trainer = utils.TrainAndValidate()
     yield from trainer.send(steps=1, validation_freq=1)
     yield workload.checkpoint_workload(), [
         checkpoint_dir
     ], workload.ignore_workload_response
     yield workload.terminate_workload(), [], workload.ignore_workload_response
Ejemplo n.º 4
0
def make_test_workloads(
    checkpoint_dir: pathlib.Path, config: det.ExperimentConfig
) -> workload.Stream:
    print("Start training a test experiment.")
    interceptor = workload.WorkloadResponseInterceptor()

    print("Training 1 step.")
    yield from interceptor.send(workload.train_workload(1), [config.batches_per_step()])
    metrics = interceptor.metrics_result()
    batch_metrics = metrics["batch_metrics"]
    check.eq(len(batch_metrics), config.batches_per_step())
    print(f"Finished training. Metrics: {batch_metrics}")

    print("Validating.")
    yield from interceptor.send(workload.validation_workload(1), [])
    validation = interceptor.metrics_result()
    v_metrics = validation["validation_metrics"]
    print(f"Finished validating. Validation metrics: {v_metrics}")

    print(f"Saving a checkpoint to {checkpoint_dir}")
    yield workload.checkpoint_workload(), [checkpoint_dir], workload.ignore_workload_response
    print(f"Finished saving a checkpoint to {checkpoint_dir}.")

    yield workload.terminate_workload(), [], workload.ignore_workload_response
    print("The test experiment passed.")
Ejemplo n.º 5
0
        def make_workloads() -> workload.Stream:
            nonlocal w
            interceptor = workload.WorkloadResponseInterceptor()

            for idx, batch in batches:
                yield from interceptor.send(workload.train_workload(1), [1])
                metrics = interceptor.metrics_result()

                # Calculate what the loss should be.
                loss = trial_class.calc_loss(w, batch)

                assert metrics["avg_metrics"]["loss"] == pytest.approx(loss)

                # Update what the weight should be.
                w = w - hparams["learning_rate"] * trial_class.calc_gradient(
                    w, batch)

                if test_checkpointing and idx == 3:
                    # Checkpoint and let the next TrialController finish the work.l
                    yield workload.checkpoint_workload(), [
                        checkpoint_dir
                    ], workload.ignore_workload_response
                    break

            yield workload.terminate_workload(
            ), [], workload.ignore_workload_response
 def make_workloads_1() -> workload.Stream:
     trainer = utils.TrainAndValidate()
     yield from trainer.send(steps=1, validation_freq=1)
     interceptor = workload.WorkloadResponseInterceptor()
     yield from interceptor.send(workload.checkpoint_workload())
     nonlocal latest_checkpoint, steps_completed
     latest_checkpoint = interceptor.metrics_result()["uuid"]
     steps_completed = trainer.get_steps_completed()
Ejemplo n.º 7
0
        def make_workloads(checkpoint_dir: pathlib.Path) -> workload.Stream:
            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=5, batches_per_step=5)
            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response
            yield workload.terminate_workload(), [], workload.ignore_workload_response
        def make_workloads() -> workload.Stream:
            trainer = utils.TrainAndValidate()
            yield from trainer.send(steps=1,
                                    validation_freq=1,
                                    scheduling_unit=10)

            interceptor = workload.WorkloadResponseInterceptor()
            yield from interceptor.send(workload.checkpoint_workload())
            nonlocal latest_checkpoint
            latest_checkpoint = interceptor.metrics_result()["uuid"]
        def make_workloads1() -> workload.Stream:
            nonlocal controller
            assert controller.trial.counter.trial_startups == 1

            yield workload.train_workload(1, 1, 0,
                                          4), workload.ignore_workload_response
            assert controller is not None, "controller was never set!"
            assert controller.trial.counter.__dict__ == {
                "trial_startups": 1,
                "validation_steps_started": 0,
                "validation_steps_ended": 0,
                "checkpoints_ended": 0,
                "training_started_times": 1,
                "training_epochs_started": 2,
                "training_epochs_ended": 2,
                "trial_shutdowns": 0,
            }
            assert controller.trial.legacy_counter.__dict__ == {
                "legacy_on_training_epochs_start_calls": 2
            }

            yield workload.validation_workload(
            ), workload.ignore_workload_response
            assert controller.trial.counter.__dict__ == {
                "trial_startups": 1,
                "validation_steps_started": 1,
                "validation_steps_ended": 1,
                "checkpoints_ended": 0,
                "training_started_times": 1,
                "training_epochs_started": 2,
                "training_epochs_ended": 2,
                "trial_shutdowns": 0,
            }
            assert controller.trial.legacy_counter.__dict__ == {
                "legacy_on_training_epochs_start_calls": 2
            }

            interceptor = workload.WorkloadResponseInterceptor()
            yield from interceptor.send(workload.checkpoint_workload())
            nonlocal latest_checkpoint, steps_completed
            latest_checkpoint = interceptor.metrics_result()["uuid"]
            steps_completed = 1
            assert controller.trial.counter.__dict__ == {
                "trial_startups": 1,
                "validation_steps_started": 1,
                "validation_steps_ended": 1,
                "checkpoints_ended": 1,
                "training_started_times": 1,
                "training_epochs_started": 2,
                "training_epochs_ended": 2,
                "trial_shutdowns": 0,
            }
            assert controller.trial.legacy_counter.__dict__ == {
                "legacy_on_training_epochs_start_calls": 2
            }
Ejemplo n.º 10
0
        def make_workloads_1() -> workload.Stream:
            nonlocal old_loss

            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=10)
            training_metrics, validation_metrics = trainer.result()
            old_loss = validation_metrics[-1]["val_loss"]

            interceptor = workload.WorkloadResponseInterceptor()
            yield from interceptor.send(workload.checkpoint_workload())
            nonlocal latest_checkpoint, steps_completed
            latest_checkpoint = interceptor.metrics_result()["uuid"]
            steps_completed = trainer.get_steps_completed()
Ejemplo n.º 11
0
        def make_workloads_1() -> workload.Stream:
            nonlocal old_loss

            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=10)
            training_metrics, validation_metrics = trainer.result()
            old_loss = validation_metrics[-1]["val_loss"]

            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response

            yield workload.terminate_workload(), [], workload.ignore_workload_response
Ejemplo n.º 12
0
    def make_workloads(steps: int, tag: str, checkpoint: bool) -> workload.Stream:
        trainer = TrainAndValidate()

        yield from trainer.send(steps, validation_freq=1, scheduling_unit=100)
        tm, vm = trainer.result()
        training_metrics[tag] += tm
        validation_metrics[tag] += vm

        if checkpoint is not None:
            interceptor = workload.WorkloadResponseInterceptor()
            yield from interceptor.send(workload.checkpoint_workload())
            nonlocal latest_checkpoint, steps_completed
            latest_checkpoint = interceptor.metrics_result()["uuid"]
            steps_completed = trainer.get_steps_completed()
Ejemplo n.º 13
0
        def make_workloads(checkpoint_dir: str = "") -> workload.Stream:
            nonlocal training_metrics

            trainer = utils.TrainAndValidate()

            yield from trainer.send(steps=10, validation_freq=10, batches_per_step=1)
            tm, _ = trainer.result()
            training_metrics += tm

            if checkpoint_dir:
                yield workload.checkpoint_workload(), [
                    checkpoint_dir
                ], workload.ignore_workload_response

            yield workload.terminate_workload(), [], workload.ignore_workload_response
Ejemplo n.º 14
0
    def make_workloads(
        steps: int, tag: str, checkpoint_dir: Optional[pathlib.Path] = None
    ) -> workload.Stream:
        trainer = TrainAndValidate()

        yield from trainer.send(steps, validation_freq=1, batches_per_step=100)
        tm, vm = trainer.result()
        training_metrics[tag] += tm
        validation_metrics[tag] += vm

        if checkpoint_dir is not None:
            yield workload.checkpoint_workload(), [
                checkpoint_dir
            ], workload.ignore_workload_response

        yield workload.terminate_workload(), [], workload.ignore_workload_response
Ejemplo n.º 15
0
def _make_test_workloads(config: det.ExperimentConfig) -> workload.Stream:
    interceptor = workload.WorkloadResponseInterceptor()

    logging.info("Training one batch")
    yield from interceptor.send(workload.train_workload(1))
    metrics = interceptor.metrics_result()
    batch_metrics = metrics["metrics"]["batch_metrics"]
    check.eq(len(batch_metrics), config.scheduling_unit())
    logging.info(f"Finished training, metrics: {batch_metrics}")

    logging.info("Validating one batch")
    yield from interceptor.send(workload.validation_workload(1))
    validation = interceptor.metrics_result()
    v_metrics = validation["metrics"]["validation_metrics"]
    logging.info(f"Finished validating, validation metrics: {v_metrics}")

    logging.info("Saving a checkpoint.")
    yield workload.checkpoint_workload(), workload.ignore_workload_response
    logging.info("Finished saving a checkpoint.")
Ejemplo n.º 16
0
def _make_test_workloads(checkpoint_dir: pathlib.Path,
                         config: det.ExperimentConfig) -> workload.Stream:
    interceptor = workload.WorkloadResponseInterceptor()

    logging.info("Training one batch")
    yield from interceptor.send(workload.train_workload(1), [])
    metrics = interceptor.metrics_result()
    batch_metrics = metrics["metrics"]["batch_metrics"]
    check.eq(len(batch_metrics), config.scheduling_unit())
    logging.debug(f"Finished training, metrics: {batch_metrics}")

    logging.info("Validating one step")
    yield from interceptor.send(workload.validation_workload(1), [])
    validation = interceptor.metrics_result()
    v_metrics = validation["metrics"]["validation_metrics"]
    logging.debug(f"Finished validating, validation metrics: {v_metrics}")

    logging.info(f"Saving a checkpoint to {checkpoint_dir}.")
    yield workload.checkpoint_workload(), [checkpoint_dir
                                           ], workload.ignore_workload_response
    logging.info(f"Finished saving a checkpoint to {checkpoint_dir}.")

    yield workload.terminate_workload(), [], workload.ignore_workload_response
    logging.info("The test experiment passed.")
Ejemplo n.º 17
0
 def make_workloads() -> workload.Stream:
     yield workload.train_workload(
         1, num_batches=100), [], workload.ignore_workload_response
     yield workload.checkpoint_workload(), [], checkpoint_response_func