def load_log_data(log_file, data_types_to_skip=()): """Loads log data into a dictionary of the form data[data_type][metric][index].""" # Load log_file assert pathmgr.exists(log_file), "Log file not found: {}".format(log_file) with pathmgr.open(log_file, "r") as f: lines = f.readlines() # Extract and parse lines that start with _TAG and have a type specified lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l] lines = [simplejson.loads(l) for l in lines] lines = [ l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip ] # Generate data structure accessed by data[data_type][index][metric] data_types = [l[_TYPE] for l in lines] data = {t: [] for t in data_types} for t, line in zip(data_types, lines): del line[_TYPE] data[t].append(line) # Generate data structure accessed by data[data_type][metric][index] for t in data: metrics = sorted(data[t][0].keys()) err_str = "Inconsistent metrics in log for _type={}: {}".format( t, metrics) assert all(sorted(d.keys()) == metrics for d in data[t]), err_str data[t] = {m: [d[m] for d in data[t]] for m in metrics} return data
def __init__(self, data_path, split): assert pathmgr.exists(data_path), "Data path '{}' not found".format(data_path) splits = ["train", "test"] assert split in splits, "Split '{}' not supported for cifar".format(split) logger.info("Constructing CIFAR-10 {}...".format(split)) self._data_path, self._split = data_path, split self._inputs, self._labels = self._load_data()
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): """Get all log files in directory containing subdirs of trained models.""" names = [n for n in sorted(pathmgr.ls(log_dir)) if name_filter in n] files = [os.path.join(log_dir, n, log_file) for n in names] f_n_ps = [(f, n) for (f, n) in zip(files, names) if pathmgr.exists(f)] files, names = zip(*f_n_ps) if f_n_ps else ([], []) return files, names
def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None): """ Loads a checkpoint selectively based on the input options. Each checkpoint contains both the model and model_ema weights (except checkpoints created by old versions of the code). If both the model and model_weights are requested, both sets of weights are loaded. If only the model weights are requested (that is if model_ema=None), the *better* set of weights is selected to be loaded (according to the lesser of test_err and ema_err, also stored in the checkpoint). The code is backward compatible with checkpoints that do not store the ema weights. """ err_str = "Checkpoint '{}' not found" assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") # Get test_err and ema_err (with backward compatibility) test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100 ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100 # Load model and optionally model_ema weights (with backward compatibility) ema_state = "ema_state" if "ema_state" in checkpoint else "model_state" if model_ema: unwrap_model(model).load_state_dict(checkpoint["model_state"]) unwrap_model(model_ema).load_state_dict(checkpoint[ema_state]) else: best_state = "model_state" if test_err <= ema_err else ema_state unwrap_model(model).load_state_dict(checkpoint[best_state]) # Load optimizer if requested if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) return checkpoint["epoch"], test_err, ema_err
def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") unwrap_model(model).load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () return checkpoint["epoch"]
def delete_checkpoints(checkpoint_dir=None, keep="all"): """Deletes unneeded checkpoints, keep can be "all", "last", or "none".""" assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep) checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir() if keep == "all" or not pathmgr.exists(checkpoint_dir): return 0 checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints for checkpoint in checkpoints: pathmgr.rm(os.path.join(checkpoint_dir, checkpoint)) return len(checkpoints)
def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err): """Saves a checkpoint and also the best weights so far in a best checkpoint.""" # Save checkpoints only from the main process if not dist.is_main_proc(): return # Ensure that the checkpoint dir exists pathmgr.mkdirs(get_checkpoint_dir()) # Record the state checkpoint = { "epoch": epoch, "test_err": test_err, "ema_err": ema_err, "model_state": unwrap_model(model).state_dict(), "ema_state": unwrap_model(model_ema).state_dict(), "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) with pathmgr.open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) # Store the best model and model_ema weights so far if not pathmgr.exists(get_checkpoint_best()): pathmgr.copy(checkpoint_file, get_checkpoint_best()) else: with pathmgr.open(get_checkpoint_best(), "rb") as f: best = torch.load(f, map_location="cpu") # Select the best model weights and the best model_ema weights if test_err < best["test_err"] or ema_err < best["ema_err"]: if test_err < best["test_err"]: best["model_state"] = checkpoint["model_state"] best["test_err"] = test_err if ema_err < best["ema_err"]: best["ema_state"] = checkpoint["ema_state"] best["ema_err"] = ema_err with pathmgr.open(get_checkpoint_best(), "wb") as f: torch.save(best, f) return checkpoint_file
def has_checkpoint(): """Determines if there are checkpoints available.""" checkpoint_dir = get_checkpoint_dir() if not pathmgr.exists(checkpoint_dir): return False return any(_NAME_PREFIX in f for f in pathmgr.ls(checkpoint_dir))