def test_data_parallel(self):
        batch_size = 12
        shape = [2, 3, 4]
        x = torch.rand([batch_size] + shape)
        distribution = DataParallelDistribution(StandardNormal(shape))

        log_prob = distribution.log_prob(x)
        samples = distribution.sample(batch_size)
        samples2, log_prob2 = distribution.sample_with_log_prob(batch_size)

        self.assertIsInstance(log_prob, torch.Tensor)
        self.assertIsInstance(samples, torch.Tensor)
        self.assertIsInstance(log_prob2, torch.Tensor)
        self.assertIsInstance(samples2, torch.Tensor)
Esempio n. 2
0
###################
## Specify model ##
###################

model = get_model(args, data_shape=data_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check)
model.load_state_dict(checkpoint['model'])
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############

path_samples = '{}/samples/sample_ep{}_s{}.png'.format(
    eval_args.model, checkpoint['current_epoch'], eval_args.seed)
if not os.path.exists(os.path.dirname(path_samples)):
    os.mkdir(os.path.dirname(path_samples))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model = model.eval()
if eval_args.double: model = model.double()

samples = model.sample(
    eval_args.samples).cpu().float() / (2**args.num_bits - 1)
vutils.save_image(samples, fp=path_samples, nrow=eval_args.nrow)
Esempio n. 3
0
out_dir = os.path.join(f"{eval_args.model}",
                       f"joint_samples/seed{eval_args.seed}/")
if not os.path.exists(os.path.dirname(os.path.dirname(out_dir))):
    os.mkdir(os.path.dirname(os.path.dirname(out_dir)))

if not os.path.exists(os.path.dirname(out_dir)):
    os.mkdir(os.path.dirname(out_dir))

if not os.path.exists(out_dir):
    os.mkdir(out_dir)

for prior_temperature in eval_args.prior_temperature:
    torch.manual_seed(eval_args.seed)

    # sample low-resolution from prior
    low_res = prior_model.sample(eval_args.samples,
                                 temperature=prior_temperature)
    path_lr = os.path.join(
        out_dir,
        f"big_prior_low_resolution_e{prior_checkpoint['current_epoch']}_temperature{int(100 * prior_temperature)}.png"
    )
    save_images(low_res, path_lr)

    big_lr = torch.repeat_interleave(torch.repeat_interleave(
        low_res, args.sr_scale_factor, dim=2),
                                     args.sr_scale_factor,
                                     dim=3)
    path_big_lr = os.path.join(
        out_dir,
        f"prior_low_resolution_e{prior_checkpoint['current_epoch']}_temperature{int(100 * prior_temperature)}.png"
    )
    save_images(big_lr, path_big_lr)
Esempio n. 4
0
            f"big_lowres_e{checkpoint['current_epoch']}_temperature{int(100 * temperature)}.png"
        )
        save_images(big_lr, path_big_lr)

    else:
        num_samples_or_context = eval_args.samples
        imgs = batch  #[:eval_args.samples]

    if args.boosted_components > 1:
        for c in range(model.num_components):
            path_samples = os.path.join(
                out_dir,
                f"sample_e{checkpoint['current_epoch']}_c{c}_temperature{int(100 * temperature)}.png"
            )
            samples = model.sample(num_samples_or_context,
                                   component=c,
                                   temperature=temperature)
            save_images(samples, path_samples)

    else:
        path_samples = os.path.join(
            out_dir,
            f"sample_e{checkpoint['current_epoch']}_temperature{int(100 * temperature)}.png"
        )
        samples = model.sample(num_samples_or_context, temperature=temperature)
        save_images(samples, path_samples)

    # save real samples too
    path_true_samples = os.path.join(
        out_dir,
        f"true_e{checkpoint['current_epoch']}_temperature{int(100 * temperature)}.png"
Esempio n. 5
0
    #out /= max_color
    vutils.save_image(out, file_path)


i = 0
assert len(eval_args.temperature) > 0
for t, temperature in enumerate(eval_args.temperature):
    torch.manual_seed(eval_args.seed)

    yhat_dir = os.path.join(f"{eval_args.model}",
                            f"yhat_test/temperature{int(100 * temperature)}/")
    if not os.path.exists(os.path.dirname(yhat_dir)):
        os.mkdir(os.path.dirname(yhat_dir))
    if not os.path.exists(yhat_dir): os.mkdir(yhat_dir)

    if t == 0:
        y_dir = os.path.join(f"{eval_args.model}", f"y_test/")
        if not os.path.exists(os.path.dirname(y_dir)):
            os.mkdir(os.path.dirname(y_dir))
        if not os.path.exists(y_dir): os.mkdir(y_dir)

    for x in eval_loader:
        context = x[1].to(device)
        samples = model.sample(context, temperature=temperature)
        for y, yhat in zip(x[0], samples):
            save_images(yhat, os.path.join(yhat_dir, f"yhat_{i}.png"))
            if t == 0:
                save_images(y, os.path.join(y_dir, f"y_{i}.png"))

            i += 1