示例#1
0
parser.add_argument('--k',         type=int, default=500,   help="Number mixture components in MoG prior")
parser.add_argument('--iter_max',  type=int, default=20000, help="Number of training iterations")
parser.add_argument('--iter_save', type=int, default=10000, help="Save model every n iterations")
parser.add_argument('--run',       type=int, default=0,     help="Run ID. In case you want to run replicates")
args = parser.parse_args()
layout = [
    ('model={:s}',  'gmvae'),
    ('z={:02d}',  args.z),
    ('k={:03d}',  args.k),
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

ut.load_model_by_name(gmvae, global_step=args.iter_max)
ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=False)
samples = torch.reshape(gmvae.sample_x(200), (10, 20, 28, 28))

f, axarr = plt.subplots(10,20)

for i in range(samples.shape[0]):
    for j in range(samples.shape[1]):
        axarr[i,j].imshow(samples[i,j].detach().numpy())
        axarr[i,j].axis('off')

plt.show() 
示例#2
0
        train_loader=train_loader,
        # train_loader=data_loader_individual[0],
        labeled_subset=labeled_subset,
        device=device,
        tqdm=tqdm.tqdm,
        writer=writer,
        iter_max=10000,
        iter_save=args.iter_save)
    ut.evaluate_lower_bound(vae, labeled_subset, run_iwae=args.train == 2)

train_args = 2
# train_args = None
mean_set = []
variance_set = []
if train_args == 2:
    ut.load_model_by_name(vae, global_step=20000)
    para_set = [
        get_mean_variance(vae, data_set_individual[i]) for i in range(10)
    ]
    for i, set in enumerate(para_set):
        temp_mean, temp_variance = ut.resample(10, set[0], set[1])
        mean_set.append(temp_mean)
        variance_set.append(temp_variance)

train_args = 3
# train_args = None
if train_args == 3:
    writer = ut.prepare_writer(model_name, overwrite_existing=False)
    refine(
        train_loader_set=data_loader_individual,
        # train_loader=data_loader_individual[0],
示例#3
0
文件: run_gmvae.py 项目: wcAlex/CS236
    ('run={:04d}', args.run)
]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True)
gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=gmvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)
    ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2)

else:
    ut.load_model_by_name(gmvae, global_step=args.iter_max, device=device)
    ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=True)

# draw digits
# ut.load_model_by_name(gmvae, global_step=args.iter_max, device=device)
# sample_digits = gmvae.sample_x(200)
#
# ut.plot_figures(sample_digits, 10, 20, 28, 28, 'q2_digits.png')
示例#4
0
parser.add_argument('--dag',       type=str, default="sup_dag",     help="Flag for toy")

args = parser.parse_args()
device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")

layout = [
	('model={:s}',  'causalvae'),
	('run={:04d}', args.run),
  ('color=True', args.color),
  ('toy={:s}', str(args.toy))

]
model_name = '_'.join([t.format(v) for (t, v) in layout])
if args.dag == "sup_dag":
  lvae = sup_dag.CausalVAE(name=model_name, z_dim=16, inference = True).to(device)
  ut.load_model_by_name(lvae, 0)

if not os.path.exists('./figs_test_vae_pendulum/'): 
  os.makedirs('./figs_test_vae_pendulum/')
means = torch.zeros(2,3,4).to(device)
z_mask = torch.zeros(2,3,4).to(device)

dataset_dir = './causal_data/pendulum'
train_dataset = get_batch_unin_dataset_withlabel(dataset_dir, 100,dataset="train")

count = 0
sample = False
print('DAG:{}'.format(lvae.dag.A))
for u,l in train_dataset:
  for i in range(4):
    for j in range(-5,5):
              gen_weight=args.gw,
              class_weight=args.cw,
              name=model_name,
              CNN=CNN).to(device)

Train = True
if Train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=hkvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          y_status='hk',
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save,
          rec_step=args.rec_step,
          CNN=CNN)
else:
    ut.load_model_by_name(hkvae, args.iter_max)

