def fit_gaia_lim_sgd(datafile, use_cuda=False): data = np.load(datafile) if use_cuda: device = torch.device('cuda') else: device = torch.device('cpu') train_data = DeconvDataset(torch.Tensor(data['X_train']), torch.Tensor(data['L_train'])) val_data = DeconvDataset(torch.Tensor(data['X_val']), torch.Tensor(data['L_val'])) svi = SVIFlow(7, 5, device=device, batch_size=512, epochs=40, lr=1e-4) svi.fit(train_data, val_data=val_data) val_log_prob = svi.score_batch(val_data, log_prob=True) print('Val log prob: {}'.format(val_log_prob / len(val_data)))
X_train = X_noisy[:(2 * N), :] X_test = X_noisy[(2 * N):, :] nc_train = S[:(2 * N), :, :] nc_test = S[(2 * N):, :, :] train_data = DeconvDataset( torch.Tensor(X_train.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32))) test_data = DeconvDataset( torch.Tensor(X_test.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32))) svi = SVIFlow(D, 5, device=device, batch_size=512, epochs=50, lr=1e-4) svi.fit(train_data, val_data=None) test_log_prob = svi.score_batch(test_data, log_prob=True) print('Test log prob: {}'.format(test_log_prob / len(test_data))) gmm = SGDDeconvGMM(K, D, device=device, batch_size=256, epochs=50, lr=1e-1) gmm.fit(train_data, val_data=test_data, verbose=True) test_log_prob = gmm.score_batch(test_data) print('Test log prob: {}'.format(test_log_prob / len(test_data))) if plot: x_width = 200 y_width = 200
elbo_params = [] for f in os.listdir(args.elbo_results_dir): path = os.path.join(args.elbo_results_dir, f) elbo_params.append(path) iw_params = [] for f in os.listdir(args.iw_results_dir): path = os.path.join(args.iw_results_dir, f) iw_params.append(path) svi = SVIFlow(2, 5, device=torch.device('cuda'), batch_size=512, epochs=100, lr=1e-4, n_samples=50, use_iwae=False, context_size=64, hidden_features=128) results = [] test_data = DeconvDataset(x_test.squeeze(), torch.cholesky(S.repeat(N, 1, 1))) torch.set_default_tensor_type(torch.cuda.FloatTensor) for p in pretrained_params: svi.model.load_state_dict(torch.load(p)) with torch.no_grad(): logv = svi.model._prior.log_prob(z_test[0].to(
parser.add_argument('pretrained_results_dir') parser.add_argument('posttrained_results_dir') parser.add_argument('svi_gmm_results_dir') parser.add_argument('svi_exact_gmm_results_dir') args = parser.parse_args() K = 3 D = 2 N = 10000 svi = SVIFlow(2, 5, device=torch.device('cuda'), batch_size=4096, epochs=100, lr=1e-4, n_samples=50, use_iwae=False, context_size=64) svi_gmm = SVIGMMFlow(2, 5, device=torch.device('cuda'), batch_size=512, epochs=100, lr=1e-4, n_samples=50, use_iwae=False, context_size=64, hidden_features=128) svi_exact_gmm = SVIGMMExact(2,
val_data = DeconvDataset(x_val.squeeze(), S.repeat(N_val, 1, 1)) gmm = SGDDeconvGMM( K, D, batch_size=200, epochs=args.epochs, lr=args.learning_rate, device=torch.device('cuda') ) gmm.fit(train_data, val_data=val_data, verbose=True) torch.save(gmm.module.state_dict(), args.output_prefix + '_params.pt') else: train_data = DeconvDataset(x_train.squeeze(), torch.cholesky(S.repeat(N, 1, 1))) val_data = DeconvDataset(x_val.squeeze(), torch.cholesky(S.repeat(N_val, 1, 1))) svi = SVIFlow( 2, 5, device=torch.device('cuda'), batch_size=512, epochs=args.epochs, lr=args.learning_rate, n_samples=args.samples, use_iwae=args.use_iwae, grad_clip_norm=args.grad_clip_norm, context_size=64, hidden_features=args.hidden_features ) svi.fit(train_data, val_data=val_data) torch.save(svi.model.state_dict(), args.output_prefix + '_params.pt')
K = 4 D = 2 N = 50000 N_val = int(0.25 * N) ref_gmm, S, (z_train, x_train), (z_val, x_val), _ = generate_mixture_data() train_data = DeconvDataset(x_train.squeeze(), torch.cholesky(S.repeat(N, 1, 1))) val_data = DeconvDataset(x_val.squeeze(), torch.cholesky(S.repeat(N, 1, 1))) svi = SVIFlow( 2, 5, device=torch.device('cuda'), batch_size=512, epochs=args.epochs, lr=args.learning_rate, n_samples=args.samples, use_iwae=False, context_size=64, hidden_features=args.hidden_features ) optimiser_prior = torch.optim.Adam( params=svi.model._prior.parameters(), lr=1e-3 ) scheduler_prior = torch.optim.lr_scheduler.ReduceLROnPlateau( optimiser_prior, mode='max', factor=0.5,