Пример #1
0
def check_sgd_gmm(D, K, N, plot=False, device=None):

    if not device:
        device = torch.device('cpu')

    means = (np.random.rand(K, D) * 20) - 10
    q = (2 * np.random.randn(K, D, D))
    covars = np.matmul(q.swapaxes(1, 2), q)

    X = np.empty((N, K, D))

    for i in range(K):
        X[:, i, :] = np.random.multivariate_normal(
            mean=means[i, :],
            cov=covars[i, :, :],
            size=N
        )

    X_data = [torch.Tensor(X.reshape(-1, D).astype(np.float32))]

    gmm = SGDGMM(K, D, device=device, epochs=200)
    gmm.fit(X_data)

    if plot:
        fig, ax = plt.subplots()

        for i in range(K):
            sc = ax.scatter(
                X[:, i, 0],
                X[:, i, 1],
                alpha=0.2,
                marker='x',
                label='Cluster {}'.format(i)
            )
            plot_covariance(
                means[i, :],
                covars[i, :, :],
                ax,
                color=sc.get_facecolor()[0]
            )

        sc = ax.scatter(
            gmm.means[:, 0],
            gmm.means[:, 1],
            marker='+',
            label='Fitted Gaussians'
        )

        for i in range(K):

            plot_covariance(
                gmm.means[i, :],
                gmm.covars[i, :, :],
                ax,
                color=sc.get_facecolor()[0]
            )

        ax.legend()
        plt.show()
Пример #2
0
def check_deconv_gmm(D, K, N, plot=False, device=None):

    if not device:
        device = torch.device('cpu')

    means = (np.random.rand(K, D) * 20) - 10
    q = (2 * np.random.randn(K, D, D))
    covars = np.matmul(q.swapaxes(1, 2), q)

    qn = (0.5 * np.random.randn(N, K, D, D))
    noise_covars = np.matmul(qn.swapaxes(2, 3), qn)

    X = np.empty((N, K, D))

    for i in range(K):
        X[:, i, :] = np.random.multivariate_normal(mean=means[i, :],
                                                   cov=covars[i, :, :],
                                                   size=N)
        for j in range(N):
            X[j,
              i, :] += np.random.multivariate_normal(mean=np.zeros(D),
                                                     cov=noise_covars[j,
                                                                      i, :, :])

    data = (torch.Tensor(X.reshape(-1, D).astype(np.float32)).to(device),
            torch.Tensor(noise_covars.reshape(-1, D, D).astype(
                np.float32)).to(device))

    gmm = DeconvGMM(K, D, max_iters=1000, device=device)
    gmm.fit(data)

    if plot:
        fig, ax = plt.subplots()

        for i in range(K):
            sc = ax.scatter(X[:, i, 0],
                            X[:, i, 1],
                            alpha=0.2,
                            marker='x',
                            label='Cluster {}'.format(i))
            plot_covariance(means[i, :],
                            covars[i, :, :],
                            ax,
                            color=sc.get_facecolor()[0])

        sc = ax.scatter(gmm.means[:, 0],
                        gmm.means[:, 1],
                        marker='+',
                        label='Fitted Gaussians')

        for i in range(K):
            plot_covariance(gmm.means[i, :],
                            gmm.covars[i, :, :],
                            ax,
                            color=sc.get_facecolor()[0])

        ax.legend()
        plt.show()
Пример #3
0
def check_online_deconv_gmm(D, K, N, plot=False, device=None, verbose=False):

    if not device:
        device = torch.device('cpu')

    data, params = generate_data(D, K, N)
    X_train, nc_train, X_test, nc_test = data
    means, covars = params

    train_data = DeconvDataset(
        torch.Tensor(X_train.reshape(-1, D).astype(np.float32)),
        torch.Tensor(nc_train.reshape(-1, D, D).astype(np.float32)))

    test_data = DeconvDataset(
        torch.Tensor(X_test.reshape(-1, D).astype(np.float32)),
        torch.Tensor(nc_test.reshape(-1, D, D).astype(np.float32)))

    gmm = OnlineDeconvGMM(K,
                          D,
                          device=device,
                          batch_size=500,
                          step_size=1e-1,
                          restarts=1,
                          epochs=20,
                          k_means_iters=20,
                          lr_step=10,
                          w=1e-3)
    gmm.fit(train_data, val_data=test_data, verbose=verbose)

    train_score = gmm.score_batch(train_data)
    test_score = gmm.score_batch(test_data)

    print('Training score: {}'.format(train_score))
    print('Test score: {}'.format(test_score))

    if plot:
        fig, ax = plt.subplots()

        ax.plot(gmm.train_ll_curve, label='Training LL')
        ax.plot(gmm.val_ll_curve, label='Validation LL')
        ax.legend()

        plt.show()

        fig, ax = plt.subplots()

        for i in range(K):
            sc = ax.scatter(X_train[:, i, 0],
                            X_train[:, i, 1],
                            alpha=0.2,
                            marker='x',
                            label='Cluster {}'.format(i))
            plot_covariance(means[i, :],
                            covars[i, :, :],
                            ax,
                            color=sc.get_facecolor()[0])

        sc = ax.scatter(gmm.means[:, 0],
                        gmm.means[:, 1],
                        marker='+',
                        label='Fitted Gaussians')

        for i in range(K):
            plot_covariance(gmm.means[i, :],
                            gmm.covars[i, :, :],
                            ax,
                            color=sc.get_facecolor()[0])

        ax.legend()
        plt.show()
Пример #4
0
        svi.model.load_state_dict(torch.load(p))
        z_samples = svi.sample_prior(10000)
        x_samples = svi.sample_posterior(test_point, 10000)

    ax_lim = (-4, 4)

    fig, axes = plt.subplots(1, 2, figsize=(3, 1.5), sharex=True, sharey=True)

    corner.hist2d(z_samples[0, :N, 0].numpy(),
                  z_samples[0, :N, 1].numpy(),
                  ax=axes[0])
    corner.hist2d(x_samples[1, :N, 0].numpy(),
                  x_samples[1, :N, 1].numpy(),
                  ax=axes[1])

    plot_covariance(mean[1], cov[1], ax=axes[1], color='r')

    axes[0].set_title(r'$p_{\theta}(\mathbf{v})$')
    axes[1].set_title(r'$q_{\phi}(\mathbf{v})$')

    axes[0].set_xticks([])
    axes[0].set_yticks([])
    axes[1].set_xticks([])
    axes[1].set_yticks([])

    axes[1].set_xlim(ax_lim)
    axes[1].set_ylim(ax_lim)

    fig.tight_layout()
    fig.savefig('additional_{}.pdf'.format(i))