Пример #1
0
    def hpc_save(self, folderpath: str, logger):
        # make sure the checkpoint folder exists
        folderpath = str(folderpath)  # because the tests pass a path object
        if not gfile.exists(folderpath):
            makedirs(folderpath)

        # save logger to make sure we get all the metrics
        logger.save()

        ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

        if not gfile.exists(folderpath):
            makedirs(folderpath)
        filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

        # give model a chance to do something on hpc_save
        model = self.get_model()
        checkpoint = self.dump_checkpoint()

        model.on_hpc_save(checkpoint)

        # do the actual save
        # TODO: fix for anything with multiprocess DP, DDP, DDP2
        try:
            atomic_save(checkpoint, filepath)
        except AttributeError as err:
            if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
                del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
            rank_zero_warn(
                'warning, `module_arguments` dropped from checkpoint.'
                f' An attribute is not picklable {err}')
            atomic_save(checkpoint, filepath)

        return filepath
Пример #2
0
 def _del_model(self, filepath):
     if gfile.exists(filepath):
         try:
             # in compat mode, remove is not implemented so if running this
             # against an actual remove file system and the correct remote
             # dependencies exist then this will work fine.
             gfile.remove(filepath)
         except AttributeError:
             os.remove(filepath)
    def _save_model(self, filepath, trainer, pl_module):

        # in debugging, track when we save checkpoints
        trainer.dev_debugger.track_checkpointing_history(filepath)

        # make paths
        if not gfile.exists(os.path.dirname(filepath)):
            makedirs(os.path.dirname(filepath))

        # delegate the saving to the model
        if self.save_function is not None:
            self.save_function(filepath, self.save_weights_only)
        else:
            raise ValueError(".save_function() not set")
 def _del_model(self, filepath):
     if gfile.exists(filepath):
         try:
             # in compat mode, remove is not implemented so if running this
             # against an actual remove file system and the correct remote
             # dependencies exist then this will work fine.
             gfile.remove(filepath)
         except AttributeError:
             if is_remote_path(filepath):
                 log.warning(
                     "Unable to remove stale checkpoints due to running gfile in compatibility mode."
                     " Please install tensorflow to run gfile in full mode"
                     " if writing checkpoints to remote locations")
             else:
                 os.remove(filepath)
Пример #5
0
    def restore_hpc_weights_if_needed(self, model: LightningModule):
        """If there is a set of hpc weights, use as signal to restore model."""
        did_restore = False

        # look for hpc weights
        folderpath = str(self.weights_save_path)
        if gfile.exists(folderpath):
            files = gfile.listdir(folderpath)
            hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

            # if hpc weights exist restore model
            if len(hpc_weight_paths) > 0:
                self.hpc_load(folderpath, self.on_gpu)
                did_restore = True
        return did_restore
Пример #6
0
    def experiment(self) -> SummaryWriter:
        r"""
        Actual tensorboard object. To use TensorBoard features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

        Example::

            self.logger.experiment.some_tensorboard_function()

        """
        if self._experiment is not None:
            return self._experiment

        assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0'
        if self.root_dir and not gfile.exists(str(self.root_dir)):
            makedirs(self.root_dir)
        self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
        return self._experiment
Пример #7
0
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
    """Load hparams from a file.

    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
    >>> path_yaml = './testing-hparams.yaml'
    >>> save_hparams_to_yaml(path_yaml, hparams)
    >>> hparams_new = load_hparams_from_yaml(path_yaml)
    >>> vars(hparams) == hparams_new
    True
    >>> os.remove(path_yaml)
    """
    if not gfile.exists(config_yaml):
        rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning)
        return {}

    with cloud_open(config_yaml, "r") as fp:
        tags = yaml.load(fp)

    return tags
Пример #8
0
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
    """Load hparams from a file.

    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
    >>> path_csv = os.path.join('.', 'testing-hparams.csv')
    >>> save_hparams_to_tags_csv(path_csv, hparams)
    >>> hparams_new = load_hparams_from_tags_csv(path_csv)
    >>> vars(hparams) == hparams_new
    True
    >>> os.remove(path_csv)
    """
    if not gfile.exists(tags_csv):
        rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning)
        return {}

    with cloud_open(tags_csv, "r", newline="") as fp:
        csv_reader = csv.reader(fp, delimiter=",")
        tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

    return tags
    def on_train_start(self, trainer, pl_module):
        """
        Determines model checkpoint save directory at runtime. References attributes from the
        trainer's logger to determine where to save checkpoints.
        The base path for saving weights is set in this priority:

        1.  Checkpoint callback's path (if passed in)
        2.  The default_root_dir from trainer if trainer has no logger
        3.  The weights_save_path from trainer, if user provides it
        4.  User provided weights_saved_path

        The base path gets extended with logger name and version (if these are available)
        and subfolder "checkpoints".
        """
        if self.dirpath is not None:
            return  # short circuit

        self.filename = '{epoch}'

        if trainer.logger is not None:
            if trainer.weights_save_path != trainer.default_root_dir:
                # the user has changed weights_save_path, it overrides anything
                save_dir = trainer.weights_save_path
            else:
                save_dir = trainer.logger.save_dir or trainer.default_root_dir

            version = trainer.logger.version if isinstance(
                trainer.logger.version,
                str) else f'version_{trainer.logger.version}'
            ckpt_path = os.path.join(save_dir, trainer.logger.name, version,
                                     "checkpoints")
        else:
            ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

        self.dirpath = ckpt_path

        assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
        if not gfile.exists(self.dirpath):
            makedirs(self.dirpath)