# pprint(vars(args))
# print('Model name:', model_name)
# print(hkvae.CNN)
# xl, yl = test_set
# yl = torch.tensor(np.eye(10)[yl]).float().to(device)
# test_set = (xl, yl)
# ut.evaluate_lower_bound_HK(hkvae, test_set)
# ut.evaluate_classifier_HK(hkvae, test_set)
示例#6
0
def run(args, verbose=False):
    layout = [
        ('{:s}', "vae2"),
        ('{:s}', args.model),
        # ('x{:02d}',  24 if args.hourly==1 else 96),
        # ('z{:02d}',  args.z),
        ('k{:02d}', args.k),
        ('iw{:02d}', args.iw),
        ('vp{:02d}', args.var_pen),
        ('lr{:.4f}', args.lr),
        ('epo{:03d}', args.num_epochs),
        ('run{:02d}', args.run),
    ]
    model_name = '_'.join([t.format(v) for (t, v) in layout])
    if verbose: pprint(vars(args))
    print('Model name:', model_name)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # cloud
    # root_dir = "../data/data15_final"
    # Oskar
    root_dir = "../data/CS236/data60/split" if (
        args.hourly == 1) else "../data/CS236/data15_final"
    # Will
    #root_dir = '/Users/willlauer/Desktop/latent_load_gen/data/split'

    # load train loader anyways - to get correct shift_scale values.
    train_loader = torch.utils.data.DataLoader(
        LoadDataset2(root_dir=root_dir,
                     mode='train',
                     shift_scale=None,
                     filter_ev=False,
                     smooth=args.smooth),
        batch_size=args.batch,
        shuffle=True,
    )
    shift_scale = train_loader.dataset.shift_scale

    if args.k > 1:
        model = GMVAE2(
            nn=args.model,
            z_dim=args.z,
            name=model_name,
            x_dim=24 if args.hourly == 1 else 96,
            warmup=(args.warmup == 1),
            var_pen=args.var_pen,
            k=args.k,
            y_dim=train_loader.dataset.dim_meta,
        ).to(device)
    else:
        model = VAE2(nn=args.model,
                     z_dim=args.z,
                     name=model_name,
                     x_dim=24 if args.hourly == 1 else 96,
                     y_dim=train_loader.dataset.dim_meta,
                     warmup=(args.warmup == 1),
                     var_pen=args.var_pen).to(device)

    if args.mode == 'train':
        split_set = LoadDataset2(
            root_dir=root_dir,
            mode='val',
            shift_scale=shift_scale,
            filter_ev=False,
            smooth=None,
        )
        val_set = {
            "x": torch.FloatTensor(split_set.other).to(device),
            "y": torch.FloatTensor(split_set.meta).to(device),
            "c": None,
        }
        # maybe in future use Tensorboard?
        _ = ut.prepare_writer(model_name, overwrite_existing=True)
        train2(
            model=model,
            train_loader=train_loader,
            val_set=val_set,
            tqdm=tqdm.tqdm,
            # writer=writer,
            lr=args.lr,
            lr_gamma=args.lr_gamma,
            lr_milestone_every=args.lr_every,
            iw=args.iw,
            num_epochs=args.num_epochs)

    else:
        ut.load_model_by_name(model, global_step=args.num_epochs)

    if args.mode in ['val', 'test']:
        model.set_to_eval()
        split_set = LoadDataset2(
            root_dir=root_dir,
            mode=args.mode,
            shift_scale=shift_scale,
            filter_ev=False,
            smooth=None,
        )
        val_set = {
            "x": torch.FloatTensor(split_set.other).to(device),
            "y": torch.FloatTensor(split_set.meta).to(device),
            "c": None,
        }
        summaries = OrderedDict({
            'epoch': args.num_epochs,
            'loss': 0,
            'kl_z': 0,
            'rec_mse': 0,
            'rec_var': 0,
            'loss_type': 0,
            'lr': args.lr,
            'var_pen': model.var_pen,
        })

        ut.save_latent(model, val_set, mode=args.mode, is_car_model=False)

        ut.evaluate_lower_bound2(model,
                                 val_set,
                                 run_iwae=True,
                                 mode=args.mode,
                                 repeats=10,
                                 summaries=summaries)

    if args.mode == 'plot':

        # print(shift_scale["other"])
        # print(shift_scale)

        make_image_load(model, shift_scale["other"], (args.log_ev == 1))
        # make_image_load_day(model, shift_scale["other"], (args.log_ev==1))
        make_image_load_z(model, shift_scale["other"], (args.log_ev == 1))

    if args.mode == 'load':
        if verbose: print(model)
    return model
示例#7
0
                    type=int,
                    default=10000,
                    help="Save model every n iterations")
parser.add_argument('--run',
                    type=int,
                    default=0,
                    help="Run ID. In case you want to run replicates")
