def check_sgd_gmm(D, K, N, plot=False, device=None): if not device: device = torch.device('cpu') means = (np.random.rand(K, D) * 20) - 10 q = (2 * np.random.randn(K, D, D)) covars = np.matmul(q.swapaxes(1, 2), q) X = np.empty((N, K, D)) for i in range(K): X[:, i, :] = np.random.multivariate_normal( mean=means[i, :], cov=covars[i, :, :], size=N ) X_data = [torch.Tensor(X.reshape(-1, D).astype(np.float32))] gmm = SGDGMM(K, D, device=device, epochs=200) gmm.fit(X_data) if plot: fig, ax = plt.subplots() for i in range(K): sc = ax.scatter( X[:, i, 0], X[:, i, 1], alpha=0.2, marker='x', label='Cluster {}'.format(i) ) plot_covariance( means[i, :], covars[i, :, :], ax, color=sc.get_facecolor()[0] ) sc = ax.scatter( gmm.means[:, 0], gmm.means[:, 1], marker='+', label='Fitted Gaussians' ) for i in range(K): plot_covariance( gmm.means[i, :], gmm.covars[i, :, :], ax, color=sc.get_facecolor()[0] ) ax.legend() plt.show()
def check_deconv_gmm(D, K, N, plot=False, device=None): if not device: device = torch.device('cpu') means = (np.random.rand(K, D) * 20) - 10 q = (2 * np.random.randn(K, D, D)) covars = np.matmul(q.swapaxes(1, 2), q) qn = (0.5 * np.random.randn(N, K, D, D)) noise_covars = np.matmul(qn.swapaxes(2, 3), qn) X = np.empty((N, K, D)) for i in range(K): X[:, i, :] = np.random.multivariate_normal(mean=means[i, :], cov=covars[i, :, :], size=N) for j in range(N): X[j, i, :] += np.random.multivariate_normal(mean=np.zeros(D), cov=noise_covars[j, i, :, :]) data = (torch.Tensor(X.reshape(-1, D).astype(np.float32)).to(device), torch.Tensor(noise_covars.reshape(-1, D, D).astype( np.float32)).to(device)) gmm = DeconvGMM(K, D, max_iters=1000, device=device) gmm.fit(data) if plot: fig, ax = plt.subplots() for i in range(K): sc = ax.scatter(X[:, i, 0], X[:, i, 1], alpha=0.2, marker='x', label='Cluster {}'.format(i)) plot_covariance(means[i, :], covars[i, :, :], ax, color=sc.get_facecolor()[0]) sc = ax.scatter(gmm.means[:, 0], gmm.means[:, 1], marker='+', label='Fitted Gaussians') for i in range(K): plot_covariance(gmm.means[i, :], gmm.covars[i, :, :], ax, color=sc.get_facecolor()[0]) ax.legend() plt.show()
def check_online_deconv_gmm(D, K, N, plot=False, device=None, verbose=False): if not device: device = torch.device('cpu') data, params = generate_data(D, K, N) X_train, nc_train, X_test, nc_test = data means, covars = params train_data = DeconvDataset( torch.Tensor(X_train.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32))) test_data = DeconvDataset( torch.Tensor(X_test.reshape(-1, D).astype(np.float32)), torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32))) gmm = OnlineDeconvGMM(K, D, device=device, batch_size=500, step_size=1e-1, restarts=1, epochs=20, k_means_iters=20, lr_step=10, w=1e-3) gmm.fit(train_data, val_data=test_data, verbose=verbose) train_score = gmm.score_batch(train_data) test_score = gmm.score_batch(test_data) print('Training score: {}'.format(train_score)) print('Test score: {}'.format(test_score)) if plot: fig, ax = plt.subplots() ax.plot(gmm.train_ll_curve, label='Training LL') ax.plot(gmm.val_ll_curve, label='Validation LL') ax.legend() plt.show() fig, ax = plt.subplots() for i in range(K): sc = ax.scatter(X_train[:, i, 0], X_train[:, i, 1], alpha=0.2, marker='x', label='Cluster {}'.format(i)) plot_covariance(means[i, :], covars[i, :, :], ax, color=sc.get_facecolor()[0]) sc = ax.scatter(gmm.means[:, 0], gmm.means[:, 1], marker='+', label='Fitted Gaussians') for i in range(K): plot_covariance(gmm.means[i, :], gmm.covars[i, :, :], ax, color=sc.get_facecolor()[0]) ax.legend() plt.show()
svi.model.load_state_dict(torch.load(p)) z_samples = svi.sample_prior(10000) x_samples = svi.sample_posterior(test_point, 10000) ax_lim = (-4, 4) fig, axes = plt.subplots(1, 2, figsize=(3, 1.5), sharex=True, sharey=True) corner.hist2d(z_samples[0, :N, 0].numpy(), z_samples[0, :N, 1].numpy(), ax=axes[0]) corner.hist2d(x_samples[1, :N, 0].numpy(), x_samples[1, :N, 1].numpy(), ax=axes[1]) plot_covariance(mean[1], cov[1], ax=axes[1], color='r') axes[0].set_title(r'$p_{\theta}(\mathbf{v})$') axes[1].set_title(r'$q_{\phi}(\mathbf{v})$') axes[0].set_xticks([]) axes[0].set_yticks([]) axes[1].set_xticks([]) axes[1].set_yticks([]) axes[1].set_xlim(ax_lim) axes[1].set_ylim(ax_lim) fig.tight_layout() fig.savefig('additional_{}.pdf'.format(i))