Beispiel #1
0
def load_log_data(log_file, data_types_to_skip=()):
    """Loads log data into a dictionary of the form data[data_type][metric][index]."""
    # Load log_file
    assert pathmgr.exists(log_file), "Log file not found: {}".format(log_file)
    with pathmgr.open(log_file, "r") as f:
        lines = f.readlines()
    # Extract and parse lines that start with _TAG and have a type specified
    lines = [l[l.find(_TAG) + len(_TAG):] for l in lines if _TAG in l]
    lines = [simplejson.loads(l) for l in lines]
    lines = [
        l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip
    ]
    # Generate data structure accessed by data[data_type][index][metric]
    data_types = [l[_TYPE] for l in lines]
    data = {t: [] for t in data_types}
    for t, line in zip(data_types, lines):
        del line[_TYPE]
        data[t].append(line)
    # Generate data structure accessed by data[data_type][metric][index]
    for t in data:
        metrics = sorted(data[t][0].keys())
        err_str = "Inconsistent metrics in log for _type={}: {}".format(
            t, metrics)
        assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
        data[t] = {m: [d[m] for d in data[t]] for m in metrics}
    return data
Beispiel #2
0
def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None):
    """
    Loads a checkpoint selectively based on the input options.

    Each checkpoint contains both the model and model_ema weights (except checkpoints
    created by old versions of the code). If both the model and model_weights are
    requested, both sets of weights are loaded. If only the model weights are requested
    (that is if model_ema=None), the *better* set of weights is selected to be loaded
    (according to the lesser of test_err and ema_err, also stored in the checkpoint).

    The code is backward compatible with checkpoints that do not store the ema weights.
    """
    err_str = "Checkpoint '{}' not found"
    assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
    with pathmgr.open(checkpoint_file, "rb") as f:
        checkpoint = torch.load(f, map_location="cpu")
    # Get test_err and ema_err (with backward compatibility)
    test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100
    ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100
    # Load model and optionally model_ema weights (with backward compatibility)
    ema_state = "ema_state" if "ema_state" in checkpoint else "model_state"
    if model_ema:
        unwrap_model(model).load_state_dict(checkpoint["model_state"])
        unwrap_model(model_ema).load_state_dict(checkpoint[ema_state])
    else:
        best_state = "model_state" if test_err <= ema_err else ema_state
        unwrap_model(model).load_state_dict(checkpoint[best_state])
    # Load optimizer if requested
    if optimizer:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
    return checkpoint["epoch"], test_err, ema_err
Beispiel #3
0
def load_checkpoint(checkpoint_file, model, optimizer=None):
    """Loads the checkpoint from the given file."""
    err_str = "Checkpoint '{}' not found"
    assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
    with pathmgr.open(checkpoint_file, "rb") as f:
        checkpoint = torch.load(f, map_location="cpu")
    unwrap_model(model).load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
    return checkpoint["epoch"]
Beispiel #4
0
def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err):
    """Saves a checkpoint and also the best weights so far in a best checkpoint."""
    # Save checkpoints only from the main process
    if not dist.is_main_proc():
        return
    # Ensure that the checkpoint dir exists
    pathmgr.mkdirs(get_checkpoint_dir())
    # Record the state
    checkpoint = {
        "epoch": epoch,
        "test_err": test_err,
        "ema_err": ema_err,
        "model_state": unwrap_model(model).state_dict(),
        "ema_state": unwrap_model(model_ema).state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "cfg": cfg.dump(),
    }
    # Write the checkpoint
    checkpoint_file = get_checkpoint(epoch + 1)
    with pathmgr.open(checkpoint_file, "wb") as f:
        torch.save(checkpoint, f)
    # Store the best model and model_ema weights so far
    if not pathmgr.exists(get_checkpoint_best()):
        pathmgr.copy(checkpoint_file, get_checkpoint_best())
    else:
        with pathmgr.open(get_checkpoint_best(), "rb") as f:
            best = torch.load(f, map_location="cpu")
        # Select the best model weights and the best model_ema weights
        if test_err < best["test_err"] or ema_err < best["ema_err"]:
            if test_err < best["test_err"]:
                best["model_state"] = checkpoint["model_state"]
                best["test_err"] = test_err
            if ema_err < best["ema_err"]:
                best["ema_state"] = checkpoint["ema_state"]
                best["ema_err"] = ema_err
            with pathmgr.open(get_checkpoint_best(), "wb") as f:
                torch.save(best, f)
    return checkpoint_file
Beispiel #5
0
 def _load_data(self):
     """Loads data into memory."""
     logger.info("{} data path: {}".format(self._split, self._data_path))
     # Compute data batch names
     if self._split == "train":
         batch_names = ["data_batch_{}".format(i) for i in range(1, 6)]
     else:
         batch_names = ["test_batch"]
     # Load data batches
     inputs, labels = [], []
     for batch_name in batch_names:
         batch_path = os.path.join(self._data_path, batch_name)
         with pathmgr.open(batch_path, "rb") as f:
             data = pickle.load(f, encoding="bytes")
         inputs.append(data[b"data"])
         labels += data[b"labels"]
     # Combine and reshape the inputs
     assert cfg.TRAIN.IM_SIZE == 32, "CIFAR-10 images are 32x32"
     inputs = np.vstack(inputs).astype(np.float32)
     inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE))
     return inputs, labels
Beispiel #6
0
def save_checkpoint(model, optimizer, epoch, best):
    """Saves a checkpoint."""
    # Save checkpoints only from the master process
    if not dist.is_master_proc():
        return
    # Ensure that the checkpoint dir exists
    pathmgr.mkdirs(get_checkpoint_dir())
    # Record the state
    checkpoint = {
        "epoch": epoch,
        "model_state": unwrap_model(model).state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "cfg": cfg.dump(),
    }
    # Write the checkpoint
    checkpoint_file = get_checkpoint(epoch + 1)
    with pathmgr.open(checkpoint_file, "wb") as f:
        torch.save(checkpoint, f)
    # If best copy checkpoint to the best checkpoint
    if best:
        pathmgr.copy(checkpoint_file, get_checkpoint_best())
    return checkpoint_file
Beispiel #7
0
def load_cfg(cfg_file):
    """Loads config from specified file."""
    with pathmgr.open(cfg_file, "r") as f:
        _C.merge_from_other_cfg(_C.load_cfg(f))
Beispiel #8
0
def dump_cfg():
    """Dumps the config to the output directory."""
    cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
    with pathmgr.open(cfg_file, "w") as f:
        _C.dump(stream=f)