def _load(self) -> None: if not self.load_path: return # Backwards compat with older checkpoint formats. List is newest to # oldest known state_dict locations. potential_paths = [ ["state_dict.pth"], ["determined", "state_dict.pth"], ["pedl", "state_dict.pth"], ["checkpoint.pt"], ] checkpoint: Optional[Dict[str, Any]] = None for ckpt_path in potential_paths: maybe_ckpt = self.load_path.joinpath(*ckpt_path) if maybe_ckpt.exists(): checkpoint = torch.load(str(maybe_ckpt), map_location="cpu") # type: ignore break if checkpoint is None or not isinstance(checkpoint, dict): return for callback in self.callbacks.values(): callback.on_checkpoint_load_start(checkpoint) if "model_state_dict" in checkpoint: # Backward compatible with older checkpoint format. check.not_in("models_state_dict", checkpoint) check.eq(len(self.context.models), 1) self.context.models[0].load_state_dict( checkpoint["model_state_dict"]) else: for idx, model in enumerate(self.context.models): model.load_state_dict(checkpoint["models_state_dict"][idx]) if "optimizer_state_dict" in checkpoint: # Backward compatible with older checkpoint format. check.not_in("optimizers_state_dict", checkpoint) check.eq(len(self.context.optimizers), 1) self.context.optimizers[0].load_state_dict( checkpoint["optimizer_state_dict"]) else: for idx, optimizer in enumerate(self.context.optimizers): optimizer.load_state_dict( checkpoint["optimizers_state_dict"][idx]) if "lr_scheduler" in checkpoint: # Backward compatible with older checkpoint format. check.not_in("lr_schedulers_state_dict", checkpoint) check.eq(len(self.context.lr_schedulers), 1) self.context.lr_schedulers[0].load_state_dict( checkpoint["lr_scheduler"]) else: for idx, lr_scheduler in enumerate(self.context.lr_schedulers): lr_scheduler.load_state_dict( checkpoint["lr_schedulers_state_dict"][idx]) if "scaler_state_dict": if self.context._scaler: self.context._scaler.load_state_dict( checkpoint["scaler_state_dict"]) else: logging.warning( "There exists scaler_state_dict in checkpoint but the experiment is not using " "AMP.") else: if self.context._scaler: logging.warning( "The experiment is using AMP but scaler_state_dict does not exist in the " "checkpoint.") if "amp_state" in checkpoint: if self.context._use_apex: apex.amp.load_state_dict(checkpoint["amp_state"]) else: logging.warning( "There exists amp_state in checkpoint but the experiment is not using Apex." ) else: if self.context._use_apex: logging.warning( "The experiment is using Apex but amp_state does not exist in the checkpoint." ) if "rng_state" in checkpoint: rng_state = checkpoint["rng_state"] np.random.set_state(rng_state["np_rng_state"]) random.setstate(rng_state["random_rng_state"]) torch.random.set_rng_state(rng_state["cpu_rng_state"]) if torch.cuda.device_count(): if "gpu_rng_state" in rng_state: torch.cuda.set_rng_state( rng_state["gpu_rng_state"], device=self.context.distributed.get_local_rank()) else: logging.warning( "The system has a gpu but no gpu_rng_state exists in the checkpoint." ) else: if "gpu_rng_state" in rng_state: logging.warning( "There exists gpu_rng_state in checkpoint but the system has no gpu." ) else: logging.warning("The checkpoint has no random state to restore.") callback_state = checkpoint.get("callbacks", {}) for name in self.callbacks: if name in callback_state: self.callbacks[name].load_state_dict(callback_state[name]) elif util.is_overridden(self.callbacks[name].load_state_dict, pytorch.PyTorchCallback): logging.warning( "Callback '{}' implements load_state_dict(), but no callback state " "was found for that name when restoring from checkpoint. This " "callback will be initialized from scratch")
def from_configs( experiment_config: ExperimentConfig, rendezvous_info: RendezvousInfo, hparams: Dict[str, Any], ) -> "HorovodContext": """ Create the HorovodContext according to experiment config and rendezvous info for this trial. """ # Horovod is always used for multi-machine distributed training. For # single-machine multi-GPU training, Horovod is used when native_parallel is # disabled. multi_machine_trial = rendezvous_info.get_size() > 1 multi_slot_trial = experiment_config["resources"]["slots_per_trial"] > 1 use_horovod = multi_machine_trial or ( multi_slot_trial and not experiment_config.native_parallel_enabled()) check.is_in("optimizations", experiment_config) optimizations_config = cast(Dict[str, Any], experiment_config.get("optimizations")) check.is_in("aggregation_frequency", optimizations_config) check.is_in("gradient_compression", optimizations_config) check.is_in("average_training_metrics", optimizations_config) # Help users migrate from the old locations for these settings, in hparams. def error_message_removed_from_hparams(removed_hparam: str) -> str: return ( f"Please move `{removed_hparam}` in the experiment config to " f"`Optimizations` from `hyperparameters`.") check.not_in( "aggregation_frequency", hparams, error_message_removed_from_hparams("aggregation_frequency"), ) check.not_in( "gradient_compression", hparams, error_message_removed_from_hparams("gradient_compression"), ) check.not_in( "grad_updates_size_file", hparams, error_message_removed_from_hparams("grad_updates_size_file"), ) hvd_config = HorovodContext( use=use_horovod, aggregation_frequency=cast( int, optimizations_config.get("aggregation_frequency")), fp16_compression=cast( bool, optimizations_config.get("gradient_compression")), grad_updates_size_file=optimizations_config.get( "grad_updates_size_file", None), average_aggregated_gradients=cast( bool, optimizations_config.get("average_aggregated_gradients")), average_training_metrics=cast( bool, optimizations_config.get("average_training_metrics")), ) if hvd_config.use and hvd_config.aggregation_frequency > 1: logging.info( f"Setting `aggregation_frequency` to {hvd_config.aggregation_frequency} " "to optimize training.") if hvd_config.use and hvd_config.fp16_compression: logging.info( "Enabling `gradient_compression` to optimize training.") return hvd_config
def _load(self, load_path: pathlib.Path) -> None: # Backwards compat with older checkpoint formats. List is newest to # oldest known state_dict locations. potential_paths = [ ["state_dict.pth"], ["determined", "state_dict.pth"], ["pedl", "state_dict.pth"], ["checkpoint.pt"], ] checkpoint: Optional[Dict[str, Any]] = None for ckpt_path in potential_paths: maybe_ckpt = load_path.joinpath(*ckpt_path) if maybe_ckpt.exists(): checkpoint = torch.load(str(maybe_ckpt), map_location="cpu") # type: ignore break if checkpoint is None or not isinstance(checkpoint, dict): return for callback in self.callbacks.values(): callback.on_checkpoint_load_start(checkpoint) if "model_state_dict" in checkpoint: # Backward compatible with older checkpoint format. check.not_in("models_state_dict", checkpoint) check.eq(len(self.context.models), 1) self.context.models[0].load_state_dict( checkpoint["model_state_dict"]) else: for idx, model in enumerate(self.context.models): model_state_dict = checkpoint["models_state_dict"][idx] try: model.load_state_dict(model_state_dict) except Exception: # If the checkpointed model is non-DDP and the current model is DDP, append # module prefix to the checkpointed data if isinstance(model, torch.nn.parallel.DistributedDataParallel): logging.debug( "Loading non-DDP checkpoint into a DDP model") self._add_prefix_in_state_dict_if_not_present( model_state_dict, "module.") else: # If the checkpointed model is DDP and we are currently running in # single-slot mode, remove the module prefix from checkpointed data logging.debug( "Loading DDP checkpoint into a non-DDP model") torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( model_state_dict, "module.") model.load_state_dict(model_state_dict) if "optimizer_state_dict" in checkpoint: # Backward compatible with older checkpoint format. check.not_in("optimizers_state_dict", checkpoint) check.eq(len(self.context.optimizers), 1) self.context.optimizers[0].load_state_dict( checkpoint["optimizer_state_dict"]) else: for idx, optimizer in enumerate(self.context.optimizers): optimizer.load_state_dict( checkpoint["optimizers_state_dict"][idx]) if "lr_scheduler" in checkpoint: # Backward compatible with older checkpoint format. check.not_in("lr_schedulers_state_dict", checkpoint) check.eq(len(self.context.lr_schedulers), 1) self.context.lr_schedulers[0].load_state_dict( checkpoint["lr_scheduler"]) else: for idx, lr_scheduler in enumerate(self.context.lr_schedulers): lr_scheduler.load_state_dict( checkpoint["lr_schedulers_state_dict"][idx]) if "scaler_state_dict" in checkpoint: if self.context._scaler: self.context._scaler.load_state_dict( checkpoint["scaler_state_dict"]) else: logging.warning( "There exists scaler_state_dict in checkpoint but the experiment is not using " "AMP.") else: if self.context._scaler: logging.warning( "The experiment is using AMP but scaler_state_dict does not exist in the " "checkpoint.") if "amp_state" in checkpoint: if self.context._use_apex: apex.amp.load_state_dict(checkpoint["amp_state"]) else: logging.warning( "There exists amp_state in checkpoint but the experiment is not using Apex." ) else: if self.context._use_apex: logging.warning( "The experiment is using Apex but amp_state does not exist in the checkpoint." ) if "rng_state" in checkpoint: rng_state = checkpoint["rng_state"] np.random.set_state(rng_state["np_rng_state"]) random.setstate(rng_state["random_rng_state"]) torch.random.set_rng_state(rng_state["cpu_rng_state"]) if torch.cuda.device_count(): if "gpu_rng_state" in rng_state: torch.cuda.set_rng_state( rng_state["gpu_rng_state"], device=self.context.distributed.local_rank) else: logging.warning( "The system has a gpu but no gpu_rng_state exists in the checkpoint." ) else: if "gpu_rng_state" in rng_state: logging.warning( "There exists gpu_rng_state in checkpoint but the system has no gpu." ) else: logging.warning("The checkpoint has no random state to restore.") callback_state = checkpoint.get("callbacks", {}) for name in self.callbacks: if name in callback_state: self.callbacks[name].load_state_dict(callback_state[name]) elif util.is_overridden(self.callbacks[name].load_state_dict, pytorch.PyTorchCallback): logging.warning( "Callback '{}' implements load_state_dict(), but no callback state " "was found for that name when restoring from checkpoint. This " "callback will be initialized from scratch") # Load workload sequencer state. wlsq_path = load_path.joinpath("workload_sequencer.pkl") if self.wlsq is not None and wlsq_path.exists(): with wlsq_path.open("rb") as f: self.wlsq.load_state(pickle.load(f))