def run(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Store the train_response_func for later. self.train_response_func = response_func # There are two possibilities when a RUN_STEP workload is recieved. # 1) This is the first training step seen by the trial # container. In this case, enter the tf.keras fit() training loop. # 2) This is _not_ the first training step, meaning that the # tf.keras fit() training loop is already active and paused. # break to re-enter the training loop. if not self.fit_loop_started: try: self._launch_fit() except det.errors.WorkerFinishedGracefully: pass if not self.expect_terminate: raise AssertionError( "Training loop exited unexpectedly but without throwing any errors. " "This is possibly due to a user callback causing the training loop to " "exit, which is not supported at this time.") break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics(self.compute_validation_metrics(), self.context.get_stop_requested())) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: self.model.stop_training = True self.expect_terminate = True response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError(f"Unknown wkld kind {wkld.kind}.")
def run(self) -> None: try: tf.estimator.train_and_evaluate(self.estimator, self.train_spec, self.eval_spec) except det.errors.WorkerFinishedGracefully: pass else: raise AssertionError( "Training loop exited unexpectedly but without throwing any errors. This is " "possibly due to either setting train_spec.max_steps to a non-None value or due to " "a user callback causing the training loop to exit, which is not supported at this " "time." ) finally: for callback in self.train_hooks: if isinstance(callback, estimator.RunHook): callback.on_trial_close() if self.exit_response_func: self.exit_response_func({} if self.is_chief else workload.Skipped())
def yield_checkpoint_model( self, wkld: workload.Workload, respond: workload.ResponseFunc) -> workload.Stream: start_time = _current_timestamp() # Only the chief container should checkpoint. if self.rendezvous_info.get_rank() != 0: respond(workload.Skipped()) return # Save the workload completed message for after checkpoint upload completes. message = None # type: Optional[workload.Response] def _respond(checkpoint_info: workload.Response) -> None: checkpoint_info = cast(Dict[str, Any], checkpoint_info) metadata = storage.StorageMetadata( storage_id, storage.StorageManager._list_directory(path), checkpoint_info.get("framework", ""), checkpoint_info.get("format", ""), ) logging.info("Saved trial to checkpoint {}".format( metadata.storage_id)) self.tensorboard_mgr.sync() nonlocal message message = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metadata, } with self.storage_mgr.store_path() as (storage_id, path): yield wkld, [pathlib.Path(path)], _respond # Because the messaging is synchronous, the layer below us must have called _respond. check_not_none(message, "response function did not get called") message = cast(workload.Response, message) respond(message)
def on_train_batch_end(self, _: int, logs: Any = None) -> None: check.is_in("loss", logs) # Remove default keras metrics we aren't interested in like "batch" and # "size". self.metrics.append({k: v for k, v in logs.items() if k not in {"batch", "size"}}) self.batches_processed += 1 if self.batches_processed != self.tf_keras_trial_controller.num_batches: return check.is_not_none( self.tf_keras_trial_controller.train_response_func, "Callback should avoid calling model.predict(), " "as this will affect Determined training behavior", ) response_func = cast( workload.ResponseFunc, self.tf_keras_trial_controller.train_response_func ) # TODO(DET-1278): Average training metrics across GPUs when using Horovod. num_inputs = ( self.tf_keras_trial_controller.num_batches * self.tf_keras_trial_controller.batch_size ) if self.tf_keras_trial_controller.is_chief: response = { "metrics": det.util.make_metrics(num_inputs, self.metrics), "stop_requested": self.tf_keras_trial_controller.context.get_stop_requested(), } response_func(response) else: response_func(workload.Skipped()) self.tf_keras_trial_controller.train_response_func = None self.metrics = [] self.batches_processed = 0 self.tf_keras_trial_controller.run() if self.model.stop_training and version.parse(tf.__version__) >= version.parse("2.2.0"): # Starting with TF 2.2, `model.stop_training` is only checked at the end of epochs. raise det.errors.WorkerFinishedGracefully
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: logging.debug(f"Received wkld {wkld.kind} with args {args}.") if wkld.kind == workload.Workload.Kind.RUN_STEP: # Configure the state for a training step. self.train_response_func = response_func self.train_workload_batches = 0 self.train_workload_metrics = [] self.train_workload_len = wkld.num_batches self.multiplexer.set_batches_requested(wkld.num_batches) break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: try: response_func( det.util.wrap_metrics( self._compute_validation_metrics(), self.context.get_stop_requested(), invalid_hp=False, init_invalid_hp=False, )) except det.InvalidHP as e: logging.info( "Invalid hyperparameter exception in trial validation step: {}" .format(e)) response_func( util.wrap_metrics( {}, self.context.get_stop_requested(), invalid_hp=True, init_invalid_hp=False, )) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) self.multiplexer._corrected_train_end() raise det.errors.WorkerFinishedGracefully else: raise AssertionError(f"Unknown workload kind {wkld.kind}.")
def compute_validation_metrics(self) -> workload.Response: metrics = self.estimator.evaluate(input_fn=self.eval_spec.input_fn, steps=self.eval_spec.steps, hooks=self.eval_spec.hooks) if self.hvd_config.use: metrics = self.average_metrics(metrics) if self.is_chief: logging.debug(f"Averaged validation metrics: {metrics}.") estimator._cleanup_after_validation_step( pathlib.Path(self.estimator._model_dir), self.is_chief) # Reset the per-evaluation set of allgather ops in the context. self.context.experimental._reset_allgather_ops() if not self.is_chief: return workload.Skipped() return {"validation_metrics": metrics}
def _save_checkpoint(self, path: pathlib.Path) -> workload.Response: # We assume that at least one training step has completed when saving a # checkpoint. if not self.is_chief: return workload.Skipped() # Save training data iterator position. path.mkdir(parents=True, exist_ok=True) # Save model. self.model.save(path.joinpath("determined-keras-model.h5"), save_format="h5") det.util.write_checkpoint_metadata(path, self.env, { "tensorflow_version": tf.__version__, "format": "h5" }) return {}
def _control_loop(self) -> None: for wkld, args, response_func in self.workloads: if wkld.kind == workload.Workload.Kind.RUN_STEP: # Move on to the next step. self.train_response_func = response_func break elif wkld.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( det.util.wrap_metrics(self._compute_validation_metrics(), self.context.get_stop_requested())) elif wkld.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.len_eq(args, 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self.save_checkpoint(path)) elif wkld.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) raise det.errors.WorkerFinishedGracefully("Exiting normally.") else: raise AssertionError(f"Unknown wkld kind {wkld.kind}")
def _checkpoint_model(self, checkpoint_path: pathlib.Path) -> workload.Response: self._save_model() if not self.estimator_trial_controller.is_chief: return workload.Skipped() self._copy_latest_checkpoint(checkpoint_path=checkpoint_path) self._save_serving_input_receiver_fns(checkpoint_path=str(checkpoint_path)) det.util.write_checkpoint_metadata( checkpoint_path, self.estimator_trial_controller.env, {"tensorflow_version": tf.__version__, "format": "saved_model"}, ) for callback in self.estimator_trial_controller.train_hooks: if isinstance(callback, estimator.RunHook): callback.on_checkpoint_end(str(checkpoint_path)) return {}
def _send_recv_workload(self, wkld: workload.Workload, args: List[Any]) -> workload.Response: # Broadcast every workload to every worker on this machine. self.broadcast_server.broadcast((wkld, args)) try: responses, exception_received = self.broadcast_server.gather_with_polling( self._health_check) except det.errors.WorkerError: if wkld.kind == workload.Workload.Kind.TERMINATE: return {} raise if exception_received: raise det.errors.WorkerError("Training process died.") # Find the response from the chief worker for the trial (the only non-SkippedWorkload). The # chief may report to another container, in which case we will only have SkippedWorkloads. chief_worker_response = None # Optional[workload.Metrics] for response in responses: if isinstance(response, workload.Skipped): continue # Any other response must be a Dict[str, Any]-like object. check.is_instance( response, dict, f"Received non-metrics object from worker: {response}") # There should only be one chief response. check.is_none(chief_worker_response, "Received multiple non-SkippedWorkload messages.") chief_worker_response = cast(Dict[str, Any], response) # Confirm that if we have did not see a chief response then we are not the chief machine. if chief_worker_response is None: check.gt( self.rendezvous_info.get_rank(), 0, "Received SkippedWorkload message from chief worker.", ) return workload.Skipped( ) if chief_worker_response is None else chief_worker_response
def _save_checkpoint(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) # Save model weights. We use `tf` format because `h5` does not support # models that subclass `tf.keras.Model` and define custom `call()` # and/or `train_step()` functions. self.model.save_weights( str(path.joinpath("determined-keras-model-weights")), save_format="tf" ) # Save optimizer(s) weights. with h5py.File(path.joinpath("determined-keras-optimizer-weights.h5"), "w") as h5file: for idx, optimizer in enumerate(self.context._optimizers): opt_group = h5file.create_group(f"optimizer-{idx}") save_optimizer_weights_to_hdf5_group(opt_group, optimizer) # Save RNG state. rng_state = {"np_rng_state": np.random.get_state(), "random_rng_state": random.getstate()} if version.parse(tf.__version__) < version.parse("2.0.0"): rng_state["tf_rng_global_seed"] = tf.compat.v1.random.get_seed(0)[0] else: generator = tf.random.get_global_generator() rng_state["tf2_rng_global_algorithm"] = generator.algorithm rng_state["tf2_rng_global_state"] = generator.state with open(path.joinpath("rng_state.pkl"), "wb") as f: pickle.dump(rng_state, f) # Save user code. det.util.write_user_code(path, self.env.on_cluster) # Save callback(s) state. callbacks_state = self.multiplexer._get_state() with path.joinpath("determined-callbacks.v1.pkl").open("wb") as f: pickle.dump(callbacks_state, f) self.multiplexer._checkpoint_end(path) return {"framework": f"tensorflow-{tf.__version__}", "format": "saved_weights"}
def _save(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) # The model code is the current working directory. util.write_checkpoint_metadata( path, self.env, { "torch_version": torch.__version__ # type: ignore }, ) # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint = { "model_state_dict": self.context.model.state_dict(), "optimizer_state_dict": self.context.optimizer.state_dict(), } if self.lr_helper: checkpoint["lr_scheduler"] = self.lr_helper.state_dict() for name, callback in self.callbacks.items(): checkpoint.setdefault("callbacks", {}) checkpoint["callbacks"][name] = callback.state_dict() torch.save( # type: ignore checkpoint, str(path.joinpath("state_dict.pth")), pickle_module=cloudpickle) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path)) return {}
def _trigger_epoch(self) -> None: """ This runs at the end of each training step, sends the metrics back to the main process, and decides what to do next. """ check.is_not_none(self.train_response_func, "no response_func at end of train_for_step") self.train_response_func = cast(workload.ResponseFunc, self.train_response_func) if self.is_chief: self.train_response_func( det.util.make_metrics(None, self.batch_metrics)) else: self.train_response_func(workload.Skipped()) self.train_response_func = None self.batch_metrics = [] self._control_loop()
def _respond(in_response: workload.Response) -> None: # Only the chief container should actually respond to TRAIN_FOR_STEP. if self.rendezvous_info.get_rank() != 0: respond(workload.Skipped()) return check_not_isinstance(in_response, workload.Skipped, "Chief skipped a workload.") in_response = cast(workload.Metrics, in_response) metrics = in_response["metrics"] metrics = cast(workload.Metrics, metrics) batch_metrics = metrics["batch_metrics"] # Sanity-check training metrics. det.util.validate_batch_metrics(batch_metrics) check_len(batch_metrics, wkld.num_batches) for callback in self.callbacks: callback.on_train_step_end(wkld.step_id, wkld.num_batches, wkld.total_batches_processed, batch_metrics) self.tensorboard_mgr.sync() out_response = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metrics, } if in_response.get("stop_requested", False): out_response["exited_reason"] = "USER_CANCELED" # Send the response up. respond(out_response)
def _save(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) # The model code is the current working directory. util.write_user_code(path) # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint = { "models_state_dict": [model.state_dict() for model in self.context.models], "optimizers_state_dict": [ optimizer.state_dict() for optimizer in self.context.optimizers ], "lr_schedulers_state_dict": [ lr_scheduler.state_dict() for lr_scheduler in self.context.lr_schedulers ], "callbacks": {name: callback.state_dict() for name, callback in self.callbacks.items()}, } torch.save( # type: ignore checkpoint, str(path.joinpath("state_dict.pth")), pickle_module=cloudpickle ) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path)) return cast( workload.Response, { "framework": f"torch-{torch.__version__}", # type: ignore "format": "cloudpickle", }, )
def run(self) -> None: for w, args, response_func in self.workloads: if w.kind == workload.Workload.Kind.RUN_STEP: response_func( util.wrap_metrics( self._train_for_step(w.step_id, w.num_batches, w.total_batches_processed), self.context.get_stop_requested(), )) elif w.kind == workload.Workload.Kind.COMPUTE_VALIDATION_METRICS: response_func( util.wrap_metrics(self._compute_validation_metrics(), self.context.get_stop_requested())) elif w.kind == workload.Workload.Kind.CHECKPOINT_MODEL: check.eq(len(args), 1) check.is_instance(args[0], pathlib.Path) path = cast(pathlib.Path, args[0]) response_func(self._save(path)) elif w.kind == workload.Workload.Kind.TERMINATE: response_func({} if self.is_chief else workload.Skipped()) break else: raise AssertionError("Unexpected workload: {}".format(w.kind))
def yield_checkpoint_model( self, wkld: workload.Workload, respond: workload.ResponseFunc) -> workload.Stream: start_time = _current_timestamp() # Only the chief container should checkpoint. if self.rendezvous_info.get_rank() == 0: with self.storage_mgr.store_path() as (storage_id, path): def _respond(checkpoint_info: workload.Response) -> None: checkpoint_info = cast(Dict[str, Any], checkpoint_info) metadata = storage.StorageMetadata( storage_id, storage.StorageManager._list_directory(path), checkpoint_info.get("framework", ""), checkpoint_info.get("format", ""), ) logging.info("Saved trial to checkpoint {}".format( metadata.storage_id)) self.tensorboard_mgr.sync() message: workload.Response = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metadata, } respond(message) yield wkld, [pathlib.Path(path)], _respond else: respond(workload.Skipped())
def _save_checkpoint(self, path: pathlib.Path) -> workload.Response: # We assume that at least one training step has completed when saving a # checkpoint. if not self.is_chief: return workload.Skipped() # Save training data iterator position. path.mkdir(parents=True, exist_ok=True) # Save model. # tf.keras.models.save_model( # self.model, path.joinpath("determined-keras-model"), overwrite=True, include_optimizer=True, save_format="tf", # signatures=None, options=None # ) tf.saved_model.save(self.model, str(path.joinpath("determined-keras-model"))) det.util.write_checkpoint_metadata(path, self.env, { "tensorflow_version": tf.__version__, "format": "tf" }) return {}
def _compute_validation_metrics(self) -> workload.Response: metrics = self._launch_evaluate() num_inputs = self.multiplexer.get_test_inputs() if self.hvd_config.use: # Use a global ZMQ barrier here because we have observed cases where hvd.allreduce # may hang when called minutes apart by different workers which may happen if # workers complete evaluation at different speeds. self._global_barrier() num_inputs = hvd.allreduce(num_inputs, average=False, name="validation_num_inputs") if isinstance(num_inputs, EagerTensor): # Horovod will promote an int to a tensor in eager mode. num_inputs = num_inputs.numpy() metrics = self._allreduce_logs(metrics) check.gt(len(metrics), 0) self.multiplexer._test_end(metrics) if not self.is_chief: return workload.Skipped() return {"num_inputs": num_inputs, "validation_metrics": metrics}
def _compute_validation_metrics(self) -> workload.Response: self.context.experimental.reset_reducers() # Set the behavior of certain layers (e.g., dropout) that are # different between training and inference. for model in self.context.models: model.eval() for callback in self.callbacks.values(): logging.warning( "on_validation_step_start is now deprecated, please use on_validation_start instead" ) callback.on_validation_step_start() for callback in self.callbacks.values(): callback.on_validation_start() num_inputs = 0 metrics = {} # type: Dict[str, Any] if self._evaluate_batch_defined(): keys = None batch_metrics = [] self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) check.gt(len(self.validation_loader), 0) for batch in self.validation_loader: batch = self.context.to_device(batch) num_inputs += pytorch.data_length(batch) vld_metrics = self.trial.evaluate_batch(batch=batch) # Verify validation metric names are the same across batches. if keys is None: keys = vld_metrics.keys() else: check.eq( keys, vld_metrics.keys(), "Validation metric names must match across all batches of data.", ) check.is_instance( vld_metrics, dict, "validation_metrics() must return a " "dictionary of string names to Tensor " "metrics", ) # TODO: For performance perform -> cpu() only at the end of validation. batch_metrics.append( self._convert_metrics_to_numpy(vld_metrics)) if self.env.test_mode: break metrics = self._reduce_metrics( batch_metrics=batch_metrics, keys=keys, metrics_reducers=self._prepare_metrics_reducers(keys=keys), ) if self.hvd_config.use: num_inputs *= hvd.size() else: check.true(self._evaluate_full_dataset_defined()) self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) if self.is_chief: metrics = self.trial.evaluate_full_dataset( data_loader=self.validation_loader) check.is_instance( metrics, dict, f"eval() must return a dictionary, got {type(metrics)}.") metrics = self._convert_metrics_to_numpy(metrics) num_inputs = self.context.get_per_slot_batch_size() * len( self.validation_loader) metrics.update( self._convert_metrics_to_numpy( self.context.experimental.reduce_metrics(for_training=False))) if self.hvd_config.use and any( map( lambda c: util.is_overridden( c.on_validation_end, pytorch. PyTorchCallback) or util.is_overridden( c.on_validation_step_end, pytorch.PyTorchCallback), self.callbacks.values(), )): logging.debug( "Broadcasting metrics to all worker processes to execute a " "validation step end callback") metrics = hvd.broadcast_object(metrics, root_rank=0) for callback in self.callbacks.values(): logging.warning( "on_validation_step_end is now deprecated, please use on_validation_end instead" ) callback.on_validation_step_end(metrics) for callback in self.callbacks.values(): callback.on_validation_end(metrics) if not self.is_chief: return workload.Skipped() return {"num_inputs": num_inputs, "validation_metrics": metrics}
def _train_for_step(self, step_id: int, num_batches: int, total_batches_processed: int) -> workload.Response: check.gt(step_id, 0) self.context.experimental.reset_reducers() # Set the behavior of certain layers (e.g., dropout) that are different # between training and inference. for model in self.context.models: model.train() start = total_batches_processed end = start + num_batches per_batch_metrics = [] # type: List[Dict] num_inputs = 0 for batch_idx in range(start, end): batch = next(self.training_iterator) ## old code: # num_inputs += pytorch.data_length(batch) # batch = self.context.to_device(batch) num_inputs += self.trial._records_in_batch(batch) batch = self.trial._batch_to_device(batch, self.context) self.context._current_batch_idx = batch_idx self.context._loss_ids = {} tr_metrics = self.trial.train_batch( batch=batch, epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx, ) if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " f"mapping string names to Tensor metrics, got {type(tr_metrics)}", ) # Step learning rate of a pytorch.LRScheduler. for lr_scheduler in self.context.lr_schedulers: self._auto_step_lr_scheduler_per_batch(batch_idx, lr_scheduler) for name, metric in tr_metrics.items(): # Convert PyTorch metric values to NumPy, so that # `det.util.encode_json` handles them properly without # needing a dependency on PyTorch. if isinstance(metric, torch.Tensor): metric = metric.cpu().detach().numpy() tr_metrics[name] = metric per_batch_metrics.append(tr_metrics) # Aggregate and reduce training metrics from all the training processes. if self.hvd_config.use and self.hvd_config.average_training_metrics: per_batch_metrics = self._average_training_metrics( per_batch_metrics) if self.hvd_config.use: num_inputs *= hvd.size() metrics = det.util.make_metrics(num_inputs, per_batch_metrics) # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch # metrics are even logical for a custom reducer. metrics["avg_metrics"].update( self._convert_metrics_to_numpy( self.context.experimental.reduce_metrics(for_training=True))) if not self.is_chief: # The training metrics are reported only in the chief process. return workload.Skipped() logging.debug( f"Done training step: {num_inputs} records in {num_batches} batches." ) return metrics
def _save(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) util.write_user_code(path, self.env.on_cluster) rng_state = { "cpu_rng_state": torch.random.get_rng_state(), # type: ignore "np_rng_state": np.random.get_state(), "random_rng_state": random.getstate(), } if torch.cuda.device_count(): rng_state[ "gpu_rng_state"] = torch.cuda.get_rng_state( # type: ignore self.context.distributed.get_local_rank()) # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint = { "models_state_dict": [model.state_dict() for model in self.context.models], "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.context.optimizers], "lr_schedulers_state_dict": [ lr_scheduler.state_dict() for lr_scheduler in self.context.lr_schedulers ], "callbacks": { name: callback.state_dict() for name, callback in self.callbacks.items() }, "rng_state": rng_state, } if self.context._scaler: checkpoint["scaler_state_dict"] = self.context._scaler.state_dict() if self.context._use_apex: checkpoint["amp_state"] = apex.amp.state_dict() torch.save( # type: ignore checkpoint, str(path.joinpath("state_dict.pth")), pickle_module=cloudpickle) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path)) return cast( workload.Response, { "framework": f"torch-{torch.__version__}", # type: ignore "format": "cloudpickle", }, )
def _respond(_: workload.Response) -> None: respond(workload.Skipped())
def _respond(in_response: workload.Response) -> None: # Only the chief container should actually respond to COMPUTE_VALIDATION_METRICS. if self.rendezvous_info.get_rank() != 0: respond(workload.Skipped()) return check_not_isinstance(in_response, workload.Skipped, "Chief skipped a workload.") in_response = cast(Dict[str, Any], in_response) metrics = in_response["metrics"] metrics = cast(workload.Metrics, metrics) v_metrics = metrics["validation_metrics"] for callback in self.callbacks: callback.on_validation_step_end(wkld.step_id, wkld.total_batches_processed, v_metrics) self.tensorboard_mgr.sync() # Check that the validation metrics computed by the model code # includes the metric used by the search method. searcher_metric = self.env.experiment_config["searcher"]["metric"] if searcher_metric not in v_metrics: raise AssertionError( "Search method is configured to use metric '{}' but model " "definition returned validation metrics {}. The metric " "used by the search method must be one of the validation " "metrics returned by the model definition.".format( searcher_metric, list(v_metrics.keys()))) sys.exit(1) non_serializable_metrics = set() # NaN and bytes are not JSON serializable. None does not have a # canonical JSON representation. In the case of trial implementation bugs # or numerical instability issues, validation metric functions may # return None or NaN values. For now, immediately fail any trial that # encounters such a None metric. For NaN metrics, if it's the target of # the searcher, we set it to +/- max_float depending on if the searcher # is optimizing for the max or min. NaN metrics which are not the # target of the searcher are dropped. # TODO (DET-2495): Do not replace NaN metric values. for metric_name, metric_value in v_metrics.items(): metric_is_none = metric_value is None metric_is_nan = tensorboard.metric_writers.util.is_numerical_scalar( metric_value) and math.isnan(metric_value) if metric_is_none or metric_is_nan: raise AssertionError("Validation metric '{}' returned " "an invalid scalar value: {}".format( metric_name, metric_value)) sys.exit(1) if isinstance(metric_value, (bytes, bytearray)): non_serializable_metrics.add(metric_name) if len(non_serializable_metrics): logging.warning("Removed non serializable metrics: %s", ", ".join(non_serializable_metrics)) for metric_name in non_serializable_metrics: del v_metrics[metric_name] out_response = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metrics, } if in_response.get("stop_requested", False): out_response["exited_reason"] = "USER_CANCELED" respond(out_response)
def _train_for_step(self, step_id: int, batches_per_step: int) -> workload.Response: check.gt(step_id, 0) # Set the behavior of certain layers (e.g., dropout) that are different # between training and inference. self.context.model.train() for callback in self.callbacks.values(): callback.on_train_step_start(step_id) step_idx = step_id - 1 start = step_idx * batches_per_step end = start + batches_per_step per_batch_metrics = [] # type: List[Dict] num_inputs = 0 for batch_idx in range(start, end): batch = next(self.training_iterator) num_inputs += data_length(batch) batch = self._to_device(batch) # Forward pass. tr_metrics = self.trial.train_batch( batch=batch, model=self.context.model, epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx, ) if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " "mapping string names to Tensor metrics, got {type(tr_metrics)}", ) check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.') # Backwards pass. loss = tr_metrics["loss"] communicate_and_update = (batch_idx + 1) % self.hvd_config.aggregation_frequency == 0 if self.use_amp(): with apex.amp.scale_loss(loss, self.context.optimizer) as scaled_loss: scaled_loss.backward() if self.hvd_config.use and communicate_and_update: # When using horovod, we need to finish communicating gradient # updates before they are unscaled which happens when we exit # of this context manager. self.context.optimizer.synchronize() else: loss.backward() # Communication needs to be synchronized so that is completed # before we apply gradient clipping and `step()`. if communicate_and_update and self.hvd_config.use: self.context.optimizer.synchronize() if communicate_and_update: parameters = ( self.context.model.parameters() if not self.use_amp() else apex.amp.master_params(self.context.optimizer) ) if self.hvd_config.average_aggregated_gradients: self._average_gradients( parameters=parameters, divisor=self.hvd_config.aggregation_frequency ) # TODO: Remove this check in v0.12.8. check.false( self.env.hparams.get("clip_grad_l2_norm", None) or self.env.hparams.get("clip_grad_val", None), "Please specify gradient clipping via callbacks.", ) for callback in self.callbacks.values(): callback.on_before_optimizer_step(parameters) if self.hvd_config.use: with self.context.optimizer.skip_synchronize(): self.context.optimizer.step() else: self.context.optimizer.step() self.context.optimizer.zero_grad() # Step learning rate of a LRScheduler. if self.context.lr_scheduler is not None: self._auto_step_lr_scheduler_per_batch(batch_idx, self.context.lr_scheduler) for name, metric in tr_metrics.items(): # Convert PyTorch metric values to NumPy, so that # `det.util.encode_json` handles them properly without # needing a dependency on PyTorch. if isinstance(metric, torch.Tensor): metric = metric.cpu().detach().numpy() tr_metrics[name] = metric check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.') per_batch_metrics.append(tr_metrics) if self.hvd_config.use and self.hvd_config.average_training_metrics: per_batch_metrics = self._average_training_metrics(per_batch_metrics) if self.hvd_config.use: num_inputs *= hvd.size() metrics = det.util.make_metrics(num_inputs, per_batch_metrics) for callback in self.callbacks.values(): callback.on_train_step_end(step_id, metrics) if not self.is_chief: return workload.Skipped() logging.debug(f"Done training step: {num_inputs} records in {batches_per_step} batches.") return metrics
def _train_for_step(self, step_id: int, num_batches: int, total_batches_processed: int) -> workload.Response: check.gt(step_id, 0) # Set the behavior of certain layers (e.g., dropout) that are different # between training and inference. for model in self.context.models: model.train() start = total_batches_processed end = start + num_batches per_batch_metrics = [] # type: List[Dict] num_inputs = 0 for batch_idx in range(start, end): batch = next(self.training_iterator) num_inputs += data_length(batch) batch = self.context._to_device(batch) self.context._current_batch_idx = batch_idx self.context._loss_ids = {} tr_metrics = self.trial.train_batch( batch=batch, model=self.context.models[0], epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx, ) if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " f"mapping string names to Tensor metrics, got {type(tr_metrics)}", ) check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.') # Step learning rate of a LRScheduler. for lr_scheduler in self.context.lr_schedulers: self._auto_step_lr_scheduler_per_batch(batch_idx, lr_scheduler) for name, metric in tr_metrics.items(): # Convert PyTorch metric values to NumPy, so that # `det.util.encode_json` handles them properly without # needing a dependency on PyTorch. if isinstance(metric, torch.Tensor): metric = metric.cpu().detach().numpy() tr_metrics[name] = metric check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.') per_batch_metrics.append(tr_metrics) # Aggregate and reduce training metrics from all the training processes. if self.hvd_config.use and self.hvd_config.average_training_metrics: per_batch_metrics = self._average_training_metrics( per_batch_metrics) if self.hvd_config.use: num_inputs *= hvd.size() metrics = det.util.make_metrics(num_inputs, per_batch_metrics) if not self.is_chief: # The training metrics are reported only in the chief process. return workload.Skipped() logging.debug( f"Done training step: {num_inputs} records in {num_batches} batches." ) return metrics
def _train_for_step(self, step_id: int, num_batches: int, total_batches_processed: int) -> workload.Response: check.gt(step_id, 0) self.context.reset_reducers() # Set the behavior of certain layers (e.g., dropout) that are different # between training and inference. for model in self.context.models: model.train() start = total_batches_processed end = start + num_batches per_batch_metrics = [] # type: List[Dict] num_inputs = 0 for batch_idx in range(start, end): batch_start_time = time.time() self.prof.update_batch_idx(batch_idx) with self.prof.record_timing("dataloader_next"): batch = next(self.training_iterator) batch_inputs = self.trial.get_batch_length(batch) num_inputs += batch_inputs with self.prof.record_timing("to_device"): batch = self.context.to_device(batch) self.context._current_batch_idx = batch_idx if self.context.is_epoch_start(): for callback in self.callbacks.values(): with self.prof.record_timing( f"callbacks.{callback.__class__.__name__}.on_training_epoch_start" ): callback.on_training_epoch_start() self.context._loss_ids = {} with self.prof.record_timing("train_batch"): if self.context.profiler: with self.context.profiler as torch_profiler: tr_metrics = self.trial.train_batch( batch=batch, epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx, ) torch_profiler.step() else: tr_metrics = self.trial.train_batch( batch=batch, epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx, ) if self._should_update_scaler(): self.context._scaler.update() if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " f"mapping string names to Tensor metrics, got {type(tr_metrics)}", ) # Step learning rate of a pytorch.LRScheduler. with self.prof.record_timing("step_lr_schedulers"): for lr_scheduler in self.context.lr_schedulers: self._auto_step_lr_scheduler_per_batch( batch_idx, lr_scheduler) with self.prof.record_timing("from_device"): for name, metric in tr_metrics.items(): # Convert PyTorch metric values to NumPy, so that # `det.util.encode_json` handles them properly without # needing a dependency on PyTorch. if isinstance(metric, torch.Tensor): metric = metric.cpu().detach().numpy() tr_metrics[name] = metric batch_dur = time.time() - batch_start_time samples_per_second = batch_inputs / batch_dur self.prof.record_metric("samples_per_second", samples_per_second) per_batch_metrics.append(tr_metrics) # Aggregate and reduce training metrics from all the training processes. if self.hvd_config.use and self.hvd_config.average_training_metrics: with self.prof.record_timing("average_training_metrics"): per_batch_metrics = self._average_training_metrics( per_batch_metrics) if self.hvd_config.use: num_inputs *= hvd.size() metrics = det.util.make_metrics(num_inputs, per_batch_metrics) # Ignore batch_metrics entirely for custom reducers; there's no guarantee that per-batch # metrics are even logical for a custom reducer. with self.prof.record_timing("reduce_metrics"): metrics["avg_metrics"].update( self._convert_metrics_to_numpy( self.context.reduce_metrics(for_training=True))) if not self.is_chief: # The training metrics are reported only in the chief process. return workload.Skipped() logging.debug( f"Done training step: {num_inputs} records in {num_batches} batches." ) return metrics
def _compute_validation_metrics(self) -> workload.Response: # Set the behavior of certain layers (e.g., dropout) that are # different between training and inference. self.model.eval() num_inputs = 0 metrics = {} # type: Optional[Dict[str, Any]] if self._evaluate_batch_defined(): keys = None batch_metrics = [] self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) check.gt(len(self.validation_loader), 0) for batch in self.validation_loader: batch = self._to_device(batch) num_inputs += data_length(batch) vld_metrics = self.trial.evaluate_batch(batch=batch, model=self.model) # Verify validation metric names are the same across batches. if keys is None: keys = vld_metrics.keys() else: check.eq( keys, vld_metrics.keys(), "Validation metric names must match across all batches of data.", ) check.is_instance( vld_metrics, dict, "validation_metrics() must return a " "dictionary of string names to Tensor " "metrics", ) # TODO: For performance perform -> cpu() only at the end of validation. batch_metrics.append( self._convert_metrics_to_numpy(vld_metrics)) keys = cast(Any, keys) metrics = self._reduce_metrics( batch_metrics=batch_metrics, keys=keys, metrics_reducers=self._prepare_metrics_reducers(keys=keys), ) if self.hvd_config.use: num_inputs *= hvd.size() else: check.true(self._evaluate_full_dataset_defined()) self.validation_loader = cast(torch.utils.data.DataLoader, self.validation_loader) if self.is_chief: metrics = self.trial.evaluate_full_dataset( data_loader=self.validation_loader, model=self.model) check.is_instance( metrics, dict, f"eval() must return a dictionary, got {type(metrics)}.") metrics = self._convert_metrics_to_numpy(metrics) num_inputs = self.context.get_per_slot_batch_size() * len( self.validation_loader) if not self.is_chief: return workload.Skipped() return {"num_inputs": num_inputs, "validation_metrics": metrics}
def _train_for_step(self, step_id: int, batches_per_step: int) -> workload.Response: check.gt(step_id, 0) step_idx = step_id - 1 start = step_idx * batches_per_step end = start + batches_per_step # Set the behavior of certain layers (e.g., dropout) that are different # between training and inference. self.model.train() per_batch_metrics = [] # type: List[Dict] num_inputs = 0 for batch_idx in range(start, end): batch = next(self.training_iterator) num_inputs += data_length(batch) batch = self._to_device(batch) # Forward pass. tr_metrics = self.trial.train_batch( batch=batch, model=self.model, epoch_idx=self.get_epoch_idx(batch_idx), batch_idx=batch_idx, ) if isinstance(tr_metrics, torch.Tensor): tr_metrics = {"loss": tr_metrics} check.is_instance( tr_metrics, dict, "train_batch() must return a dictionary " "mapping string names to Tensor metrics, got {type(tr_metrics)}", ) check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.') # Backwards pass. loss = tr_metrics["loss"] communicate_and_update = ( batch_idx + 1) % self.hvd_config.aggregation_frequency == 0 if self.use_amp(): with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if self.hvd_config.use and communicate_and_update: self.optimizer.synchronize() else: loss.backward() if communicate_and_update: parameters = (self.model.parameters() if not self.use_amp() else apex.amp.master_params(self.optimizer)) if self.hvd_config.average_aggregated_gradients: self._average_gradients( parameters=parameters, divisor=self.hvd_config.aggregation_frequency) self._clip_grads(parameters) if self.hvd_config.use and self.use_amp(): with self.optimizer.skip_synchronize(): self.optimizer.step() else: self.optimizer.step() self.optimizer.zero_grad() if self.lr_helper.should_step_lr( batches_completed=batch_idx + 1, epoch_length=len(self.training_loader), aggregation_frequency=self.hvd_config. aggregation_frequency, ): self.lr_helper.step() for name, metric in tr_metrics.items(): # Convert PyTorch metric values to NumPy, so that # `det.util.encode_json` handles them properly without # needing a dependency on PyTorch. if isinstance(metric, torch.Tensor): metric = metric.cpu().detach().numpy() tr_metrics[name] = metric check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.') per_batch_metrics.append(tr_metrics) if self.hvd_config.use and self.hvd_config.average_training_metrics: per_batch_metrics = self._average_training_metrics( per_batch_metrics) if not self.is_chief: return workload.Skipped() if self.hvd_config.use: num_inputs *= hvd.size() logging.debug( f"Done training step: {num_inputs} records in {batches_per_step} batches." ) return det.util.make_metrics(num_inputs, per_batch_metrics)