parser.add_argument('--kernel_size_in', type=int, default=7) parser.add_argument('--output_filters', type=int, default=1024) # Optim params parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--milestones', type=eval, default=[]) parser.add_argument('--gamma', type=float, default=0.1) args = parser.parse_args() setup.prepare_env(args) ################## ## Specify data ## ################## data = CategoricalCIFAR10() setup.register_data(data.train, data.test) ################### ## Specify model ## ################### model = AutoregressiveSubsetFlow2d( base_shape=( 3, 32, 32, ), transforms=[ QuadraticSplineAutoregressiveSubsetTransform2d(PixelCNN(
def eval_elbo(create_model_fn): parser = argparse.ArgumentParser() # Training args parser.add_argument('--model_path', type=str) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--batch_size', type=int, default=100) parser.add_argument('--test_set', type=eval, default=True) parser.add_argument('--double', type=eval, default=False) parser.add_argument('--k', type=int, default=None) eval_args = parser.parse_args() if eval_args.k is None: batch_size = eval_args.batch_size iwbo_batch_size = None elif eval_args.k <= eval_args.batch_size: assert eval_args.batch_size % eval_args.k == 0 batch_size = eval_args.batch_size // eval_args.k iwbo_batch_size = None else: assert eval_args.k % eval_args.batch_size == 0 batch_size = 1 iwbo_batch_size = eval_args.batch_size model_log = os.path.join(LOG_FOLDER, eval_args.model_path) model_check = os.path.join(CHECK_FOLDER, eval_args.model_path) with open('{}/args.pickle'.format(model_log), 'rb') as f: args = pickle.load(f) ################## ## Specify data ## ################## torch.manual_seed(0) data = CategoricalCIFAR10() if eval_args.test_set: data_loader = torch.utils.data.DataLoader(data.test, batch_size=batch_size) else: data_loader = torch.utils.data.DataLoader(data.train, batch_size=batch_size) test_str = 'test' if eval_args.test_set else 'train' ################ ## Load model ## ################ model = create_model_fn(args) # Load pre-trained weights weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu') model.load_state_dict(weights, strict=False) model = model.to(eval_args.device) model = model.eval() if eval_args.double: model = model.double() double_str = '_double' if eval_args.double else '' #################### ## Compute loglik ## #################### if eval_args.k is None: # Compute ELBO elbo = dataset_elbo_bpd(model, data_loader, device=eval_args.device, double=eval_args.double) print('Done, ELBO: {}'.format(elbo)) fname = '{}/elbo_bpd_{}{}.txt'.format(model_log, test_str, double_str) with open(fname, "w") as f: f.write(str(elbo)) else: iwbo = dataset_iwbo_bpd(model, data_loader, device=eval_args.device, double=eval_args.double, k=eval_args.k, batch_size=iwbo_batch_size) print('Done, IWBO({}): {}'.format(eval_args.k, iwbo)) fname = '{}/iwbo{}_bpd_{}{}.txt'.format(model_log, eval_args.k, test_str, double_str) with open(fname, "w") as f: f.write(str(iwbo))
def eval_exact(create_model_fn): parser = argparse.ArgumentParser() # Training args parser.add_argument('--model_path', type=str) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--test_set', type=eval, default=True) parser.add_argument('--double', type=eval, default=False) eval_args = parser.parse_args() batch_size = eval_args.batch_size model_log = os.path.join(LOG_FOLDER, eval_args.model_path) model_check = os.path.join(CHECK_FOLDER, eval_args.model_path) with open('{}/args.pickle'.format(model_log), 'rb') as f: args = pickle.load(f) ################## ## Specify data ## ################## torch.manual_seed(0) data = CategoricalCIFAR10() if eval_args.test_set: data_loader = torch.utils.data.DataLoader(data.test, batch_size=batch_size) else: data_loader = torch.utils.data.DataLoader(data.train, batch_size=batch_size) test_str = 'test' if eval_args.test_set else 'train' ################ ## Load model ## ################ model = create_model_fn(args) # Load pre-trained weights weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu') model.load_state_dict(weights, strict=False) model = model.to(eval_args.device) model = model.eval() if eval_args.double: model = model.double() double_str = '_double' if eval_args.double else '' #################### ## Compute loglik ## #################### bpd = dataset_loglik_bpd(model, data_loader, device=eval_args.device, double=eval_args.double) print('Done, bpd: {}'.format(bpd)) fname = '{}/exact_loglik_bpd_{}{}.txt'.format(model_log, test_str, double_str) with open(fname, "w") as f: f.write(str(bpd))
def interpolate(create_model_fn, idx_list): parser = argparse.ArgumentParser() # Training args parser.add_argument('--model_path', type=str) parser.add_argument('--start', type=int, default=0) parser.add_argument('--end', type=int, default=None) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--row_length', type=int, default=9) parser.add_argument('--double', type=eval, default=False) parser.add_argument('--clamp', type=eval, default=False) eval_args = parser.parse_args() model_log = os.path.join(LOG_FOLDER, eval_args.model_path) model_check = os.path.join(CHECK_FOLDER, eval_args.model_path) with open('{}/args.pickle'.format(model_log), 'rb') as f: args = pickle.load(f) torch.manual_seed(0) u = torch.rand(3, 32, 32).to(eval_args.device) if eval_args.double: u = u.double() ############### ## Load data ## ############### data = CategoricalCIFAR10() ################ ## Load model ## ################ model = create_model_fn(args) # Load pre-trained weights weights = torch.load('{}/model.pt'.format(model_check), map_location='cpu') model.load_state_dict(weights, strict=False) model = model.to(eval_args.device) model = model.eval() if eval_args.double: model = model.double() ############################ ## Perform interpolations ## ############################ gaussian = Normal(0, 1) idxs = idx_list[eval_args.start:eval_args.end] with torch.no_grad(): data1, data2 = [], [] batch_idxs = [] for n, (i1, i2) in enumerate(idxs): data1.append(data.test[i1][0].unsqueeze(0)) data2.append(data.test[i2][0].unsqueeze(0)) batch_idxs.append((i1, i2)) if (n + 1) % eval_args.batch_size == 0 or (n + 1) == len(idxs): data1 = torch.cat(data1, dim=0) data2 = torch.cat(data2, dim=0) print("Matching pairs", (n + 1) - eval_args.batch_size, "-", n + 1, "/", len(idxs)) if eval_args.double: data1 = data1.double() data2 = data2.double() double_str = '_double' if eval_args.double else '' z_lower1, z_upper1 = model.forward_transform( data1.to(eval_args.device)) z_lower2, z_upper2 = model.forward_transform( data2.to(eval_args.device)) z1 = z_lower1 + (z_upper1 - z_lower1) * u z2 = z_lower2 + (z_upper2 - z_lower2) * u # Move latent to Gaussian space g1 = gaussian.icdf(z1) g2 = gaussian.icdf(z2) g1[g1 == -math.inf] = -1e9 g1[g1 == math.inf] = 1e9 g2[g2 == -math.inf] = -1e9 g2[g2 == math.inf] = 1e9 # Interpolation in Gaussian space: ws = [(w / (math.sqrt(w**2 + (1 - w)**2)), (1 - w) / (math.sqrt(w**2 + (1 - w)**2))) for w in np.linspace(0, 1, eval_args.row_length)] zw = torch.cat( [gaussian.cdf(w[0] * g1 + w[1] * g2) for w in ws], dim=0) xw = model.inverse_transform( zw, clamp=eval_args.clamp).cpu().float() / 255 xw = xw.reshape(eval_args.row_length, len(batch_idxs), *xw.shape[1:]) for i, (i1, i2) in enumerate(batch_idxs): vutils.save_image(xw[:, i], '{}/i_{}_{}_l_{}{}.png'.format( model_log, i1, i2, eval_args.row_length, double_str), nrow=eval_args.row_length, padding=2) print("Stored interpolations") data1, data2 = [], [] batch_idxs = []