args = parser.parse_args()
layout = [('model={:s}', 'fsvae'), ('run={:04d}', args.run)]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, test_set = ut.get_svhn_data(device)
fsvae = FSVAE(name=model_name).to(device)
# writer = ut.prepare_writer(model_name, overwrite_existing=True)

# train(model=fsvae,
#       train_loader=train_loader,
#       labeled_subset=labeled_subset,
#       device=device,
#       y_status='fullsup',
#       tqdm=tqdm.tqdm,
#       writer=writer,
#       iter_max=args.iter_max,
#       iter_save=args.iter_save)

ut.load_model_by_name(fsvae, global_step=60000, device=device)
ut.plot_grid_fsvae(fsvae)
def refine(train_loader_set,
           mean_set,
           variance_set,
           z_dim,
           device,
           tqdm,
           writer,
           iter_max=np.inf,
           iter_save=np.inf,
           model_name='model',
           y_status='none',
           reinitialize=False):
    # Optimization

    i = 0
    with tqdm(total=iter_max) as pbar:
        while True:
            for index, train_loader in enumerate(train_loader_set):
                print("Iteration:", i)
                print("index: ", index)

                z_prior_m = torch.nn.Parameter(mean_set[index].cpu(),
                                               requires_grad=False).to(device)
                z_prior_v = torch.nn.Parameter(variance_set[index].cpu(),
                                               requires_grad=False).to(device)
                vae = VAE(z_dim=z_dim,
                          name=model_name,
                          z_prior_m=z_prior_m,
                          z_prior_v=z_prior_v).to(device)
                optimizer = optim.Adam(vae.parameters(), lr=1e-3)
                if i == 0:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=20000)
                else:
                    print("Load model")
                    ut.load_model_by_name(vae, global_step=iter_save)
                for batch_idx, (xu, yu) in enumerate(train_loader):
                    # i is num of gradient steps taken by end of loop iteration
                    optimizer.zero_grad()

                    xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1))
                    yu = yu.new(np.eye(10)[yu]).to(device).float()
                    loss, summaries = vae.loss_encoder(xu)

                    loss.backward()
                    optimizer.step()

                    # Feel free to modify the progress bar

                    pbar.set_postfix(loss='{:.2e}'.format(loss))

                    pbar.update(1)

                    i += 1
                    # Log summaries
                    if i % 50 == 0: ut.log_summaries(writer, summaries, i)

                    if i == iter_max:
                        ut.save_model_by_name(vae, 0)
                        return

                # Save model
                ut.save_model_by_name(vae, iter_save)
示例#9
0
parser.add_argument('--train', type=int, default=1, help="Flag for training")
args = parser.parse_args()
layout = [('model={:s}', 'ssvae'), ('gw={:03d}', args.gw),
          ('cw={:03d}', args.cw), ('run={:04d}', args.run)]
