Exemplo n.º 1
0
 def on_save_checkpoint(self, trainer, pl_module):
     # only do this 1 time
     if trainer.current_epoch == 0:
         file_path = f"{trainer.logger.log_dir}/hparams.yaml"
         print(f"Saving hparams to file_path: {file_path}")
         save_hparams_to_yaml(config_yaml=file_path,
                              hparams=pl_module.hparams)
Exemplo n.º 2
0
 def __init__(self, args_0, args_1, args_2, kwarg_1=None):
     self.save_hyperparameters()
     self.test_hparams()
     config_file = f"{tmpdir}/hparams.yaml"
     save_hparams_to_yaml(config_file, self.hparams)
     self.hparams = load_hparams_from_yaml(config_file)
     self.test_hparams()
     super().__init__()
Exemplo n.º 3
0
 def on_save_checkpoint(self, trainer, pl_module):
     if trainer.current_epoch == 0:
         file_path = os.path.join(trainer.logger.save_dir,
                                  trainer.logger.name,
                                  f"version_{trainer.logger.version}",
                                  "hparams.yaml")
         print(f"Saving hparams to file_path: {file_path}")
         save_hparams_to_yaml(config_yaml=file_path,
                              hparams=pl_module.hparams)
Exemplo n.º 4
0
    def save(self) -> None:
        super().save()
        dir_path = self.log_dir

        # prepare the file path
        hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

        # save the metatags file if it doesn't exist and the log directory exists
        if self._fs.isdir(dir_path) and not self._fs.isfile(hparams_file):
            save_hparams_to_yaml(hparams_file, self.hparams)
Exemplo n.º 5
0
    def save(self) -> None:
        super().save()
        dir_path = self.log_dir
        if not os.path.isdir(dir_path):
            dir_path = self.save_dir

        # prepare the file path
        hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

        # save the metatags file
        save_hparams_to_yaml(hparams_file, self.hparams)
Exemplo n.º 6
0
    def save(self) -> None:
        # Initialize experiment
        _ = self.experiment

        super().save()

        # prepare the file path
        hparams_file = os.path.join(self.save_dir, self.NAME_HPARAMS_FILE)

        # save the metatags file if it doesn't exist
        if not os.path.isfile(hparams_file):
            save_hparams_to_yaml(hparams_file, self.hparams)
Exemplo n.º 7
0
    def save(self) -> None:
        """Save recorded hparams and metrics into files."""
        hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
        save_hparams_to_yaml(hparams_file, self.hparams)

        if not self.metrics:
            return

        last_m = {}
        for m in self.metrics:
            last_m.update(m)
        metrics_keys = list(last_m.keys())

        with open(self.metrics_file_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=metrics_keys)
            writer.writeheader()
            writer.writerows(self.metrics)
Exemplo n.º 8
0
    def save(self) -> None:
        super().save()
        try:
            self.experiment.flush()
        except AttributeError:
            # you are using PT version (<v1.2) which does not have implemented flush
            self.experiment._get_file_writer().flush()

        dir_path = self.log_dir
        if not os.path.isdir(dir_path):
            dir_path = self.save_dir

        # prepare the file path
        hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

        # save the metatags file
        save_hparams_to_yaml(hparams_file, self.hparams)
Exemplo n.º 9
0
    def save(self) -> None:
        super().save()
        dir_path = self.log_dir
        if not os.path.isdir(dir_path):
            dir_path = self.save_dir

        # prepare the file path
        hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

        # save the metatags file
        if Container is not None:
            if isinstance(self.hparams, Container):
                from omegaconf import OmegaConf
                OmegaConf.save(self.hparams, hparams_file, resolve=True)
            else:
                save_hparams_to_yaml(hparams_file, self.hparams)
        else:
            save_hparams_to_yaml(hparams_file, self.hparams)
Exemplo n.º 10
0
def test_hparams_save_yaml(tmpdir):
    hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
                   nasted=dict(any_num=123, anystr='abcd'))
    path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')

    save_hparams_to_yaml(path_yaml, hparams)
    assert load_hparams_from_yaml(path_yaml) == hparams

    save_hparams_to_yaml(path_yaml, Namespace(**hparams))
    assert load_hparams_from_yaml(path_yaml) == hparams

    save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
    assert load_hparams_from_yaml(path_yaml) == hparams

    save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
    assert load_hparams_from_yaml(path_yaml) == hparams
Exemplo n.º 11
0
def test_hparams_save_yaml(tmpdir):
    class Options(str, Enum):
        option1name = "option1val"
        option2name = "option2val"
        option3name = "option3val"

    hparams = dict(
        batch_size=32,
        learning_rate=0.001,
        data_root="./any/path/here",
        nested=dict(any_num=123, anystr="abcd"),
        switch=Options.option3name,
    )
    path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")

    def _compare_params(loaded_params, default_params: dict):
        assert isinstance(loaded_params, (dict, DictConfig))
        assert loaded_params.keys() == default_params.keys()
        for k, v in default_params.items():
            if isinstance(v, Enum):
                assert v.name == loaded_params[k]
            else:
                assert v == loaded_params[k]

    save_hparams_to_yaml(path_yaml, hparams)
    _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False),
                    hparams)

    save_hparams_to_yaml(path_yaml, Namespace(**hparams))
    _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False),
                    hparams)

    save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
    _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False),
                    hparams)

    if _OMEGACONF_AVAILABLE:
        save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
        _compare_params(load_hparams_from_yaml(path_yaml), hparams)
Exemplo n.º 12
0
            nhead=args.nhead,
            num_encoder_layers=args.num_encoder_layers,
            num_decoder_layers=args.num_decoder_layers,
            dim_feedforward=args.dim_feedforward,
            dropout=args.dropout,
            activation=args.activation,
            warmup=args.warmup,
            bpe_file=args.bpe_file,
            lenpen=args.lenpen,
            beam_size=args.beam_size,
            ckpt_steps=args.ckpt_steps)
    trainer = pl.Trainer.from_argparse_args(args)
    # hack for saving all args
    from pytorch_lightning.core.saving import save_hparams_to_yaml
    os.makedirs(trainer.logger.log_dir)
    save_hparams_to_yaml(f'{trainer.logger.log_dir}/args.yaml', args)

    trainer.fit(mt, datamodule=dataset)

    # average state_dict
    fnames = os.listdir(f'{trainer.logger.log_dir}/checkpoints/')
    fnames = filter(lambda s: s.startswith('step'), fnames)
    fnames = sorted(fnames, key=lambda s: int(s.split('.')[1]))
    fnames = fnames[-5:]
    fnames = [f'{trainer.logger.log_dir}/checkpoints/{f}' for f in fnames]
    state_dicts = [torch.load(f)['state_dict'] for f in fnames]
    state_dict = {}
    for k in state_dicts[0].keys():
        state_dict[k] = torch.stack([s[k] for s in state_dicts]).mean(0)
    mt.load_state_dict(state_dict)
    trainer.test()