Пример #10
0
    def on_validation_end(self, trainer, pl_module):
        # only run on main process
        if trainer.global_rank != 0:
            return

        # TODO: remove when dict results are deprecated
        self.__warn_deprecated_monitor_key()

        metrics = trainer.callback_metrics
        epoch = trainer.current_epoch

        # support structured results
        if metrics.get('checkpoint_on') is not None:
            self.monitor = 'checkpoint_on'

        # conditioned val metrics override conditioned train loop metrics
        if metrics.get('val_checkpoint_on') is not None:
            self.monitor = 'val_checkpoint_on'

        if self.save_top_k == 0:
            # no models are saved
            return
        if self.epoch_last_check is not None and (
                epoch - self.epoch_last_check) < self.period:
            # skipping in this term
            return

        self.epoch_last_check = epoch

        ckpt_name_metrics = trainer.logged_metrics
        filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
        version_cnt = 0
        while gfile.exists(filepath):
            filepath = self.format_checkpoint_name(epoch,
                                                   ckpt_name_metrics,
                                                   ver=version_cnt)
            # this epoch called before
            version_cnt += 1

        if self.save_top_k != -1:
            current = metrics.get(self.monitor)

            if not isinstance(current, torch.Tensor):
                rank_zero_warn(
                    f'The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved'
                    f' HINT: what is the value of {self.monitor} in validation_epoch_end()?',
                    RuntimeWarning)
                if current is not None:
                    current = torch.tensor(current)

            if current is None:
                rank_zero_warn(
                    f'Can save best model only with {self.monitor} available, skipping.',
                    RuntimeWarning)
            elif self.check_monitor_top_k(current):
                self._do_check_save(filepath, current, epoch, trainer,
                                    pl_module)
            elif self.verbose > 0:
                log.info(
                    f'\nEpoch {epoch:05d}: {self.monitor}  was not in top {self.save_top_k}'
                )

        else:
            if self.verbose > 0:
                log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

            assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
            self._save_model(filepath, trainer, pl_module)

        if self.save_last:
            filepath = os.path.join(
                self.dirpath,
                self.prefix + ModelCheckpoint.CHECKPOINT_NAME_LAST)
            self._save_model(filepath, trainer, pl_module)
Пример #11
0
    def __init__(self,
                 filepath: Optional[str] = None,
                 monitor: str = 'val_loss',
                 verbose: bool = False,
                 save_last: bool = False,
                 save_top_k: int = 1,
                 save_weights_only: bool = False,
                 mode: str = 'auto',
                 period: int = 1,
                 prefix: str = ''):
        super().__init__()
        if (filepath):
            filepath = str(
                filepath
            )  # the tests pass in a py.path.local but we want a str
        if save_top_k > 0 and filepath is not None and gfile.isdir(
                filepath) and len(gfile.listdir(filepath)) > 0:
            rank_zero_warn(
                f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
                "All files in this directory will be deleted when a checkpoint is saved!"
            )
        self._rank = 0

        self.monitor = monitor
        self.verbose = verbose
        if filepath is None:  # will be determined by trainer at runtime
            self.dirpath, self.filename = None, None
        else:
            if gfile.isdir(filepath):
                self.dirpath, self.filename = filepath, '{epoch}'
            else:
                filepath = os.path.realpath(filepath)
                self.dirpath, self.filename = os.path.split(filepath)
            if not gfile.exists(self.dirpath):
                makedirs(self.dirpath)
        self.save_last = save_last
        self.save_top_k = save_top_k
        self.save_weights_only = save_weights_only
        self.period = period
        self.epoch_last_check = None
        self.prefix = prefix
        self.best_k_models = {}
        # {filename: monitor}
        self.kth_best_model_path = ''
        self.best_model_score = 0
        self.best_model_path = ''
        self.save_function = None

        torch_inf = torch.tensor(np.Inf)
        mode_dict = {
            'min': (torch_inf, 'min'),
            'max': (-torch_inf, 'max'),
            'auto': (-torch_inf, 'max') if 'acc' in self.monitor
            or self.monitor.startswith('fmeasure') else (torch_inf, 'min'),
        }

        if mode not in mode_dict:
            rank_zero_warn(
                f'ModelCheckpoint mode {mode} is unknown, '
                f'fallback to auto mode.', RuntimeWarning)
            mode = 'auto'

        self.kth_value, self.mode = mode_dict[mode]