def _save(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) # The model code is the current working directory. util.write_user_code(path) rng_state = { "cpu_rng_state": torch.random.get_rng_state(), # type: ignore "np_rng_state": np.random.get_state(), "random_rng_state": random.getstate(), } if torch.cuda.device_count(): rng_state[ "gpu_rng_state"] = torch.cuda.get_rng_state( # type: ignore self.context.distributed.get_local_rank()) # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint = { "models_state_dict": [model.state_dict() for model in self.context.models], "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.context.optimizers], "lr_schedulers_state_dict": [ lr_scheduler.state_dict() for lr_scheduler in self.context.lr_schedulers ], "callbacks": { name: callback.state_dict() for name, callback in self.callbacks.items() }, "rng_state": rng_state, } if self.context._use_amp: checkpoint["amp_state"] = apex.amp.state_dict() torch.save( # type: ignore checkpoint, str(path.joinpath("state_dict.pth")), pickle_module=cloudpickle) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path)) return cast( workload.Response, { "framework": f"torch-{torch.__version__}", # type: ignore "format": "cloudpickle", }, )
def _save(self, path: pathlib.Path) -> None: if self.context.distributed.local_rank == 0: path.mkdir(parents=True, exist_ok=True) _ = self.context.distributed.gather_local(None) # sync if self.is_chief: # We assume these stateful objects should be the same across slots and only have # the chief save them. util.write_user_code(path, self.env.on_cluster) if self.wlsq is not None: with path.joinpath("workload_sequencer.pkl").open("wb") as f: pickle.dump(self.wlsq.get_state(), f) # Save per rank Determined checkpoint. rng_state = { "cpu_rng_state": torch.random.get_rng_state(), "np_rng_state": np.random.get_state(), "random_rng_state": random.getstate(), } if torch.cuda.device_count(): rng_state["gpu_rng_state"] = torch.cuda.get_rng_state( self.context.distributed.get_local_rank()) checkpoint = {"rng_state": rng_state} # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint["callbacks"] = { name: callback.state_dict() for name, callback in self.callbacks.items() } for callback in self.callbacks.values(): callback.on_checkpoint_save_start(checkpoint) ckpt_name = f"det_state_dict_rank{self.context.distributed.rank}.pth" torch.save(checkpoint, str(path.joinpath(ckpt_name))) # We allow users to override save behavior if needed but we default to using # the save method provided by DeepSpeed. self.trial.save(self.context, path) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path))
def _save(self, path: pathlib.Path) -> workload.Response: if not self.is_chief: return workload.Skipped() path.mkdir(parents=True, exist_ok=True) # The model code is the current working directory. util.write_user_code(path) # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint = { "model_state_dict": self.context.model.state_dict(), "optimizer_state_dict": self.context.optimizer.state_dict(), } if self.context.lr_scheduler is not None: checkpoint["lr_scheduler"] = self.context.lr_scheduler.state_dict() for name, callback in self.callbacks.items(): checkpoint.setdefault("callbacks", {}) checkpoint["callbacks"][name] = callback.state_dict() torch.save( # type: ignore checkpoint, str(path.joinpath("state_dict.pth")), pickle_module=cloudpickle) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path)) return cast( workload.Response, { "framework": f"torch-{torch.__version__}", # type: ignore "format": "cloudpickle", }, )
def _save(self, path: pathlib.Path) -> None: path.mkdir(parents=True, exist_ok=True) util.write_user_code(path, self.env.on_cluster) rng_state = { "cpu_rng_state": torch.random.get_rng_state(), "np_rng_state": np.random.get_state(), "random_rng_state": random.getstate(), } if torch.cuda.device_count(): rng_state["gpu_rng_state"] = torch.cuda.get_rng_state( self.context.distributed.local_rank) # PyTorch uses optimizer objects that take the model parameters to # optimize on construction, so we store and reload the `state_dict()` # of the model and optimizer explicitly (instead of dumping the entire # objects) to avoid breaking the connection between the model and the # optimizer. checkpoint = { "models_state_dict": [model.state_dict() for model in self.context.models], "optimizers_state_dict": [optimizer.state_dict() for optimizer in self.context.optimizers], "lr_schedulers_state_dict": [ lr_scheduler.state_dict() for lr_scheduler in self.context.lr_schedulers ], "callbacks": { name: callback.state_dict() for name, callback in self.callbacks.items() }, "rng_state": rng_state, } if self.context._scaler: checkpoint["scaler_state_dict"] = self.context._scaler.state_dict() if self.context._use_apex: checkpoint["amp_state"] = apex.amp.state_dict() for callback in self.callbacks.values(): callback.on_checkpoint_save_start(checkpoint) torch.save(checkpoint, str(path.joinpath("state_dict.pth"))) for callback in self.callbacks.values(): callback.on_checkpoint_end(str(path)) if self.wlsq is not None: with path.joinpath("workload_sequencer.pkl").open("wb") as f: pickle.dump(self.wlsq.get_state(), f) trial_cls = type(self.trial) with open(path.joinpath("load_data.json"), "w") as f2: json.dump( { "trial_type": "PyTorchTrial", "experiment_config": self.context.env.experiment_config, "hparams": self.context.env.hparams, "trial_cls_spec": f"{trial_cls.__module__}:{trial_cls.__qualname__}", }, f2, )