Ejemplo n.º 1
0
    def _atomic_save(self, checkpoint, filepath: str):
        """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.

        This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
        saving is finished.

        Args:
            checkpoint: The object to save.
                Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
                accepts.
            filepath: The path to which the checkpoint will be saved.
                This points to the file that the checkpoint will be stored in.
        """
        bytesbuffer = io.BytesIO()
        # Can't use the new zipfile serialization for 1.6.0 because there's a bug in
        # torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
        # More details can be found here: https://github.com/pytorch/pytorch/issues/42239
        if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
            torch.save(checkpoint,
                       bytesbuffer,
                       _use_new_zipfile_serialization=False)
        else:
            torch.save(checkpoint, bytesbuffer)
        with cloud_open(filepath, 'wb') as f:
            f.write(bytesbuffer.getvalue())
Ejemplo n.º 2
0
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
    """
    Args:
        config_yaml: path to new YAML file
        hparams: parameters to be saved
    """
    if not gfile.isdir(os.path.dirname(config_yaml)):
        raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

    # convert Namespace or AD to dict
    if isinstance(hparams, Namespace):
        hparams = vars(hparams)
    elif isinstance(hparams, AttributeDict):
        hparams = dict(hparams)

    # saving with OmegaConf objects
    if OmegaConf is not None:
        if OmegaConf.is_config(hparams):
            OmegaConf.save(hparams, config_yaml, resolve=True)
            return
        for v in hparams.values():
            if OmegaConf.is_config(v):
                OmegaConf.save(OmegaConf.create(hparams),
                               config_yaml,
                               resolve=True)
                return

    # saving the standard way
    assert isinstance(hparams, dict)
    with cloud_open(config_yaml, 'w', newline='') as fp:
        yaml.dump(hparams, fp)
Ejemplo n.º 3
0
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
    if not gfile.isdir(os.path.dirname(tags_csv)):
        raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

    if isinstance(hparams, Namespace):
        hparams = vars(hparams)

    with cloud_open(tags_csv, "w", newline="") as fp:
        fieldnames = ["key", "value"]
        writer = csv.DictWriter(fp, fieldnames=fieldnames)
        writer.writerow({"key": "key", "value": "value"})
        for k, v in hparams.items():
            writer.writerow({"key": k, "value": v})
Ejemplo n.º 4
0
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
    """Load hparams from a file.

    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
    >>> path_yaml = './testing-hparams.yaml'
    >>> save_hparams_to_yaml(path_yaml, hparams)
    >>> hparams_new = load_hparams_from_yaml(path_yaml)
    >>> vars(hparams) == hparams_new
    True
    >>> os.remove(path_yaml)
    """
    if not gfile.exists(config_yaml):
        rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
        return {}

    with cloud_open(config_yaml, "r") as fp:
        tags = yaml.load(fp)

    return tags
Ejemplo n.º 5
0
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
    """Load hparams from a file.

    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
    >>> path_csv = os.path.join('.', 'testing-hparams.csv')
    >>> save_hparams_to_tags_csv(path_csv, hparams)
    >>> hparams_new = load_hparams_from_tags_csv(path_csv)
    >>> vars(hparams) == hparams_new
    True
    >>> os.remove(path_csv)
    """
    if not gfile.exists(tags_csv):
        rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
        return {}

    with cloud_open(tags_csv, "r", newline="") as fp:
        csv_reader = csv.reader(fp, delimiter=",")
        tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

    return tags
    def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue,
                                                results):
        if self.distributed_backend.lower() not in [
                'ddp_spawn', 'ddp_cpu', 'tpu'
        ]:
            return

        # track the best model path
        best_model_path = None
        if self.checkpoint_callback is not None:
            best_model_path = self.checkpoint_callback.best_model_path

        if self.global_rank == 0 and mp_queue is not None:
            rank_zero_warn('cleaning up ddp environment...')
            # todo, pass complete checkpoint as state dictionary
            mp_queue.put(best_model_path)
            mp_queue.put(results)

            # save the last weights
            last_path = None
            if not self.testing and best_model_path is not None and len(
                    best_model_path) > 0:
                last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
                # Can't use the new zipfile serialization for 1.6.0 because there's a bug in
                # torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
                # More details can be found here: https://github.com/pytorch/pytorch/issues/42239
                bytesbuffer = io.BytesIO()
                if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
                    torch.save(model.state_dict(),
                               bytesbuffer,
                               _use_new_zipfile_serialization=False)
                else:
                    torch.save(model.state_dict(), bytesbuffer)
                with cloud_open(last_path, 'wb') as f:
                    f.write(bytesbuffer.getvalue())
            mp_queue.put(last_path)
Ejemplo n.º 7
0
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
    """
    Args:
        config_yaml: path to new YAML file
        hparams: parameters to be saved
    """
    if not gfile.isdir(os.path.dirname(config_yaml)):
        raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

    if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
        from omegaconf import OmegaConf

        OmegaConf.save(hparams, config_yaml, resolve=True)
        return

    # saving the standard way
    if isinstance(hparams, Namespace):
        hparams = vars(hparams)
    elif isinstance(hparams, AttributeDict):
        hparams = dict(hparams)
    assert isinstance(hparams, dict)

    with cloud_open(config_yaml, "w", newline="") as fp:
        yaml.dump(hparams, fp)