コード例 #1
0
ファイル: stackmnist_mode.py プロジェクト: wgrathwohl/VERA
    def load_model(directory, itr,  return_p=False):
        """
        Load models given experiment directory and itr.
        """
        path = os.path.join("experiments", directory, "save_model", "{:06d}.pt".format(itr))

        # load arguments
        with open(os.path.join("experiments", directory, "args.txt"), 'r') as f:
            args = argparse.Namespace(**json.load(f))

        logp_net, g = get_models(args, log=False)

        ckpt = torch.load(path, map_location=torch.device('cpu'))

        logp_net.load_state_dict(ckpt["model"]["logp_net"])
        g.load_state_dict(ckpt["model"]["g"])

        # get true labels
        train_loader, _, _ = utils.data.get_data(args)
        label_counts = torch.zeros(1000)
        for _, label in train_loader:
            label_counts[label.long()] += 1

        label_counts /= label_counts.sum()

        if return_p:
            return logp_net, g, args, label_counts
        else:
            return logp_net, g, args
コード例 #2
0
ファイル: eval.py プロジェクト: wgrathwohl/VERA
def main(args):
    utils.makedirs(args.save_dir)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    f, g = get_models(args)

    print(f"loading model from {args.ckpt_path}")

    # load em up
    ckpt = torch.load(args.ckpt_path)
    f.load_state_dict(ckpt["model"]["logp_net"])
    g.load_state_dict(ckpt["model"]["g"])

    f = f.to(device)
    g = g.to(device)
    f.eval()

    if args.eval == "OOD":
        OODAUC(f, args, device)
    elif args.eval == "test_clf":
        test_clf(f, args, device)
    elif args.eval == "cond_samples":
        cond_samples(f, g, args, device)
    elif args.eval == "uncond_samples":
        uncond_samples(f, g, args, device)
    elif args.eval == "logp_hist":
        logp_hist(f, args, device)