model_name = '_'.join([t.format(v) for (t, v) in layout])
pprint(vars(args))
print('Model name:', model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader, labeled_subset, test_set = ut.get_mnist_data(
    device, use_test_subset=False)
ssvae = SSVAE(gen_weight=args.gw, class_weight=args.cw,
              name=model_name).to(device)

if args.train:
    writer = ut.prepare_writer(model_name, overwrite_existing=True)
    train(model=ssvae,
          train_loader=train_loader,
          labeled_subset=labeled_subset,
          device=device,
          y_status='semisup',
          tqdm=tqdm.tqdm,
          writer=writer,
          iter_max=args.iter_max,
          iter_save=args.iter_save)

else:
    ut.load_model_by_name(ssvae, args.iter_max, device=device)

ut.evaluate_classifier(ssvae, test_set)
示例#10
0
def run(args, verbose=False):
    layout = [
        ('{:s}', "vae2"),
        ('{:s}', args.model),
        # ('x{:02d}',  24 if args.hourly==1 else 96),
        # ('z{:02d}',  args.z),
        ('k{:02d}', args.k),
        ('iw{:02d}', args.iw),
        ('vp{:02d}', args.var_pen),
        ('lr{:.4f}', args.lr),
        ('epo{:03d}', args.num_epochs),
        ('run{:02d}', args.run)
    ]
    model_name = 'car' + '_'.join([t.format(v) for (t, v) in layout])
    if verbose: pprint(vars(args))
    print('Model name:', model_name)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # cloud
    # root_dir = "../data/data15_final"
    # Oskar
    root_dir = "../data/CS236/data60/split" if (
        args.hourly == 1) else "../data/CS236/data15_final"

    # load train loader anyways - to get correct shift_scale values.
    train_loader = torch.utils.data.DataLoader(
        LoadDataset2(root_dir=root_dir,
                     mode='train',
                     shift_scale=None,
                     filter_ev=False,
                     log_car=(args.log_ev == 1),
                     smooth=args.smooth),
        batch_size=args.batch,
        shuffle=True,
    )
    shift_scale = train_loader.dataset.shift_scale

    # load use-model
    use_model = run_vae2.main({
        "mode": 'load',
        "model": 'ff-s-dec',  # hardcode
        "lr": 0.01,  # hardcode
        "k": 1,  # hardcode
        "iw": 0,  # hardcode
        "num_epochs": 20,  # hardcode
        "var_pen": 1,  # hardcode
        "run": 1,  # hardcode
    })

    if args.k > 1:
        print('36')
        model = GMVAE2CAR(
            nn=args.model,
            name=model_name,
            z_dim=args.z,
            x_dim=24 if args.hourly == 1 else 96,
            c_dim=use_model.z_dim,
            warmup=(args.warmup == 1),
            var_pen=args.var_pen,
            use_model=use_model,
            k=args.k,
            y_dim=train_loader.dataset.dim_meta,
        ).to(device)
    else:
        model = VAE2CAR(
            nn=args.model,
            name=model_name,
            z_dim=args.z,
            x_dim=24 if args.hourly == 1 else 96,
            c_dim=use_model.z_dim,
            warmup=(args.warmup == 1),
            var_pen=args.var_pen,
            use_model=use_model,
            y_dim=train_loader.dataset.dim_meta,
        ).to(device)

    if args.mode == 'train':
        split_set = LoadDataset2(
            root_dir=root_dir,
            mode='val',
            shift_scale=shift_scale,
            filter_ev=False,
            log_car=(args.log_ev == 1),
            smooth=None,
        )
        val_set = {
            "x": torch.FloatTensor(split_set.car).to(device),
            "y": torch.FloatTensor(split_set.meta).to(device),
            "c": torch.FloatTensor(split_set.other).to(device),
        }
        _ = ut.prepare_writer(model_name, overwrite_existing=True)

        # make sure not to train the first VAE
        if not (args.finetune == 1):
            for p in model.use_model.parameters():
                p.requires_grad = False

        train2(
            model=model,
            train_loader=train_loader,
            val_set=val_set,
            tqdm=tqdm.tqdm,
            lr=args.lr,
            lr_gamma=args.lr_gamma,
            lr_milestone_every=args.lr_every,
            iw=args.iw,
            num_epochs=args.num_epochs,
            is_car_model=True,
        )

    else:
        ut.load_model_by_name(model, global_step=args.num_epochs)

    if args.mode in ['val', 'test']:
        model.set_to_eval()
        split_set = LoadDataset2(
            root_dir=root_dir,
            mode=args.mode,
            shift_scale=shift_scale,
            filter_ev=False,
            log_car=(args.log_ev == 1),
            smooth=None,
        )
        val_set = {
            "x": torch.FloatTensor(split_set.car).to(device),
            "y": torch.FloatTensor(split_set.meta).to(device),
            "c": torch.FloatTensor(split_set.other).to(device),
        }
        summaries = OrderedDict({
            'epoch': args.num_epochs,
            'loss': 0,
            'kl_z': 0,
            'rec_mse': 0,
            'rec_var': 0,
            'loss_type': 0,
            'lr': args.lr,
            'var_pen': model.var_pen,
        })

        ut.save_latent(model, val_set, mode=args.mode, is_car_model=True)

        ut.evaluate_lower_bound2(model,
                                 val_set,
                                 run_iwae=True,
                                 mode=args.mode,
                                 repeats=10,
                                 summaries=copy.deepcopy(summaries))

    if args.mode == 'plot':
        make_image_load(model, shift_scale["car"], (args.log_ev == 1))
        # make_image_load_day(model, shift_scale["car"], (args.log_ev==1))
        make_image_load_z(model, shift_scale["car"], (args.log_ev == 1))
        make_image_load_z_use(model, shift_scale["car"], (args.log_ev == 1))

    if args.mode == 'load':
        if verbose: print(model)
    return model