def _del_model(self, filepath): if gfile.exists(filepath): try: # in compat mode, remove is not implemented so if running this # against an actual remove file system and the correct remote # dependencies exist then this will work fine. gfile.remove(filepath) except AttributeError: if is_remote_path(filepath): log.warning( "Unable to remove stale checkpoints due to running gfile in compatibility mode." " Please install tensorflow to run gfile in full mode" " if writing checkpoints to remote locations") else: os.remove(filepath)
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, mode: str = 'auto', period: int = 1, prefix: str = ''): super().__init__() if (filepath): filepath = str( filepath ) # the tests pass in a py.path.local but we want a str if save_top_k > 0 and filepath is not None and gfile.isdir( filepath) and len(gfile.listdir(filepath)) > 0: rank_zero_warn( f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self._rank = 0 self.monitor = monitor self.verbose = verbose if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: if gfile.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: if not is_remote_path(filepath): # dont normalize remote paths filepath = os.path.realpath(filepath) self.dirpath, self.filename = os.path.split(filepath) makedirs(self.dirpath) # calls with exist_ok self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period self.epoch_last_check = None self.prefix = prefix self.best_k_models = {} # {filename: monitor} self.kth_best_model_path = '' self.best_model_score = 0 self.best_model_path = '' self.save_function = None self.warned_result_obj = False torch_inf = torch.tensor(np.Inf) mode_dict = { 'min': (torch_inf, 'min'), 'max': (-torch_inf, 'max'), 'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') else (torch_inf, 'min'), } if mode not in mode_dict: rank_zero_warn( f'ModelCheckpoint mode {mode} is unknown, ' f'fallback to auto mode.', RuntimeWarning) mode = 'auto' self.kth_value, self.mode = mode_dict[mode]