def _control_loop(self) -> None: assert self.workloads is not None for wkld, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind}.") try: if wkld.kind == workload.Workload.Kind.RUN_STEP: # Configure the state for a training step. self.train_response_func = response_func self.train_workload_batches = 0 self.train_workload_inputs = 0 self.train_workload_metrics = [] self.train_workload_len = wkld.num_batches self.multiplexer.set_batches_requested(wkld.num_batches) return elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: action = "validation" response = { "metrics": self._compute_validation_metrics(), "stop_requested": self.context.get_stop_requested(), } # type: workload.Response elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: action = "checkpointing" if self.is_chief: metadata = { "determined_version": det.__version__, "steps_completed": self.steps_completed, "framework": f"tensorflow-{tf.__version__}", "format": "saved_weights", } with self.context._core.checkpoint.store_path(metadata) as ( path, storage_id, ): self._save_checkpoint(path) response = {"uuid": storage_id} else: response = {} else: raise AssertionError(f"Unknown workload kind {wkld.kind}.") except det.InvalidHP as e: logging.info(f"Invalid hyperparameter exception during {action}: {e}") response = workload.InvalidHP() response_func(response) self.upload_tb_files() # End-of-training. self.multiplexer._corrected_train_end() raise det.errors.WorkerFinishedGracefully()
def _run(self) -> None: assert self.workloads is not None for w, response_func in self.workloads: try: if w.kind == workload.Workload.Kind.RUN_STEP: action = "training" response = { "metrics": self._train_for_step( w.step_id, w.num_batches, w.total_batches_processed, ), "stop_requested": self.context.get_stop_requested(), } # type: workload.Response elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: action = "validation" response = { "metrics": self._compute_validation_metrics(), "stop_requested": self.context.get_stop_requested(), } elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: action = "checkpointing" if self.is_chief: metadata = { "steps_completed": self.steps_completed, "framework": f"torch-{torch.__version__}", "format": "pickle", } with self.context._core.checkpoint.store_path( metadata) as ( path, storage_id, ): self._save(path) response = {"uuid": storage_id} else: response = {} else: raise AssertionError("Unexpected workload: {}".format( w.kind)) except det.InvalidHP as e: logging.info( f"Invalid hyperparameter exception during {action}: {e}") response = workload.InvalidHP() response_func(response)
def control_loop(self) -> None: core = self.estimator_trial_controller.context._core assert self.estimator_trial_controller.workloads is not None for wkld, response_func in self.estimator_trial_controller.workloads: logging.debug(f"Received wkld {wkld.kind}.") try: if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store values for the training loop. self.num_batches = wkld.num_batches self.train_response_func = response_func # Break out of the control loop so that the train process # re-enters the train_and_evaluate() loop. return elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: action = "validation" response = { "metrics": self._compute_validation_metrics(), "stop_requested": ( self.estimator_trial_controller.context.get_stop_requested() ), } # type: workload.Response elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: action = "checkpointing" self._save_model() if self.estimator_trial_controller.is_chief: metadata = { "steps_completed": self.steps_completed, "framework": f"tensorflow-{tf.__version__}", "format": "saved_model", } with core.checkpoint.store_path(metadata) as (path, storage_id): self._checkpoint_model(path) response = {"uuid": storage_id} else: response = {} else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.") except det.InvalidHP as e: logging.info(f"Invalid hyperparameter exception during {action}: {e}") response = workload.InvalidHP() response_func(response) # End-of-training. raise det.errors.WorkerFinishedGracefully("Exiting normally.")
def _run(self) -> None: assert self.workloads is not None for w, response_func in self.workloads: try: if w.kind == workload.Workload.Kind.RUN_STEP: action = "training" response = { "metrics": self._train_for_step( w.step_id, w.num_batches, w.total_batches_processed, ), "stop_requested": self.context.get_stop_requested(), } # type: workload.Response elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: action = "validation" response = { "metrics": self._compute_validation_metrics(), "stop_requested": self.context.get_stop_requested(), } elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: action = "checkpointing" # The checkpointing api would have been sufficient if the base_path for the # storage manager is guaranteed to be a shared file system. # # Since we can't guarantee that, we use the base storage_manager instead for # more flexibility. Since checkpoints can be distributed across multiple # nodes, we will use the same uuid and separate path but each node # will upload its checkpoints to the storage manager individually. storage_manager = self.context._core.checkpoint._storage_manager if self.is_chief: metadata = { "steps_completed": self.steps_completed, "framework": f"torch-{torch.__version__}", "format": "pickle", } storage_id = str(uuid.uuid4()) with storage_manager.store_path(storage_id) as path: # Broadcast checkpoint path to all ranks. self.context.distributed.broadcast( (storage_id, path)) self._save(path) # Gather resources across nodes. all_resources = self.context.distributed.gather( storage.StorageManager._list_directory(path)) resources = { k: v for d in all_resources for k, v in d.items() } self.context._core.checkpoint._report_checkpoint( storage_id, resources, metadata) response = {"uuid": storage_id} else: storage_id, path = self.context.distributed.broadcast( None) self._save(path) # Gather resources across nodes. _ = self.context.distributed.gather( storage.StorageManager._list_directory(path)) if self.context.distributed.local_rank == 0: storage_manager.post_store_path( str(path), storage_id) response = {} else: raise AssertionError("Unexpected workload: {}".format( w.kind)) except det.InvalidHP as e: logging.info( f"Invalid hyperparameter exception during {action}: {e}") response = workload.InvalidHP() response_func(response) self.upload_tb_files()