parser.add_argument('--k', type=int, default=500, help="Number mixture components in MoG prior") parser.add_argument('--iter_max', type=int, default=20000, help="Number of training iterations") parser.add_argument('--iter_save', type=int, default=10000, help="Save model every n iterations") parser.add_argument('--run', type=int, default=0, help="Run ID. In case you want to run replicates") args = parser.parse_args() layout = [ ('model={:s}', 'gmvae'), ('z={:02d}', args.z), ('k={:03d}', args.k), ('run={:04d}', args.run) ] model_name = '_'.join([t.format(v) for (t, v) in layout]) pprint(vars(args)) print('Model name:', model_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True) gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device) ut.load_model_by_name(gmvae, global_step=args.iter_max) ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=False) samples = torch.reshape(gmvae.sample_x(200), (10, 20, 28, 28)) f, axarr = plt.subplots(10,20) for i in range(samples.shape[0]): for j in range(samples.shape[1]): axarr[i,j].imshow(samples[i,j].detach().numpy()) axarr[i,j].axis('off') plt.show()
train_loader=train_loader, # train_loader=data_loader_individual[0], labeled_subset=labeled_subset, device=device, tqdm=tqdm.tqdm, writer=writer, iter_max=10000, iter_save=args.iter_save) ut.evaluate_lower_bound(vae, labeled_subset, run_iwae=args.train == 2) train_args = 2 # train_args = None mean_set = [] variance_set = [] if train_args == 2: ut.load_model_by_name(vae, global_step=20000) para_set = [ get_mean_variance(vae, data_set_individual[i]) for i in range(10) ] for i, set in enumerate(para_set): temp_mean, temp_variance = ut.resample(10, set[0], set[1]) mean_set.append(temp_mean) variance_set.append(temp_variance) train_args = 3 # train_args = None if train_args == 3: writer = ut.prepare_writer(model_name, overwrite_existing=False) refine( train_loader_set=data_loader_individual, # train_loader=data_loader_individual[0],
('run={:04d}', args.run) ] model_name = '_'.join([t.format(v) for (t, v) in layout]) pprint(vars(args)) print('Model name:', model_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_loader, labeled_subset, _ = ut.get_mnist_data(device, use_test_subset=True) gmvae = GMVAE(z_dim=args.z, k=args.k, name=model_name).to(device) if args.train: writer = ut.prepare_writer(model_name, overwrite_existing=True) train(model=gmvae, train_loader=train_loader, labeled_subset=labeled_subset, device=device, tqdm=tqdm.tqdm, writer=writer, iter_max=args.iter_max, iter_save=args.iter_save) ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=args.train == 2) else: ut.load_model_by_name(gmvae, global_step=args.iter_max, device=device) ut.evaluate_lower_bound(gmvae, labeled_subset, run_iwae=True) # draw digits # ut.load_model_by_name(gmvae, global_step=args.iter_max, device=device) # sample_digits = gmvae.sample_x(200) # # ut.plot_figures(sample_digits, 10, 20, 28, 28, 'q2_digits.png')
parser.add_argument('--dag', type=str, default="sup_dag", help="Flag for toy") args = parser.parse_args() device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu") layout = [ ('model={:s}', 'causalvae'), ('run={:04d}', args.run), ('color=True', args.color), ('toy={:s}', str(args.toy)) ] model_name = '_'.join([t.format(v) for (t, v) in layout]) if args.dag == "sup_dag": lvae = sup_dag.CausalVAE(name=model_name, z_dim=16, inference = True).to(device) ut.load_model_by_name(lvae, 0) if not os.path.exists('./figs_test_vae_pendulum/'): os.makedirs('./figs_test_vae_pendulum/') means = torch.zeros(2,3,4).to(device) z_mask = torch.zeros(2,3,4).to(device) dataset_dir = './causal_data/pendulum' train_dataset = get_batch_unin_dataset_withlabel(dataset_dir, 100,dataset="train") count = 0 sample = False print('DAG:{}'.format(lvae.dag.A)) for u,l in train_dataset: for i in range(4): for j in range(-5,5):
gen_weight=args.gw, class_weight=args.cw, name=model_name, CNN=CNN).to(device) Train = True if Train: writer = ut.prepare_writer(model_name, overwrite_existing=True) train(model=hkvae, train_loader=train_loader, labeled_subset=labeled_subset, device=device, y_status='hk', tqdm=tqdm.tqdm, writer=writer, iter_max=args.iter_max, iter_save=args.iter_save, rec_step=args.rec_step, CNN=CNN) else: ut.load_model_by_name(hkvae, args.iter_max) # pprint(vars(args)) # print('Model name:', model_name) # print(hkvae.CNN) # xl, yl = test_set # yl = torch.tensor(np.eye(10)[yl]).float().to(device) # test_set = (xl, yl) # ut.evaluate_lower_bound_HK(hkvae, test_set) # ut.evaluate_classifier_HK(hkvae, test_set)
def run(args, verbose=False): layout = [ ('{:s}', "vae2"), ('{:s}', args.model), # ('x{:02d}', 24 if args.hourly==1 else 96), # ('z{:02d}', args.z), ('k{:02d}', args.k), ('iw{:02d}', args.iw), ('vp{:02d}', args.var_pen), ('lr{:.4f}', args.lr), ('epo{:03d}', args.num_epochs), ('run{:02d}', args.run), ] model_name = '_'.join([t.format(v) for (t, v) in layout]) if verbose: pprint(vars(args)) print('Model name:', model_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # cloud # root_dir = "../data/data15_final" # Oskar root_dir = "../data/CS236/data60/split" if ( args.hourly == 1) else "../data/CS236/data15_final" # Will #root_dir = '/Users/willlauer/Desktop/latent_load_gen/data/split' # load train loader anyways - to get correct shift_scale values. train_loader = torch.utils.data.DataLoader( LoadDataset2(root_dir=root_dir, mode='train', shift_scale=None, filter_ev=False, smooth=args.smooth), batch_size=args.batch, shuffle=True, ) shift_scale = train_loader.dataset.shift_scale if args.k > 1: model = GMVAE2( nn=args.model, z_dim=args.z, name=model_name, x_dim=24 if args.hourly == 1 else 96, warmup=(args.warmup == 1), var_pen=args.var_pen, k=args.k, y_dim=train_loader.dataset.dim_meta, ).to(device) else: model = VAE2(nn=args.model, z_dim=args.z, name=model_name, x_dim=24 if args.hourly == 1 else 96, y_dim=train_loader.dataset.dim_meta, warmup=(args.warmup == 1), var_pen=args.var_pen).to(device) if args.mode == 'train': split_set = LoadDataset2( root_dir=root_dir, mode='val', shift_scale=shift_scale, filter_ev=False, smooth=None, ) val_set = { "x": torch.FloatTensor(split_set.other).to(device), "y": torch.FloatTensor(split_set.meta).to(device), "c": None, } # maybe in future use Tensorboard? _ = ut.prepare_writer(model_name, overwrite_existing=True) train2( model=model, train_loader=train_loader, val_set=val_set, tqdm=tqdm.tqdm, # writer=writer, lr=args.lr, lr_gamma=args.lr_gamma, lr_milestone_every=args.lr_every, iw=args.iw, num_epochs=args.num_epochs) else: ut.load_model_by_name(model, global_step=args.num_epochs) if args.mode in ['val', 'test']: model.set_to_eval() split_set = LoadDataset2( root_dir=root_dir, mode=args.mode, shift_scale=shift_scale, filter_ev=False, smooth=None, ) val_set = { "x": torch.FloatTensor(split_set.other).to(device), "y": torch.FloatTensor(split_set.meta).to(device), "c": None, } summaries = OrderedDict({ 'epoch': args.num_epochs, 'loss': 0, 'kl_z': 0, 'rec_mse': 0, 'rec_var': 0, 'loss_type': 0, 'lr': args.lr, 'var_pen': model.var_pen, }) ut.save_latent(model, val_set, mode=args.mode, is_car_model=False) ut.evaluate_lower_bound2(model, val_set, run_iwae=True, mode=args.mode, repeats=10, summaries=summaries) if args.mode == 'plot': # print(shift_scale["other"]) # print(shift_scale) make_image_load(model, shift_scale["other"], (args.log_ev == 1)) # make_image_load_day(model, shift_scale["other"], (args.log_ev==1)) make_image_load_z(model, shift_scale["other"], (args.log_ev == 1)) if args.mode == 'load': if verbose: print(model) return model
type=int, default=10000, help="Save model every n iterations") parser.add_argument('--run', type=int, default=0, help="Run ID. In case you want to run replicates") args = parser.parse_args() layout = [('model={:s}', 'fsvae'), ('run={:04d}', args.run)] model_name = '_'.join([t.format(v) for (t, v) in layout]) pprint(vars(args)) print('Model name:', model_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_loader, labeled_subset, test_set = ut.get_svhn_data(device) fsvae = FSVAE(name=model_name).to(device) # writer = ut.prepare_writer(model_name, overwrite_existing=True) # train(model=fsvae, # train_loader=train_loader, # labeled_subset=labeled_subset, # device=device, # y_status='fullsup', # tqdm=tqdm.tqdm, # writer=writer, # iter_max=args.iter_max, # iter_save=args.iter_save) ut.load_model_by_name(fsvae, global_step=60000, device=device) ut.plot_grid_fsvae(fsvae)
def refine(train_loader_set, mean_set, variance_set, z_dim, device, tqdm, writer, iter_max=np.inf, iter_save=np.inf, model_name='model', y_status='none', reinitialize=False): # Optimization i = 0 with tqdm(total=iter_max) as pbar: while True: for index, train_loader in enumerate(train_loader_set): print("Iteration:", i) print("index: ", index) z_prior_m = torch.nn.Parameter(mean_set[index].cpu(), requires_grad=False).to(device) z_prior_v = torch.nn.Parameter(variance_set[index].cpu(), requires_grad=False).to(device) vae = VAE(z_dim=z_dim, name=model_name, z_prior_m=z_prior_m, z_prior_v=z_prior_v).to(device) optimizer = optim.Adam(vae.parameters(), lr=1e-3) if i == 0: print("Load model") ut.load_model_by_name(vae, global_step=20000) else: print("Load model") ut.load_model_by_name(vae, global_step=iter_save) for batch_idx, (xu, yu) in enumerate(train_loader): # i is num of gradient steps taken by end of loop iteration optimizer.zero_grad() xu = torch.bernoulli(xu.to(device).reshape(xu.size(0), -1)) yu = yu.new(np.eye(10)[yu]).to(device).float() loss, summaries = vae.loss_encoder(xu) loss.backward() optimizer.step() # Feel free to modify the progress bar pbar.set_postfix(loss='{:.2e}'.format(loss)) pbar.update(1) i += 1 # Log summaries if i % 50 == 0: ut.log_summaries(writer, summaries, i) if i == iter_max: ut.save_model_by_name(vae, 0) return # Save model ut.save_model_by_name(vae, iter_save)
parser.add_argument('--train', type=int, default=1, help="Flag for training") args = parser.parse_args() layout = [('model={:s}', 'ssvae'), ('gw={:03d}', args.gw), ('cw={:03d}', args.cw), ('run={:04d}', args.run)] model_name = '_'.join([t.format(v) for (t, v) in layout]) pprint(vars(args)) print('Model name:', model_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_loader, labeled_subset, test_set = ut.get_mnist_data( device, use_test_subset=False) ssvae = SSVAE(gen_weight=args.gw, class_weight=args.cw, name=model_name).to(device) if args.train: writer = ut.prepare_writer(model_name, overwrite_existing=True) train(model=ssvae, train_loader=train_loader, labeled_subset=labeled_subset, device=device, y_status='semisup', tqdm=tqdm.tqdm, writer=writer, iter_max=args.iter_max, iter_save=args.iter_save) else: ut.load_model_by_name(ssvae, args.iter_max, device=device) ut.evaluate_classifier(ssvae, test_set)
def run(args, verbose=False): layout = [ ('{:s}', "vae2"), ('{:s}', args.model), # ('x{:02d}', 24 if args.hourly==1 else 96), # ('z{:02d}', args.z), ('k{:02d}', args.k), ('iw{:02d}', args.iw), ('vp{:02d}', args.var_pen), ('lr{:.4f}', args.lr), ('epo{:03d}', args.num_epochs), ('run{:02d}', args.run) ] model_name = 'car' + '_'.join([t.format(v) for (t, v) in layout]) if verbose: pprint(vars(args)) print('Model name:', model_name) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # cloud # root_dir = "../data/data15_final" # Oskar root_dir = "../data/CS236/data60/split" if ( args.hourly == 1) else "../data/CS236/data15_final" # load train loader anyways - to get correct shift_scale values. train_loader = torch.utils.data.DataLoader( LoadDataset2(root_dir=root_dir, mode='train', shift_scale=None, filter_ev=False, log_car=(args.log_ev == 1), smooth=args.smooth), batch_size=args.batch, shuffle=True, ) shift_scale = train_loader.dataset.shift_scale # load use-model use_model = run_vae2.main({ "mode": 'load', "model": 'ff-s-dec', # hardcode "lr": 0.01, # hardcode "k": 1, # hardcode "iw": 0, # hardcode "num_epochs": 20, # hardcode "var_pen": 1, # hardcode "run": 1, # hardcode }) if args.k > 1: print('36') model = GMVAE2CAR( nn=args.model, name=model_name, z_dim=args.z, x_dim=24 if args.hourly == 1 else 96, c_dim=use_model.z_dim, warmup=(args.warmup == 1), var_pen=args.var_pen, use_model=use_model, k=args.k, y_dim=train_loader.dataset.dim_meta, ).to(device) else: model = VAE2CAR( nn=args.model, name=model_name, z_dim=args.z, x_dim=24 if args.hourly == 1 else 96, c_dim=use_model.z_dim, warmup=(args.warmup == 1), var_pen=args.var_pen, use_model=use_model, y_dim=train_loader.dataset.dim_meta, ).to(device) if args.mode == 'train': split_set = LoadDataset2( root_dir=root_dir, mode='val', shift_scale=shift_scale, filter_ev=False, log_car=(args.log_ev == 1), smooth=None, ) val_set = { "x": torch.FloatTensor(split_set.car).to(device), "y": torch.FloatTensor(split_set.meta).to(device), "c": torch.FloatTensor(split_set.other).to(device), } _ = ut.prepare_writer(model_name, overwrite_existing=True) # make sure not to train the first VAE if not (args.finetune == 1): for p in model.use_model.parameters(): p.requires_grad = False train2( model=model, train_loader=train_loader, val_set=val_set, tqdm=tqdm.tqdm, lr=args.lr, lr_gamma=args.lr_gamma, lr_milestone_every=args.lr_every, iw=args.iw, num_epochs=args.num_epochs, is_car_model=True, ) else: ut.load_model_by_name(model, global_step=args.num_epochs) if args.mode in ['val', 'test']: model.set_to_eval() split_set = LoadDataset2( root_dir=root_dir, mode=args.mode, shift_scale=shift_scale, filter_ev=False, log_car=(args.log_ev == 1), smooth=None, ) val_set = { "x": torch.FloatTensor(split_set.car).to(device), "y": torch.FloatTensor(split_set.meta).to(device), "c": torch.FloatTensor(split_set.other).to(device), } summaries = OrderedDict({ 'epoch': args.num_epochs, 'loss': 0, 'kl_z': 0, 'rec_mse': 0, 'rec_var': 0, 'loss_type': 0, 'lr': args.lr, 'var_pen': model.var_pen, }) ut.save_latent(model, val_set, mode=args.mode, is_car_model=True) ut.evaluate_lower_bound2(model, val_set, run_iwae=True, mode=args.mode, repeats=10, summaries=copy.deepcopy(summaries)) if args.mode == 'plot': make_image_load(model, shift_scale["car"], (args.log_ev == 1)) # make_image_load_day(model, shift_scale["car"], (args.log_ev==1)) make_image_load_z(model, shift_scale["car"], (args.log_ev == 1)) make_image_load_z_use(model, shift_scale["car"], (args.log_ev == 1)) if args.mode == 'load': if verbose: print(model) return model