示例#1
0
    def _control_loop(self) -> None:
        assert self.workloads is not None
        for wkld, response_func in self.workloads:
            logging.debug(f"Received wkld {wkld.kind}.")

            try:
                if wkld.kind == workload.Workload.Kind.RUN_STEP:
                    # Configure the state for a training step.
                    self.train_response_func = response_func
                    self.train_workload_batches = 0
                    self.train_workload_inputs = 0
                    self.train_workload_metrics = []
                    self.train_workload_len = wkld.num_batches
                    self.multiplexer.set_batches_requested(wkld.num_batches)
                    return

                elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    action = "validation"
                    response = {
                        "metrics": self._compute_validation_metrics(),
                        "stop_requested": self.context.get_stop_requested(),
                    }  # type: workload.Response

                elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    action = "checkpointing"
                    if self.is_chief:
                        metadata = {
                            "determined_version": det.__version__,
                            "steps_completed": self.steps_completed,
                            "framework": f"tensorflow-{tf.__version__}",
                            "format": "saved_weights",
                        }
                        with self.context._core.checkpoint.store_path(metadata) as (
                            path,
                            storage_id,
                        ):
                            self._save_checkpoint(path)
                        response = {"uuid": storage_id}
                    else:
                        response = {}

                else:
                    raise AssertionError(f"Unknown workload kind {wkld.kind}.")

            except det.InvalidHP as e:
                logging.info(f"Invalid hyperparameter exception during {action}: {e}")
                response = workload.InvalidHP()
            response_func(response)
            self.upload_tb_files()

        # End-of-training.
        self.multiplexer._corrected_train_end()
        raise det.errors.WorkerFinishedGracefully()
    def _run(self) -> None:
        assert self.workloads is not None
        for w, response_func in self.workloads:
            try:
                if w.kind == workload.Workload.Kind.RUN_STEP:
                    action = "training"
                    response = {
                        "metrics":
                        self._train_for_step(
                            w.step_id,
                            w.num_batches,
                            w.total_batches_processed,
                        ),
                        "stop_requested":
                        self.context.get_stop_requested(),
                    }  # type: workload.Response

                elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    action = "validation"
                    response = {
                        "metrics": self._compute_validation_metrics(),
                        "stop_requested": self.context.get_stop_requested(),
                    }

                elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    action = "checkpointing"
                    if self.is_chief:
                        metadata = {
                            "steps_completed": self.steps_completed,
                            "framework": f"torch-{torch.__version__}",
                            "format": "pickle",
                        }
                        with self.context._core.checkpoint.store_path(
                                metadata) as (
                                    path,
                                    storage_id,
                                ):
                            self._save(path)
                        response = {"uuid": storage_id}
                    else:
                        response = {}

                else:
                    raise AssertionError("Unexpected workload: {}".format(
                        w.kind))

            except det.InvalidHP as e:
                logging.info(
                    f"Invalid hyperparameter exception during {action}: {e}")
                response = workload.InvalidHP()
            response_func(response)
示例#3
0
    def control_loop(self) -> None:
        core = self.estimator_trial_controller.context._core

        assert self.estimator_trial_controller.workloads is not None
        for wkld, response_func in self.estimator_trial_controller.workloads:
            logging.debug(f"Received wkld {wkld.kind}.")

            try:
                if wkld.kind == workload.Workload.Kind.RUN_STEP:
                    # Store values for the training loop.
                    self.num_batches = wkld.num_batches
                    self.train_response_func = response_func
                    # Break out of the control loop so that the train process
                    # re-enters the train_and_evaluate() loop.
                    return

                elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    action = "validation"
                    response = {
                        "metrics": self._compute_validation_metrics(),
                        "stop_requested": (
                            self.estimator_trial_controller.context.get_stop_requested()
                        ),
                    }  # type: workload.Response

                elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    action = "checkpointing"
                    self._save_model()
                    if self.estimator_trial_controller.is_chief:
                        metadata = {
                            "steps_completed": self.steps_completed,
                            "framework": f"tensorflow-{tf.__version__}",
                            "format": "saved_model",
                        }
                        with core.checkpoint.store_path(metadata) as (path, storage_id):
                            self._checkpoint_model(path)
                        response = {"uuid": storage_id}
                    else:
                        response = {}

                else:
                    raise AssertionError(f"Unknown wkld kind {wkld.kind}.")

            except det.InvalidHP as e:
                logging.info(f"Invalid hyperparameter exception during {action}: {e}")
                response = workload.InvalidHP()

            response_func(response)

        # End-of-training.
        raise det.errors.WorkerFinishedGracefully("Exiting normally.")
示例#4
0
    def _run(self) -> None:
        assert self.workloads is not None
        for w, response_func in self.workloads:
            try:
                if w.kind == workload.Workload.Kind.RUN_STEP:
                    action = "training"
                    response = {
                        "metrics":
                        self._train_for_step(
                            w.step_id,
                            w.num_batches,
                            w.total_batches_processed,
                        ),
                        "stop_requested":
                        self.context.get_stop_requested(),
                    }  # type: workload.Response
                elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS:
                    action = "validation"
                    response = {
                        "metrics": self._compute_validation_metrics(),
                        "stop_requested": self.context.get_stop_requested(),
                    }
                elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL:
                    action = "checkpointing"
                    # The checkpointing api would have been sufficient if the base_path for the
                    # storage manager is guaranteed to be a shared file system.
                    #
                    # Since we can't guarantee that, we use the base storage_manager instead for
                    # more flexibility.  Since checkpoints can be distributed across multiple
                    # nodes, we will use the same uuid and separate path but each node
                    # will upload its checkpoints to the storage manager individually.
                    storage_manager = self.context._core.checkpoint._storage_manager
                    if self.is_chief:
                        metadata = {
                            "steps_completed": self.steps_completed,
                            "framework": f"torch-{torch.__version__}",
                            "format": "pickle",
                        }
                        storage_id = str(uuid.uuid4())
                        with storage_manager.store_path(storage_id) as path:
                            # Broadcast checkpoint path to all ranks.
                            self.context.distributed.broadcast(
                                (storage_id, path))
                            self._save(path)
                            # Gather resources across nodes.
                            all_resources = self.context.distributed.gather(
                                storage.StorageManager._list_directory(path))
                        resources = {
                            k: v
                            for d in all_resources for k, v in d.items()
                        }

                        self.context._core.checkpoint._report_checkpoint(
                            storage_id, resources, metadata)
                        response = {"uuid": storage_id}
                    else:
                        storage_id, path = self.context.distributed.broadcast(
                            None)
                        self._save(path)
                        # Gather resources across nodes.
                        _ = self.context.distributed.gather(
                            storage.StorageManager._list_directory(path))
                        if self.context.distributed.local_rank == 0:
                            storage_manager.post_store_path(
                                str(path), storage_id)
                        response = {}

                else:
                    raise AssertionError("Unexpected workload: {}".format(
                        w.kind))

            except det.InvalidHP as e:
                logging.info(
                    f"Invalid hyperparameter exception during {action}: {e}")
                response = workload.InvalidHP()
            response_func(response)
            self.upload_tb_files()