Exemple #1
0
    def max_ckpt_in_folder(self, path, name_key='ckpt_'):
        files = gfile.listdir(str(path))
        files = [x for x in files if name_key in x]
        if len(files) == 0:
            return 0

        ckpt_vs = []
        for name in files:
            name = name.split(name_key)[-1]
            name = re.sub('[^0-9]', '', name)
            ckpt_vs.append(int(name))

        return max(ckpt_vs)
Exemple #2
0
    def restore_hpc_weights_if_needed(self, model: LightningModule):
        """If there is a set of hpc weights, use as signal to restore model."""
        did_restore = False

        # look for hpc weights
        folderpath = str(self.weights_save_path)
        if gfile.exists(folderpath):
            files = gfile.listdir(folderpath)
            hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

            # if hpc weights exist restore model
            if len(hpc_weight_paths) > 0:
                self.hpc_load(folderpath, self.on_gpu)
                did_restore = True
        return did_restore
    def _get_next_version(self):
        root_dir = os.path.join(self.save_dir, self.name)

        if not gfile.isdir(root_dir):
            log.warning('Missing logger folder: %s', root_dir)
            return 0

        existing_versions = []
        for d in gfile.listdir(root_dir):
            if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
                existing_versions.append(int(d.split("_")[1]))

        if len(existing_versions) == 0:
            return 0

        return max(existing_versions) + 1
    def __init__(self,
                 filepath: Optional[str] = None,
                 monitor: str = 'val_loss',
                 verbose: bool = False,
                 save_last: bool = False,
                 save_top_k: int = 1,
                 save_weights_only: bool = False,
                 mode: str = 'auto',
                 period: int = 1,
                 prefix: str = ''):
        super().__init__()
        if (filepath):
            filepath = str(
                filepath
            )  # the tests pass in a py.path.local but we want a str
        if save_top_k > 0 and filepath is not None and gfile.isdir(
                filepath) and len(gfile.listdir(filepath)) > 0:
            rank_zero_warn(
                f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
                "All files in this directory will be deleted when a checkpoint is saved!"
            )
        self._rank = 0

        self.monitor = monitor
        self.verbose = verbose
        if filepath is None:  # will be determined by trainer at runtime
            self.dirpath, self.filename = None, None
        else:
            if gfile.isdir(filepath):
                self.dirpath, self.filename = filepath, '{epoch}'
            else:
                if not is_remote_path(filepath):  # dont normalize remote paths
                    filepath = os.path.realpath(filepath)
                self.dirpath, self.filename = os.path.split(filepath)
            makedirs(self.dirpath)  # calls with exist_ok
        self.save_last = save_last
        self.save_top_k = save_top_k
        self.save_weights_only = save_weights_only
        self.period = period
        self.epoch_last_check = None
        self.prefix = prefix
        self.best_k_models = {}
        # {filename: monitor}
        self.kth_best_model_path = ''
        self.best_model_score = 0
        self.best_model_path = ''
        self.save_function = None
        self.warned_result_obj = False

        torch_inf = torch.tensor(np.Inf)
        mode_dict = {
            'min': (torch_inf, 'min'),
            'max': (-torch_inf, 'max'),
            'auto': (-torch_inf, 'max') if 'acc' in self.monitor
            or self.monitor.startswith('fmeasure') else (torch_inf, 'min'),
        }

        if mode not in mode_dict:
            rank_zero_warn(
                f'ModelCheckpoint mode {mode} is unknown, '
                f'fallback to auto mode.', RuntimeWarning)
            mode = 'auto'

        self.kth_value, self.mode = mode_dict[mode]