def hpc_save(self, folderpath: str, logger): # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object if not gfile.exists(folderpath): makedirs(folderpath) # save logger to make sure we get all the metrics logger.save() ckpt_number = self.max_ckpt_in_folder(folderpath) + 1 if not gfile.exists(folderpath): makedirs(folderpath) filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt') # give model a chance to do something on hpc_save model = self.get_model() checkpoint = self.dump_checkpoint() model.on_hpc_save(checkpoint) # do the actual save # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: atomic_save(checkpoint, filepath) except AttributeError as err: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( 'warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}') atomic_save(checkpoint, filepath) return filepath
def _del_model(self, filepath): if gfile.exists(filepath): try: # in compat mode, remove is not implemented so if running this # against an actual remove file system and the correct remote # dependencies exist then this will work fine. gfile.remove(filepath) except AttributeError: os.remove(filepath)
def _save_model(self, filepath, trainer, pl_module): # in debugging, track when we save checkpoints trainer.dev_debugger.track_checkpointing_history(filepath) # make paths if not gfile.exists(os.path.dirname(filepath)): makedirs(os.path.dirname(filepath)) # delegate the saving to the model if self.save_function is not None: self.save_function(filepath, self.save_weights_only) else: raise ValueError(".save_function() not set")
def _del_model(self, filepath): if gfile.exists(filepath): try: # in compat mode, remove is not implemented so if running this # against an actual remove file system and the correct remote # dependencies exist then this will work fine. gfile.remove(filepath) except AttributeError: if is_remote_path(filepath): log.warning( "Unable to remove stale checkpoints due to running gfile in compatibility mode." " Please install tensorflow to run gfile in full mode" " if writing checkpoints to remote locations") else: os.remove(filepath)
def restore_hpc_weights_if_needed(self, model: LightningModule): """If there is a set of hpc weights, use as signal to restore model.""" did_restore = False # look for hpc weights folderpath = str(self.weights_save_path) if gfile.exists(folderpath): files = gfile.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x] # if hpc weights exist restore model if len(hpc_weight_paths) > 0: self.hpc_load(folderpath, self.on_gpu) did_restore = True return did_restore
def experiment(self) -> SummaryWriter: r""" Actual tensorboard object. To use TensorBoard features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: self.logger.experiment.some_tensorboard_function() """ if self._experiment is not None: return self._experiment assert rank_zero_only.rank == 0, 'tried to init log dirs in non global_rank=0' if self.root_dir and not gfile.exists(str(self.root_dir)): makedirs(self.root_dir) self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) >>> hparams_new = load_hparams_from_yaml(path_yaml) >>> vars(hparams) == hparams_new True >>> os.remove(path_yaml) """ if not gfile.exists(config_yaml): rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) return {} with cloud_open(config_yaml, "r") as fp: tags = yaml.load(fp) return tags
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: """Load hparams from a file. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_csv = os.path.join('.', 'testing-hparams.csv') >>> save_hparams_to_tags_csv(path_csv, hparams) >>> hparams_new = load_hparams_from_tags_csv(path_csv) >>> vars(hparams) == hparams_new True >>> os.remove(path_csv) """ if not gfile.exists(tags_csv): rank_zero_warn(f"Missing Tags: {tags_csv}.", RuntimeWarning) return {} with cloud_open(tags_csv, "r", newline="") as fp: csv_reader = csv.reader(fp, delimiter=",") tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} return tags
def on_train_start(self, trainer, pl_module): """ Determines model checkpoint save directory at runtime. References attributes from the trainer's logger to determine where to save checkpoints. The base path for saving weights is set in this priority: 1. Checkpoint callback's path (if passed in) 2. The default_root_dir from trainer if trainer has no logger 3. The weights_save_path from trainer, if user provides it 4. User provided weights_saved_path The base path gets extended with logger name and version (if these are available) and subfolder "checkpoints". """ if self.dirpath is not None: return # short circuit self.filename = '{epoch}' if trainer.logger is not None: if trainer.weights_save_path != trainer.default_root_dir: # the user has changed weights_save_path, it overrides anything save_dir = trainer.weights_save_path else: save_dir = trainer.logger.save_dir or trainer.default_root_dir version = trainer.logger.version if isinstance( trainer.logger.version, str) else f'version_{trainer.logger.version}' ckpt_path = os.path.join(save_dir, trainer.logger.name, version, "checkpoints") else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") self.dirpath = ckpt_path assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' if not gfile.exists(self.dirpath): makedirs(self.dirpath)
def on_validation_end(self, trainer, pl_module): # only run on main process if trainer.global_rank != 0: return # TODO: remove when dict results are deprecated self.__warn_deprecated_monitor_key() metrics = trainer.callback_metrics epoch = trainer.current_epoch # support structured results if metrics.get('checkpoint_on') is not None: self.monitor = 'checkpoint_on' # conditioned val metrics override conditioned train loop metrics if metrics.get('val_checkpoint_on') is not None: self.monitor = 'val_checkpoint_on' if self.save_top_k == 0: # no models are saved return if self.epoch_last_check is not None and ( epoch - self.epoch_last_check) < self.period: # skipping in this term return self.epoch_last_check = epoch ckpt_name_metrics = trainer.logged_metrics filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) version_cnt = 0 while gfile.exists(filepath): filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt) # this epoch called before version_cnt += 1 if self.save_top_k != -1: current = metrics.get(self.monitor) if not isinstance(current, torch.Tensor): rank_zero_warn( f'The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved' f' HINT: what is the value of {self.monitor} in validation_epoch_end()?', RuntimeWarning) if current is not None: current = torch.tensor(current) if current is None: rank_zero_warn( f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch, trainer, pl_module) elif self.verbose > 0: log.info( f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}' ) else: if self.verbose > 0: log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' self._save_model(filepath, trainer, pl_module) if self.save_last: filepath = os.path.join( self.dirpath, self.prefix + ModelCheckpoint.CHECKPOINT_NAME_LAST) self._save_model(filepath, trainer, pl_module)
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() if (filepath): filepath = str( filepath ) # the tests pass in a py.path.local but we want a str if save_top_k > 0 and filepath is not None and gfile.isdir( filepath) and len(gfile.listdir(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self._rank = 0 self.monitor = monitor self.verbose = verbose if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) if not gfile.exists(self.dirpath): makedirs(self.dirpath) self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period self.epoch_last_check = None self.prefix = prefix self.best_k_models = {} # {filename: monitor} self.kth_best_model_path = '' self.best_model_score = 0 self.best_model_path = '' self.save_function = None torch_inf = torch.tensor(np.Inf) mode_dict = { 'min': (torch_inf, 'min'), 'max': (-torch_inf, 'max'), 'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') else (torch_inf, 'min'), } if mode not in mode_dict: rank_zero_warn( f'ModelCheckpoint mode {mode} is unknown, ' f'fallback to auto mode.', RuntimeWarning) mode = 'auto' self.kth_value, self.mode = mode_dict[mode]