示例#1
0
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
    """
    Args:
        config_yaml: path to new YAML file
        hparams: parameters to be saved
    """
    if not gfile.isdir(os.path.dirname(config_yaml)):
        raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

    # convert Namespace or AD to dict
    if isinstance(hparams, Namespace):
        hparams = vars(hparams)
    elif isinstance(hparams, AttributeDict):
        hparams = dict(hparams)

    # saving with OmegaConf objects
    if OmegaConf is not None:
        if OmegaConf.is_config(hparams):
            OmegaConf.save(hparams, config_yaml, resolve=True)
            return
        for v in hparams.values():
            if OmegaConf.is_config(v):
                OmegaConf.save(OmegaConf.create(hparams),
                               config_yaml,
                               resolve=True)
                return

    # saving the standard way
    assert isinstance(hparams, dict)
    with cloud_open(config_yaml, 'w', newline='') as fp:
        yaml.dump(hparams, fp)
示例#2
0
    def _get_next_version(self):
        root_dir = os.path.join(self.save_dir, self.name)

        if not gfile.isdir(root_dir):
            log.warning('Missing logger folder: %s', root_dir)
            return 0

        existing_versions = []
        for d in gfile.listdir(root_dir):
            if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
                existing_versions.append(int(d.split("_")[1]))

        if len(existing_versions) == 0:
            return 0

        return max(existing_versions) + 1
示例#3
0
    def save(self) -> None:
        super().save()
        dir_path = self.log_dir
        if not gfile.isdir(dir_path):
            dir_path = self.save_dir

        # prepare the file path
        hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)

        # save the metatags file
        save_hparams_to_yaml(hparams_file, self.hparams)
示例#4
0
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
    if not gfile.isdir(os.path.dirname(tags_csv)):
        raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")

    if isinstance(hparams, Namespace):
        hparams = vars(hparams)

    with cloud_open(tags_csv, "w", newline="") as fp:
        fieldnames = ["key", "value"]
        writer = csv.DictWriter(fp, fieldnames=fieldnames)
        writer.writerow({"key": "key", "value": "value"})
        for k, v in hparams.items():
            writer.writerow({"key": k, "value": v})
示例#5
0
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
    """
    Args:
        config_yaml: path to new YAML file
        hparams: parameters to be saved
    """
    if not gfile.isdir(os.path.dirname(config_yaml)):
        raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")

    if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
        from omegaconf import OmegaConf

        OmegaConf.save(hparams, config_yaml, resolve=True)
        return

    # saving the standard way
    if isinstance(hparams, Namespace):
        hparams = vars(hparams)
    elif isinstance(hparams, AttributeDict):
        hparams = dict(hparams)
    assert isinstance(hparams, dict)

    with cloud_open(config_yaml, "w", newline="") as fp:
        yaml.dump(hparams, fp)
    def __init__(self,
                 filepath: Optional[str] = None,
                 monitor: str = 'val_loss',
                 verbose: bool = False,
                 save_last: bool = False,
                 save_top_k: int = 1,
                 save_weights_only: bool = False,
                 mode: str = 'auto',
                 period: int = 1,
                 prefix: str = ''):
        super().__init__()
        if (filepath):
            filepath = str(
                filepath
            )  # the tests pass in a py.path.local but we want a str
        if save_top_k > 0 and filepath is not None and gfile.isdir(
                filepath) and len(gfile.listdir(filepath)) > 0:
            rank_zero_warn(
                f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
                "All files in this directory will be deleted when a checkpoint is saved!"
            )
        self._rank = 0

        self.monitor = monitor
        self.verbose = verbose
        if filepath is None:  # will be determined by trainer at runtime
            self.dirpath, self.filename = None, None
        else:
            if gfile.isdir(filepath):
                self.dirpath, self.filename = filepath, '{epoch}'
            else:
                if not is_remote_path(filepath):  # dont normalize remote paths
                    filepath = os.path.realpath(filepath)
                self.dirpath, self.filename = os.path.split(filepath)
            makedirs(self.dirpath)  # calls with exist_ok
        self.save_last = save_last
        self.save_top_k = save_top_k
        self.save_weights_only = save_weights_only
        self.period = period
        self.epoch_last_check = None
        self.prefix = prefix
        self.best_k_models = {}
        # {filename: monitor}
        self.kth_best_model_path = ''
        self.best_model_score = 0
        self.best_model_path = ''
        self.save_function = None
        self.warned_result_obj = False

        torch_inf = torch.tensor(np.Inf)
        mode_dict = {
            'min': (torch_inf, 'min'),
            'max': (-torch_inf, 'max'),
            'auto': (-torch_inf, 'max') if 'acc' in self.monitor
            or self.monitor.startswith('fmeasure') else (torch_inf, 'min'),
        }

        if mode not in mode_dict:
            rank_zero_warn(
                f'ModelCheckpoint mode {mode} is unknown, '
                f'fallback to auto mode.', RuntimeWarning)
            mode = 'auto'

        self.kth_value, self.mode = mode_dict[mode]