def _atomic_save(self, checkpoint, filepath: str): """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once saving is finished. Args: checkpoint: The object to save. Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` accepts. filepath: The path to which the checkpoint will be saved. This points to the file that the checkpoint will be stored in. """ bytesbuffer = io.BytesIO() # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False) else: torch.save(checkpoint, bytesbuffer) with cloud_open(filepath, 'wb') as f: f.write(bytesbuffer.getvalue())
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: """ Args: config_yaml: path to new YAML file hparams: parameters to be saved """ if not gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") # convert Namespace or AD to dict if isinstance(hparams, Namespace): hparams = vars(hparams) elif isinstance(hparams, AttributeDict): hparams = dict(hparams) # saving with OmegaConf objects if OmegaConf is not None: if OmegaConf.is_config(hparams): OmegaConf.save(hparams, config_yaml, resolve=True) return for v in hparams.values(): if OmegaConf.is_config(v): OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True) return # saving the standard way assert isinstance(hparams, dict) with cloud_open(config_yaml, 'w', newline='') as fp: yaml.dump(hparams, fp)
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: if not gfile.isdir(os.path.dirname(tags_csv)): raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.") if isinstance(hparams, Namespace): hparams = vars(hparams) with cloud_open(tags_csv, "w", newline="") as fp: fieldnames = ["key", "value"] writer = csv.DictWriter(fp, fieldnames=fieldnames) writer.writerow({"key": "key", "value": "value"}) for k, v in hparams.items(): writer.writerow({"key": k, "value": v})
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) >>> hparams_new = load_hparams_from_yaml(path_yaml) >>> vars(hparams) == hparams_new True >>> os.remove(path_yaml) """ if not gfile.exists(config_yaml): rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) return {} with cloud_open(config_yaml, "r") as fp: tags = yaml.load(fp) return tags
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_csv = os.path.join('.', 'testing-hparams.csv') >>> save_hparams_to_tags_csv(path_csv, hparams) >>> hparams_new = load_hparams_from_tags_csv(path_csv) >>> vars(hparams) == hparams_new True >>> os.remove(path_csv) """ if not gfile.exists(tags_csv): rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning) return {} with cloud_open(tags_csv, "r", newline="") as fp: csv_reader = csv.reader(fp, delimiter=",") tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} return tags
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.distributed_backend.lower() not in [ 'ddp_spawn', 'ddp_cpu', 'tpu' ]: return # track the best model path best_model_path = None if self.checkpoint_callback is not None: best_model_path = self.checkpoint_callback.best_model_path if self.global_rank == 0 and mp_queue is not None: rank_zero_warn('cleaning up ddp environment...') # todo, pass complete checkpoint as state dictionary mp_queue.put(best_model_path) mp_queue.put(results) # save the last weights last_path = None if not self.testing and best_model_path is not None and len( best_model_path) > 0: last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) # Can't use the new zipfile serialization for 1.6.0 because there's a bug in # torch.hub.load_state_dict_from_url() that prevents it from loading the new files. # More details can be found here: https://github.com/pytorch/pytorch/issues/42239 bytesbuffer = io.BytesIO() if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]: torch.save(model.state_dict(), bytesbuffer, _use_new_zipfile_serialization=False) else: torch.save(model.state_dict(), bytesbuffer) with cloud_open(last_path, 'wb') as f: f.write(bytesbuffer.getvalue()) mp_queue.put(last_path)
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: """ Args: config_yaml: path to new YAML file hparams: parameters to be saved """ if not gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") if OMEGACONF_AVAILABLE and isinstance(hparams, Container): from omegaconf import OmegaConf OmegaConf.save(hparams, config_yaml, resolve=True) return # saving the standard way if isinstance(hparams, Namespace): hparams = vars(hparams) elif isinstance(hparams, AttributeDict): hparams = dict(hparams) assert isinstance(hparams, dict) with cloud_open(config_yaml, "w", newline="") as fp: yaml.dump(hparams, fp)