parser.add_argument('--kernel_size_in', type=int, default=7)
parser.add_argument('--output_filters', type=int, default=1024)

# Optim params
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--milestones', type=eval, default=[])
parser.add_argument('--gamma', type=float, default=0.1)

args = parser.parse_args()
setup.prepare_env(args)

##################
## Specify data ##
##################

data = CategoricalCIFAR10()

setup.register_data(data.train, data.test)

###################
## Specify model ##
###################

model = AutoregressiveSubsetFlow2d(
    base_shape=(
        3,
        32,
        32,
    ),
    transforms=[
        QuadraticSplineAutoregressiveSubsetTransform2d(PixelCNN(
Esempio n. 2
0
def eval_elbo(create_model_fn):

    parser = argparse.ArgumentParser()

    # Training args
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_set', type=eval, default=True)
    parser.add_argument('--double', type=eval, default=False)
    parser.add_argument('--k', type=int, default=None)

    eval_args = parser.parse_args()

    if eval_args.k is None:
        batch_size = eval_args.batch_size
        iwbo_batch_size = None
    elif eval_args.k <= eval_args.batch_size:
        assert eval_args.batch_size % eval_args.k == 0
        batch_size = eval_args.batch_size // eval_args.k
        iwbo_batch_size = None
    else:
        assert eval_args.k % eval_args.batch_size == 0
        batch_size = 1
        iwbo_batch_size = eval_args.batch_size
    model_log = os.path.join(LOG_FOLDER, eval_args.model_path)
    model_check = os.path.join(CHECK_FOLDER, eval_args.model_path)

    with open('{}/args.pickle'.format(model_log), 'rb') as f:
        args = pickle.load(f)

    ##################
    ## Specify data ##
    ##################

    torch.manual_seed(0)

    data = CategoricalCIFAR10()

    if eval_args.test_set:
        data_loader = torch.utils.data.DataLoader(data.test,
                                                  batch_size=batch_size)
    else:
        data_loader = torch.utils.data.DataLoader(data.train,
                                                  batch_size=batch_size)
    test_str = 'test' if eval_args.test_set else 'train'

    ################
    ## Load model ##
    ################

    model = create_model_fn(args)

    # Load pre-trained weights
    weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu')
    model.load_state_dict(weights, strict=False)
    model = model.to(eval_args.device)
    model = model.eval()
    if eval_args.double: model = model.double()
    double_str = '_double' if eval_args.double else ''

    ####################
    ## Compute loglik ##
    ####################

    if eval_args.k is None:
        # Compute ELBO
        elbo = dataset_elbo_bpd(model,
                                data_loader,
                                device=eval_args.device,
                                double=eval_args.double)
        print('Done, ELBO: {}'.format(elbo))
        fname = '{}/elbo_bpd_{}{}.txt'.format(model_log, test_str, double_str)
        with open(fname, "w") as f:
            f.write(str(elbo))
    else:
        iwbo = dataset_iwbo_bpd(model,
                                data_loader,
                                device=eval_args.device,
                                double=eval_args.double,
                                k=eval_args.k,
                                batch_size=iwbo_batch_size)
        print('Done, IWBO({}): {}'.format(eval_args.k, iwbo))
        fname = '{}/iwbo{}_bpd_{}{}.txt'.format(model_log, eval_args.k,
                                                test_str, double_str)
        with open(fname, "w") as f:
            f.write(str(iwbo))
Esempio n. 3
0
def eval_exact(create_model_fn):

    parser = argparse.ArgumentParser()

    # Training args
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--test_set', type=eval, default=True)
    parser.add_argument('--double', type=eval, default=False)

    eval_args = parser.parse_args()

    batch_size = eval_args.batch_size
    model_log = os.path.join(LOG_FOLDER, eval_args.model_path)
    model_check = os.path.join(CHECK_FOLDER, eval_args.model_path)

    with open('{}/args.pickle'.format(model_log), 'rb') as f:
        args = pickle.load(f)

    ##################
    ## Specify data ##
    ##################

    torch.manual_seed(0)

    data = CategoricalCIFAR10()

    if eval_args.test_set:
        data_loader = torch.utils.data.DataLoader(data.test,
                                                  batch_size=batch_size)
    else:
        data_loader = torch.utils.data.DataLoader(data.train,
                                                  batch_size=batch_size)
    test_str = 'test' if eval_args.test_set else 'train'

    ################
    ## Load model ##
    ################

    model = create_model_fn(args)

    # Load pre-trained weights
    weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu')
    model.load_state_dict(weights, strict=False)
    model = model.to(eval_args.device)
    model = model.eval()
    if eval_args.double: model = model.double()
    double_str = '_double' if eval_args.double else ''

    ####################
    ## Compute loglik ##
    ####################

    bpd = dataset_loglik_bpd(model,
                             data_loader,
                             device=eval_args.device,
                             double=eval_args.double)
    print('Done, bpd: {}'.format(bpd))
    fname = '{}/exact_loglik_bpd_{}{}.txt'.format(model_log, test_str,
                                                  double_str)
    with open(fname, "w") as f:
        f.write(str(bpd))
def interpolate(create_model_fn, idx_list):

    parser = argparse.ArgumentParser()

    # Training args
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--start', type=int, default=0)
    parser.add_argument('--end', type=int, default=None)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--row_length', type=int, default=9)
    parser.add_argument('--double', type=eval, default=False)
    parser.add_argument('--clamp', type=eval, default=False)

    eval_args = parser.parse_args()

    model_log = os.path.join(LOG_FOLDER, eval_args.model_path)
    model_check = os.path.join(CHECK_FOLDER, eval_args.model_path)

    with open('{}/args.pickle'.format(model_log), 'rb') as f:
        args = pickle.load(f)

    torch.manual_seed(0)

    u = torch.rand(3, 32, 32).to(eval_args.device)
    if eval_args.double: u = u.double()

    ###############
    ## Load data ##
    ###############

    data = CategoricalCIFAR10()

    ################
    ## Load model ##
    ################

    model = create_model_fn(args)

    # Load pre-trained weights
    weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu')
    model.load_state_dict(weights, strict=False)
    model = model.to(eval_args.device)
    model = model.eval()
    if eval_args.double: model = model.double()

    ############################
    ## Perform interpolations ##
    ############################

    gaussian = Normal(0, 1)

    idxs = idx_list[eval_args.start:eval_args.end]

    with torch.no_grad():
        data1, data2 = [], []
        batch_idxs = []
        for n, (i1, i2) in enumerate(idxs):

            data1.append(data.test[i1][0].unsqueeze(0))
            data2.append(data.test[i2][0].unsqueeze(0))
            batch_idxs.append((i1, i2))

            if (n + 1) % eval_args.batch_size == 0 or (n + 1) == len(idxs):

                data1 = torch.cat(data1, dim=0)
                data2 = torch.cat(data2, dim=0)

                print("Matching pairs", (n + 1) - eval_args.batch_size, "-",
                      n + 1, "/", len(idxs))

                if eval_args.double:
                    data1 = data1.double()
                    data2 = data2.double()
                double_str = '_double' if eval_args.double else ''

                z_lower1, z_upper1 = model.forward_transform(
                    data1.to(eval_args.device))
                z_lower2, z_upper2 = model.forward_transform(
                    data2.to(eval_args.device))

                z1 = z_lower1 + (z_upper1 - z_lower1) * u
                z2 = z_lower2 + (z_upper2 - z_lower2) * u

                # Move latent to Gaussian space
                g1 = gaussian.icdf(z1)
                g2 = gaussian.icdf(z2)
                g1[g1 == -math.inf] = -1e9
                g1[g1 == math.inf] = 1e9
                g2[g2 == -math.inf] = -1e9
                g2[g2 == math.inf] = 1e9

                # Interpolation in Gaussian space:
                ws = [(w / (math.sqrt(w**2 + (1 - w)**2)),
                       (1 - w) / (math.sqrt(w**2 + (1 - w)**2)))
                      for w in np.linspace(0, 1, eval_args.row_length)]
                zw = torch.cat(
                    [gaussian.cdf(w[0] * g1 + w[1] * g2) for w in ws], dim=0)
                xw = model.inverse_transform(
                    zw, clamp=eval_args.clamp).cpu().float() / 255
                xw = xw.reshape(eval_args.row_length, len(batch_idxs),
                                *xw.shape[1:])
                for i, (i1, i2) in enumerate(batch_idxs):
                    vutils.save_image(xw[:, i],
                                      '{}/i_{}_{}_l_{}{}.png'.format(
                                          model_log, i1, i2,
                                          eval_args.row_length, double_str),
                                      nrow=eval_args.row_length,
                                      padding=2)
                print("Stored interpolations")

                data1, data2 = [], []
                batch_idxs = []