Exemplo n.º 1
0
parser.add_argument('--k',         type=int, default=500,   help="Number mixture components in MoG prior")
parser.add_argument('--iter_max',  type=int, default=20000, help="Number of training iterations")
parser.add_argument('--iter_save', type=int, default=10000, help="Save model every n iterations")
parser.add_argument('--run',       type=int, default=0,     help="Run ID. In case you want to run replicates")
args = parser.parse_args()
layout = [
    ('model={:s}',  'gmvae'),
    ('z={:02d}',  args.z),
    ('k={:03d}',  args.k),
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

ut.load_model_by_name(gmvae, global_step=args.iter_max)
ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=False)
samples = torch.reshape(gmvae.sample_x(200), (10, 20, 28, 28))

f, axarr = plt.subplots(10,20)

for i in range(samples.shape[0]):
    for j in range(samples.shape[1]):
        axarr[i,j].imshow(samples[i,j].detach().numpy())
        axarr[i,j].axis('off')

plt.show() 
Exemplo n.º 2
0
test_loader = torch.utils.data.DataLoader(testgen,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=0)

full_loader = torch.utils.data.DataLoader(dc,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=0)

#raise Exception("Stop")

#train_loader, labeled_subset, test_loader = ut.get_mnist_data_and_test(device, use_test_subset=True)
gmvae = GMVAE(nn='popv',
              encode_dim=len(dc[0]),
              z_dim=args.z,
              k=args.k,
              name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=gmvae,
          train_loader=train_loader,
          labeled_subset=None,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(gmvae,
                            None,
Exemplo n.º 3
0
parser.add_argument('--run',       type=int, default=0,     help="Run ID. In case you want to run replicates")
parser.add_argument('--train',     type=int, default=1,     help="Flag for training")
args = parser.parse_args()
layout = [
    ('model={:s}',  'gmvae'),
    ('z={:02d}',  args.z),
    ('k={:03d}',  args.k),
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=gmvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2)

else:
    ut.load_model_by_name(gmvae, global_step=args.iter_max, device=device)
Exemplo n.º 4
0
parser.add_argument('--sample',    type=int, default=0,     help="Flag for smapling")
args = parser.parse_args()
layout = [
    ('model={:s}',  'gmvae'),
    ('z={:02d}',  args.z),
    ('k={:03d}',  args.k),
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=gmvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2)

else:
    ut.load_model_by_name(gmvae, global_step=args.iter_max)
Exemplo n.º 5
0
parser.add_argument('--run',
                    type=int,
                    default=0,
                    help="Run ID. In case you want to run replicates")
parser.add_argument('--train', type=int, default=1, help="Flag for training")
args = parser.parse_args()
layout = [('model={:s}', 'gmvae'), ('z={:02d}', args.z), ('k={:03d}', args.k),
          ('run={:04d}', args.run)]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device,
                                                    use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    print(args.train)
    train(model=gmvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)

    ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2)
    x = gmvae.sample_x(100).view(100, 1, 28, 28)
Exemplo n.º 6
0
parser.add_argument('--run',
                    type=int,
                    default=0,
                    help="Run ID. In case you want to run replicates")
parser.add_argument('--train', type=int, default=1, help="Flag for training")
args = parser.parse_args()
layout = [('model={:s}', 'gmvae'), ('z={:02d}', args.z), ('k={:03d}', args.k),
          ('run={:04d}', args.run)]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device,
                                                    use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=gmvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2)

else:
    ut.load_model_by_name(gmvae, global_step=args.iter_max, device=device)