def score_sgd(datafile, results_dir, output_file): data = np.load(datafile) test_data = SGDDeconvDataset(torch.Tensor(data['X_test']), torch.Tensor(data['C_test'])) rf = os.listdir(results_dir) param_files = [ f for f in rf if f.startswith('sgd_512') and f.endswith('.pkl') ] gmm = SGDDeconvGMM(512, 7, batch_size=500) scores = [] for p in param_files: state_dict = torch.load(results_dir + p, map_location=torch.device('cpu')) gmm.module.load_state_dict(state_dict) test_score = gmm.score_batch(test_data) print(test_score) scores.append(test_score) print('Test Score: {} +- {}'.format(np.mean(scores), np.std(scores))) json.dump(scores, open(output_file, 'w'))
def _create_prior(self): self.gmm = SGDDeconvGMM(3, self.dimensions, batch_size=self.batch_size, epochs=self.epochs, lr=self.lr, device=self.device) self.gmm.module = SGDGMMModule(3, self.dimensions, w=0, device=self.device) return self.gmm.module
def fit_gaia_lim_sgd(datafile, output_prefix, K, batch_size, epochs, lr, w_reg, k_means_iters, lr_step, lr_gamma, use_cuda): data = np.load(datafile) if use_cuda: device = torch.device('cuda') else: device = torch.device('cpu') train_data = DeconvDataset( torch.Tensor(data['X_train']), torch.Tensor(data['C_train']) ) val_data = DeconvDataset( torch.Tensor(data['X_val']), torch.Tensor(data['C_val']) ) gmm = SGDDeconvGMM( K, 7, device=device, batch_size=batch_size, epochs=epochs, w=w_reg, k_means_iters=k_means_iters, lr=lr, lr_step=lr_step, lr_gamma=lr_gamma ) start_time = time.time() gmm.fit(train_data, val_data=val_data, verbose=True) end_time = time.time() train_score = gmm.score_batch(train_data) val_score = gmm.score_batch(val_data) print('Training score: {}'.format(train_score)) print('Val score: {}'.format(val_score)) results = { 'start_time': start_time, 'end_time': end_time, 'train_score': train_score, 'val_score': val_score, 'train_curve': gmm.train_loss_curve, 'val_curve': gmm.val_loss_curve } json.dump(results, open(str(output_prefix) + '_results.json', mode='w')) torch.save(gmm.module.state_dict(), output_prefix + '_params.pkl')
train_data = DeconvDataset( torch.Tensor(X_train.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32))) test_data = DeconvDataset( torch.Tensor(X_test.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32))) svi = SVIFlow(D, 5, device=device, batch_size=512, epochs=50, lr=1e-4) svi.fit(train_data, val_data=None) test_log_prob = svi.score_batch(test_data, log_prob=True) print('Test log prob: {}'.format(test_log_prob / len(test_data))) gmm = SGDDeconvGMM(K, D, device=device, batch_size=256, epochs=50, lr=1e-1) gmm.fit(train_data, val_data=test_data, verbose=True) test_log_prob = gmm.score_batch(test_data) print('Test log prob: {}'.format(test_log_prob / len(test_data))) if plot: x_width = 200 y_width = 200 x = np.linspace(-5, 10, num=x_width, dtype=np.float32) y = np.linspace(-15, 15, num=y_width, dtype=np.float32) xx, yy = np.meshgrid(x, y) d = torch.tensor(np.concatenate((xx[:, :, None], yy[:, :, None]), axis=-1))
path = os.path.join(args.gmm_results_dir, f) gmm_params.append(path) elbo_params = collections.defaultdict(list) for f in os.listdir(args.elbo_results_dir): path = os.path.join(args.elbo_results_dir, f) elbo_params[int(f[16:18])].append(path) iw_params = collections.defaultdict(list) for f in os.listdir(args.iw_results_dir): path = os.path.join(args.iw_results_dir, f) iw_params[int(f[16:18])].append(path) gmm = SGDDeconvGMM(K, D, batch_size=200, device=torch.device('cuda')) test_gmm = SGDGMM(K, D, batch_size=200, device=torch.device('cuda')) svi = SVIFlow(2, 5, device=torch.device('cuda'), batch_size=512, epochs=100, lr=1e-4, n_samples=50, use_iwae=False, context_size=64, hidden_features=128) results = []
elbo_params = [] for f in os.listdir(args.elbo_results_dir): path = os.path.join(args.elbo_results_dir, f) elbo_params.append(path) test_gmm = SGDGMM( K, D, batch_size=200, device=torch.device('cuda'), w=0 ) test_deconv_gmm = SGDDeconvGMM( K, D, batch_size=512, device=torch.device('cuda'), w=0 ) svi_gmm = SVIGMMExact( 2, 5, device=torch.device('cuda'), batch_size=512, epochs=100, lr=1e-4, n_samples=50, use_iwae=False, context_size=64, hidden_features=128
def generate_mixture_data(): K = 3 D = 2 N = 50000 N_val = int(0.25 * N) torch.set_default_tensor_type(torch.FloatTensor) ref_gmm = SGDDeconvGMM( K, D, batch_size=512, device=torch.device('cpu') ) ref_gmm.module.soft_weights.data = torch.zeros(K) scale = 2 ref_gmm.module.means.data = torch.Tensor([ [-scale, 0], [0, -scale], [0, scale] ]) short_std = 0.3 long_std = 1 stds = torch.Tensor([ [short_std, long_std], [long_std, short_std], [long_std, short_std] ]) ref_gmm.module.l_diag.data = torch.log(stds) state = torch.get_rng_state() torch.manual_seed(432988) z_train = ref_gmm.sample_prior(N) z_val = ref_gmm.sample_prior(N_val) noise_short = 0.1 noise_long = 1.0 S = torch.Tensor([ [noise_short, 0], [0, noise_long] ]) noise_distribution = torch.distributions.MultivariateNormal( loc=torch.Tensor([0, 0]), covariance_matrix=S ) x_train = z_train + noise_distribution.sample([N]) x_val = z_val + noise_distribution.sample([N_val]) torch.manual_seed(263568) z_test = ref_gmm.sample_prior(N) x_test = z_test + noise_distribution.sample([N]) torch.set_rng_state(state) return ( ref_gmm, S, (z_train, x_train), (z_val, x_val), (z_test, x_test) )
n_samples=50, use_iwae=False, context_size=64, hidden_features=128) svi_exact_gmm = SVIGMMExact(2, 5, device=torch.device('cuda'), batch_size=512, epochs=100, lr=1e-4, n_samples=50, use_iwae=False, context_size=64, hidden_features=128) test_gmm = SGDDeconvGMM(K, D, batch_size=200, device=torch.device('cuda'), w=0) torch.set_default_tensor_type(torch.cuda.FloatTensor) params = [] params.append( os.path.join(args.pretrained_results_dir, os.listdir(args.pretrained_results_dir)[0])) params.append( os.path.join(args.posttrained_results_dir, os.listdir(args.posttrained_results_dir)[0])) params.append( os.path.join(args.svi_gmm_results_dir, os.listdir(args.svi_gmm_results_dir)[0])) params.append( os.path.join(args.svi_exact_gmm_results_dir, os.listdir(args.svi_exact_gmm_results_dir)[0]))
def check_sgd_deconv_gmm(D, K, N, plot=False, verbose=False, device=None): if not device: device = torch.device('cpu') data, params = generate_data(D, K, N) X_train, nc_train, X_test, nc_test = data means, covars = params train_data = DeconvDataset( torch.Tensor(X_train.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32))) test_data = DeconvDataset( torch.Tensor(X_test.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32))) gmm = SGDDeconvGMM(K, D, device=device, batch_size=250, epochs=200, restarts=1, lr=1e-1) gmm.fit(train_data, val_data=test_data, verbose=verbose) train_score = gmm.score_batch(train_data) test_score = gmm.score_batch(test_data) print('Training score: {}'.format(train_score)) print('Test score: {}'.format(test_score)) if plot: fig, ax = plt.subplots() ax.plot(gmm.train_loss_curve, label='Training Loss') ax.plot(gmm.val_loss_curve, label='Validation Loss') fig, ax = plt.subplots() for i in range(K): sc = ax.scatter(X_train[:, i, 0], X_train[:, i, 1], alpha=0.2, marker='x', label='Cluster {}'.format(i)) plot_covariance(means[i, :], covars[i, :, :], ax, color=sc.get_facecolor()[0]) sc = ax.scatter(gmm.means[:, 0], gmm.means[:, 1], marker='+', label='Fitted Gaussians') for i in range(K): plot_covariance(gmm.means[i, :], gmm.covars[i, :, :], ax, color=sc.get_facecolor()[0]) ax.legend() plt.show()
hidden_features=args.hidden_features ) if args.freeze_gmm: svi_gmm.model._prior.load_state_dict(ref_gmm.module.state_dict()) for param in svi_gmm.model._prior.parameters(): param.requires_grad = False svi_gmm.fit(train_data, val_data=val_data) torch.save(svi_gmm.model.state_dict(), args.output_prefix + '_params.pt') else: train_data = DeconvDataset(x_train.squeeze(), S.repeat(N, 1, 1)) val_data = DeconvDataset(x_val.squeeze(), S.repeat(N_val, 1, 1)) gmm = SGDDeconvGMM( K, D, batch_size=200, epochs=args.epochs, lr=args.learning_rate, device=torch.device('cuda') ) gmm.fit(train_data, val_data=val_data, verbose=True) torch.save(gmm.module.state_dict(), args.output_prefix + '_params.pt') else: train_data = DeconvDataset(x_train.squeeze(), torch.cholesky(S.repeat(N, 1, 1))) val_data = DeconvDataset(x_val.squeeze(), torch.cholesky(S.repeat(N_val, 1, 1))) svi = SVIFlow( 2, 5, device=torch.device('cuda'), batch_size=512, epochs=args.epochs, lr=args.learning_rate,