def load_log_data(log_file, data_types_to_skip=()): """Loads log data into a dictionary of the form data[data_type][metric][index].""" # Load log_file assert pathmgr.exists(log_file), "Log file not found: {}".format(log_file) with pathmgr.open(log_file, "r") as f: lines = f.readlines() # Extract and parse lines that start with _TAG and have a type specified lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l] lines = [simplejson.loads(l) for l in lines] lines = [ l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip ] # Generate data structure accessed by data[data_type][index][metric] data_types = [l[_TYPE] for l in lines] data = {t: [] for t in data_types} for t, line in zip(data_types, lines): del line[_TYPE] data[t].append(line) # Generate data structure accessed by data[data_type][metric][index] for t in data: metrics = sorted(data[t][0].keys()) err_str = "Inconsistent metrics in log for _type={}: {}".format( t, metrics) assert all(sorted(d.keys()) == metrics for d in data[t]), err_str data[t] = {m: [d[m] for d in data[t]] for m in metrics} return data
def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None): """ Loads a checkpoint selectively based on the input options. Each checkpoint contains both the model and model_ema weights (except checkpoints created by old versions of the code). If both the model and model_weights are requested, both sets of weights are loaded. If only the model weights are requested (that is if model_ema=None), the *better* set of weights is selected to be loaded (according to the lesser of test_err and ema_err, also stored in the checkpoint). The code is backward compatible with checkpoints that do not store the ema weights. """ err_str = "Checkpoint '{}' not found" assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") # Get test_err and ema_err (with backward compatibility) test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100 ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100 # Load model and optionally model_ema weights (with backward compatibility) ema_state = "ema_state" if "ema_state" in checkpoint else "model_state" if model_ema: unwrap_model(model).load_state_dict(checkpoint["model_state"]) unwrap_model(model_ema).load_state_dict(checkpoint[ema_state]) else: best_state = "model_state" if test_err <= ema_err else ema_state unwrap_model(model).load_state_dict(checkpoint[best_state]) # Load optimizer if requested if optimizer: optimizer.load_state_dict(checkpoint["optimizer_state"]) return checkpoint["epoch"], test_err, ema_err
def load_checkpoint(checkpoint_file, model, optimizer=None): """Loads the checkpoint from the given file.""" err_str = "Checkpoint '{}' not found" assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) with pathmgr.open(checkpoint_file, "rb") as f: checkpoint = torch.load(f, map_location="cpu") unwrap_model(model).load_state_dict(checkpoint["model_state"]) optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else () return checkpoint["epoch"]
def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err): """Saves a checkpoint and also the best weights so far in a best checkpoint.""" # Save checkpoints only from the main process if not dist.is_main_proc(): return # Ensure that the checkpoint dir exists pathmgr.mkdirs(get_checkpoint_dir()) # Record the state checkpoint = { "epoch": epoch, "test_err": test_err, "ema_err": ema_err, "model_state": unwrap_model(model).state_dict(), "ema_state": unwrap_model(model_ema).state_dict(), "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) with pathmgr.open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) # Store the best model and model_ema weights so far if not pathmgr.exists(get_checkpoint_best()): pathmgr.copy(checkpoint_file, get_checkpoint_best()) else: with pathmgr.open(get_checkpoint_best(), "rb") as f: best = torch.load(f, map_location="cpu") # Select the best model weights and the best model_ema weights if test_err < best["test_err"] or ema_err < best["ema_err"]: if test_err < best["test_err"]: best["model_state"] = checkpoint["model_state"] best["test_err"] = test_err if ema_err < best["ema_err"]: best["ema_state"] = checkpoint["ema_state"] best["ema_err"] = ema_err with pathmgr.open(get_checkpoint_best(), "wb") as f: torch.save(best, f) return checkpoint_file
def _load_data(self): """Loads data into memory.""" logger.info("{} data path: {}".format(self._split, self._data_path)) # Compute data batch names if self._split == "train": batch_names = ["data_batch_{}".format(i) for i in range(1, 6)] else: batch_names = ["test_batch"] # Load data batches inputs, labels = [], [] for batch_name in batch_names: batch_path = os.path.join(self._data_path, batch_name) with pathmgr.open(batch_path, "rb") as f: data = pickle.load(f, encoding="bytes") inputs.append(data[b"data"]) labels += data[b"labels"] # Combine and reshape the inputs assert cfg.TRAIN.IM_SIZE == 32, "CIFAR-10 images are 32x32" inputs = np.vstack(inputs).astype(np.float32) inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE)) return inputs, labels
def save_checkpoint(model, optimizer, epoch, best): """Saves a checkpoint.""" # Save checkpoints only from the master process if not dist.is_master_proc(): return # Ensure that the checkpoint dir exists pathmgr.mkdirs(get_checkpoint_dir()) # Record the state checkpoint = { "epoch": epoch, "model_state": unwrap_model(model).state_dict(), "optimizer_state": optimizer.state_dict(), "cfg": cfg.dump(), } # Write the checkpoint checkpoint_file = get_checkpoint(epoch + 1) with pathmgr.open(checkpoint_file, "wb") as f: torch.save(checkpoint, f) # If best copy checkpoint to the best checkpoint if best: pathmgr.copy(checkpoint_file, get_checkpoint_best()) return checkpoint_file
def load_cfg(cfg_file): """Loads config from specified file.""" with pathmgr.open(cfg_file, "r") as f: _C.merge_from_other_cfg(_C.load_cfg(f))
def dump_cfg(): """Dumps the config to the output directory.""" cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) with pathmgr.open(cfg_file, "w") as f: _C.dump(stream=f)