Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--model',
                        type=str,
                        default='gmvae',
                        help="run VAE or GMVAE")
    parser.add_argument('--z',
                        type=int,
                        default=4,
                        help="Number of latent dimensions")
    parser.add_argument('--iter_max',
                        type=int,
                        default=10000,
                        help="Number of training iterations")
    # parser.add_argument('--iter_save', type=int, default=1000,  help="Save model every n iterations")
    parser.add_argument('--run',
                        type=int,
                        default=0,
                        help="Run ID. In case you want to run replicates")
    parser.add_argument('--train',
                        type=int,
                        default=0,
                        help="Flag for training")
    # parser.add_argument('--batch',     type=int, default=100,    help="Batch size")
    parser.add_argument('--k',
                        type=int,
                        default=16,
                        help="Number mixture components in MoG prior")
    parser.add_argument('--warmup',
                        type=int,
                        default=1,
                        help="Fix variance during first 1/4 of training")
    args = parser.parse_args()
    layout = [('model={:s}', args.model), ('z={:02d}', args.z),
              ('run={:04d}', args.run)]
    model_name = '_'.join([t.format(v) for (t, v) in layout])
    pprint(vars(args))
    print('Model name:', model_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if args.model == 'gmvae':
        model = GMVAE(z_dim=args.z,
                      name=model_name,
                      x_dim=24,
                      warmup=args.warmup,
                      k=args.k).to(device)
    else:
        model = VAE(z_dim=args.z,
                    name=model_name,
                    x_dim=24,
                    warmup=args.warmup).to(device)

    ut.load_model_by_name(model, global_step=args.iter_max)

    # make_image_load(model)
    make_image_load_z(model)
Exemplo n.º 2
0
parser.add_argument('--iter_save', type=int, default=10000, help="Save model every n iterations")
parser.add_argument('--run',       type=int, default=0,     help="Run ID. In case you want to run replicates")
parser.add_argument('--train',     type=int, default=1,     help="Flag for training")
args = parser.parse_args()
layout = [
    ('model={:s}',  'vae'),
    ('z={:02d}',  args.z),
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
vae = VAE(z_dim=args.z, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=vae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(vae, labeled_subset, run_iwae=args.train == 2)
    x = vae.sample_x(100).view(100, 1, 28, 28)
    db.printTensor(x)
    ImUtil.showBatch(x, show=True)
Exemplo n.º 3
0
parser.add_argument('--iter_save', type=int, default=10000, help="Save model every n iterations")
parser.add_argument('--run',       type=int, default=0,     help="Run ID. In case you want to run replicates")
parser.add_argument('--train',     type=int, default=1,     help="Flag for training")
args = parser.parse_args()
layout = [
    ('model={:s}',  'vae'),
    ('z={:02d}',  args.z),
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
vae = VAE(z_dim=args.z, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=vae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(vae, labeled_subset, run_iwae=args.train == 2)

else:
    ut.load_model_by_name(vae, global_step=args.iter_max, device=device)
Exemplo n.º 4
0
train_loader, labeled_subset, _ = ut.get_mnist_data(device,
                                                    train_set,
                                                    test_set,
                                                    use_test_subset=True)

data_individual_length = 5000
# Balance datasets, each digit has 5000 examples.
data_set_individual, data_loader_individual = generate_individual_set_loader(
    device, train_set, data_individual_length)

z_prior_m = torch.nn.Parameter(torch.zeros(args.z),
                               requires_grad=False).to(device)
z_prior_v = torch.nn.Parameter(torch.ones(args.z),
                               requires_grad=False).to(device)
vae = VAE(z_dim=args.z,
          name=model_name,
          z_prior_m=z_prior_m,
          z_prior_v=z_prior_v).to(device)

# train_args:
# 1 -> step 1: get the model
# 2 -> step 2: get mean and variance
# 3 -> step 3: refine the model

# train_args = 1
train_args = None
if train_args == 1:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(
        model=vae,
        train_loader=train_loader,
        # train_loader=data_loader_individual[0],
def refine(train_loader_set,
           mean_set,
           variance_set,
           z_dim,
           device,
           tqdm,
           writer,
           iter_max=np.inf,
           iter_save=np.inf,
           model_name='model',
           y_status='none',
           reinitialize=False):
    # Optimization

    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for index, train_loader in enumerate(train_loader_set):
                print("Iteration:", i)
                print("index: ", index)

                z_prior_m = torch.nn.Parameter(mean_set[index].cpu(),
                                               requires_grad=False).to(device)
                z_prior_v = torch.nn.Parameter(variance_set[index].cpu(),
                                               requires_grad=False).to(device)
                vae = VAE(z_dim=z_dim,
                          name=model_name,
                          z_prior_m=z_prior_m,
                          z_prior_v=z_prior_v).to(device)
                optimizer = optim.Adam(vae.parameters(), lr=1e-3)
                if i == 0:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=20000)
                else:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=iter_save)
                for batch_idx, (xu, yu) in enumerate(train_loader):
                    # i is num of gradient steps taken by end of loop iteration
                    optimizer.zero_grad()

                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = vae.loss_encoder(xu)

                    loss.backward()
                    optimizer.step()

                    # Feel free to modify the progress bar

                    pbar.set_postfix(loss='{:.2e}'.format(loss))

                    pbar.update(1)

                    i += 1
                    # Log summaries
                    if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                    if i == iter_max:
                        ut.save_model_by_name(vae, 0)
                        return

                # Save model
                ut.save_model_by_name(vae, iter_save)
Exemplo n.º 6
0
MLP = v3.MLP(784, 512, 64)
#q_m, q_v = enc.encode(x)
#mean = mean_layer.forward(l_enc_a)
#var  = var_layer.forward(l_enc_a)
l_enc_a, mu, var = MLP.encode(x)
print(l_enc_a.shape)
print(mu.shape)
print(var.shape)

Dec = v3.FinalDecoder()
out = Dec.decode(mu)
print(out.shape)

#test LVAE
lvae = LVAE()
vae = VAE(z_dim=2)
z_down, mu0, var0, _, _ = lvae.EncoderPartialDecoder(x)
print("mu0", mu0.shape)

decoded_logits = lvae.DecodeRest(mu0, var0)
print(decoded_logits.shape)

nelbo, _, _ = lvae.negative_elbo_bound(x, 1)
print(nelbo)
loss, summaries = lvae.loss(x, 1)

train_loader, labeled_subset, _ = ut.get_mnist_data(device,
                                                    use_test_subset=True)
beta = DeterministicWarmup(n=50, t_max=1)

for i in range(10):
Exemplo n.º 7
0
                                           num_workers=0)

test_loader = torch.utils.data.DataLoader(testgen,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=0)

full_loader = torch.utils.data.DataLoader(dc,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=0)

#raise Exception("Stop")

#train_loader, labeled_subset, test_loader = ut.get_mnist_data_and_test(device, use_test_subset=True)
vae = VAE(nn='popv', encode_dim=len(dc[0]), z_dim=args.z,
          name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=vae,
          train_loader=train_loader,
          labeled_subset=None,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(vae,
                            None,
                            run_iwae=args.train == 2,
                            ox=torch.tensor(testgen))
Exemplo n.º 8
0
                    default=10000,
                    help="Save model every n iterations")
parser.add_argument('--run',
                    type=int,
                    default=0,
                    help="Run ID. In case you want to run replicates")
args = parser.parse_args()
layout = [('model={:s}', 'vae'), ('z={:02d}', args.z),
          ('run={:04d}', args.run)]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device,
                                                    use_test_subset=True)
vae = VAE(z_dim=args.z, name=model_name).to(device)

ut.load_model_by_name(vae, global_step=args.iter_max)
ut.evaluate_lower_bound(vae, labeled_subset, run_iwae=False)
samples = torch.reshape(vae.sample_x(200), (10, 20, 28, 28))
#print(torch.reshape(vae.sample_x(200), (200, 28, 28)))

f, axarr = plt.subplots(10, 20)

for i in range(samples.shape[0]):
    for j in range(samples.shape[1]):
        axarr[i, j].imshow(samples[i, j].detach().numpy())
        axarr[i, j].axis('off')

plt.show()
Exemplo n.º 9
0
parser.add_argument('--run',
                    type=int,
                    default=0,
                    help="Run ID. In case you want to run replicates")
parser.add_argument('--train', type=int, default=1, help="Flag for training")
args = parser.parse_args()
layout = [('model={:s}', 'vae'), ('z={:02d}', args.z),
          ('run={:04d}', args.run)]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device,
                                                    use_test_subset=True)
vae = VAE(z_dim=args.z, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=vae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(vae, labeled_subset, run_iwae=args.train == 2)

else:
    ut.load_model_by_name(vae, global_step=args.iter_max)
Exemplo n.º 10
0
    # 'car': (-8.742585689402953, 3.867340374856824),
    # 'other': (-0.5925285502290069, 2.732820150602578),
}  # computed on train subset, log-transformed

# choose model
if args.model == 'gmvae':
    model = GMVAE(z_dim=args.z,
                  name=model_name,
                  x_dim=24,
                  warmup=(args.warmup == 1),
                  var_pen=args.var_pen,
                  k=args.k).to(device)
else:
    model = VAE(z_dim=args.z,
                name=model_name,
                x_dim=24,
                warmup=(args.warmup == 1),
                var_pen=args.var_pen).to(device)

if args.mode == 'train':
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train_loader = ut.get_load_data(device,
                                    split='train',
                                    batch_size=args.batch,
                                    in_memory=True,
                                    log_normal=True,
                                    shift_scale=shift_scale)
    train(model=model,
          train_loader=train_loader,
          device=device,
          tqdm=tqdm.tqdm,