def on_save_checkpoint(self, trainer, pl_module): # only do this 1 time if trainer.current_epoch == 0: file_path = f"{trainer.logger.log_dir}/hparams.yaml" print(f"Saving hparams to file_path: {file_path}") save_hparams_to_yaml(config_yaml=file_path, hparams=pl_module.hparams)
def __init__(self, args_0, args_1, args_2, kwarg_1=None): self.save_hyperparameters() self.test_hparams() config_file = f"{tmpdir}/hparams.yaml" save_hparams_to_yaml(config_file, self.hparams) self.hparams = load_hparams_from_yaml(config_file) self.test_hparams() super().__init__()
def on_save_checkpoint(self, trainer, pl_module): if trainer.current_epoch == 0: file_path = os.path.join(trainer.logger.save_dir, trainer.logger.name, f"version_{trainer.logger.version}", "hparams.yaml") print(f"Saving hparams to file_path: {file_path}") save_hparams_to_yaml(config_yaml=file_path, hparams=pl_module.hparams)
def save(self) -> None: super().save() dir_path = self.log_dir # prepare the file path hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE) # save the metatags file if it doesn't exist and the log directory exists if self._fs.isdir(dir_path) and not self._fs.isfile(hparams_file): save_hparams_to_yaml(hparams_file, self.hparams)
def save(self) -> None: super().save() dir_path = self.log_dir if not os.path.isdir(dir_path): dir_path = self.save_dir # prepare the file path hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE) # save the metatags file save_hparams_to_yaml(hparams_file, self.hparams)
def save(self) -> None: # Initialize experiment _ = self.experiment super().save() # prepare the file path hparams_file = os.path.join(self.save_dir, self.NAME_HPARAMS_FILE) # save the metatags file if it doesn't exist if not os.path.isfile(hparams_file): save_hparams_to_yaml(hparams_file, self.hparams)
def save(self) -> None: """Save recorded hparams and metrics into files.""" hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) save_hparams_to_yaml(hparams_file, self.hparams) if not self.metrics: return last_m = {} for m in self.metrics: last_m.update(m) metrics_keys = list(last_m.keys()) with open(self.metrics_file_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=metrics_keys) writer.writeheader() writer.writerows(self.metrics)
def save(self) -> None: super().save() try: self.experiment.flush() except AttributeError: # you are using PT version (<v1.2) which does not have implemented flush self.experiment._get_file_writer().flush() dir_path = self.log_dir if not os.path.isdir(dir_path): dir_path = self.save_dir # prepare the file path hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE) # save the metatags file save_hparams_to_yaml(hparams_file, self.hparams)
def save(self) -> None: super().save() dir_path = self.log_dir if not os.path.isdir(dir_path): dir_path = self.save_dir # prepare the file path hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE) # save the metatags file if Container is not None: if isinstance(self.hparams, Container): from omegaconf import OmegaConf OmegaConf.save(self.hparams, hparams_file, resolve=True) else: save_hparams_to_yaml(hparams_file, self.hparams) else: save_hparams_to_yaml(hparams_file, self.hparams)
def test_hparams_save_yaml(tmpdir): hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here', nasted=dict(any_num=123, anystr='abcd')) path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml') save_hparams_to_yaml(path_yaml, hparams) assert load_hparams_from_yaml(path_yaml) == hparams save_hparams_to_yaml(path_yaml, Namespace(**hparams)) assert load_hparams_from_yaml(path_yaml) == hparams save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams
def test_hparams_save_yaml(tmpdir): class Options(str, Enum): option1name = "option1val" option2name = "option2val" option3name = "option3val" hparams = dict( batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"), switch=Options.option3name, ) path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") def _compare_params(loaded_params, default_params: dict): assert isinstance(loaded_params, (dict, DictConfig)) assert loaded_params.keys() == default_params.keys() for k, v in default_params.items(): if isinstance(v, Enum): assert v.name == loaded_params[k] else: assert v == loaded_params[k] save_hparams_to_yaml(path_yaml, hparams) _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, Namespace(**hparams)) _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) if _OMEGACONF_AVAILABLE: save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) _compare_params(load_hparams_from_yaml(path_yaml), hparams)
nhead=args.nhead, num_encoder_layers=args.num_encoder_layers, num_decoder_layers=args.num_decoder_layers, dim_feedforward=args.dim_feedforward, dropout=args.dropout, activation=args.activation, warmup=args.warmup, bpe_file=args.bpe_file, lenpen=args.lenpen, beam_size=args.beam_size, ckpt_steps=args.ckpt_steps) trainer = pl.Trainer.from_argparse_args(args) # hack for saving all args from pytorch_lightning.core.saving import save_hparams_to_yaml os.makedirs(trainer.logger.log_dir) save_hparams_to_yaml(f'{trainer.logger.log_dir}/args.yaml', args) trainer.fit(mt, datamodule=dataset) # average state_dict fnames = os.listdir(f'{trainer.logger.log_dir}/checkpoints/') fnames = filter(lambda s: s.startswith('step'), fnames) fnames = sorted(fnames, key=lambda s: int(s.split('.')[1])) fnames = fnames[-5:] fnames = [f'{trainer.logger.log_dir}/checkpoints/{f}' for f in fnames] state_dicts = [torch.load(f)['state_dict'] for f in fnames] state_dict = {} for k in state_dicts[0].keys(): state_dict[k] = torch.stack([s[k] for s in state_dicts]).mean(0) mt.load_state_dict(state_dict) trainer.test()