コード例 #1
0
def get_approx_digit_distribution(n_pca_comp=10,
                                  n_mixtures=5,
                                  covariance_type="spherical") -> Distribution:
    """
    Returns an GMM approximation to MNIST.

    Args:
        n_pca_comp: The number of dimensions to reduce the image size to.
        n_mixtures: The number of mixtures to use for each digit.

    Returns:
        Distribution: An approximate model of mnist.
    """
    X, y = load_digits(return_X_y=True)
    X = X.astype(np.float) / X.max()
    y = np.array([int(v) for v in y])
    distributions = []

    for i in range(10):
        print(f"training model on digit {i}")
        dist = train_gmm_pca_model(
            X[y == i, :],
            n_mixtures=n_mixtures,
            n_pca_comp=n_pca_comp,
            covariance_type=covariance_type,
        )
        distributions.append(dist)
    mnist_dist = Distribution.mixture(distributions)
    mnist_dist.visualize = plot_image_samples([8, 8], False)
    mnist_dist.rvs = get_samples(X)
    return mnist_dist
コード例 #2
0
def get_mnist_distribution() -> Sampler:
    """
    Returns a MNIST sampler.

    Returns:
        Sampler: Sampling mnist digits
    """
    X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
    X = 2 * X.astype(np.float) / X.max() - 1
    X = X.reshape([-1, 1, 28, 28])
    mnist_dist = Sampler.from_samples(
        X, noise=lambda shape: 2 * np.random.rand(*shape) / 255)
    mnist_dist.visualize = plot_image_samples([28, 28], False)
    return mnist_dist
コード例 #3
0
def get_digit_distribution() -> Sampler:
    """
    Returns a Digits Sampler.

    Returns:
        Sampler: Sampling digits
    """
    X, y = load_digits(return_X_y=True)
    X = 2 * X.astype(np.float) / X.max() - 1
    n = X.shape[0]
    X = X.reshape([n, 1, 8, 8])
    digit_dist = Sampler.from_samples(
        X, noise=lambda shape: 2 * np.random.rand(*shape) / 255)
    digit_dist.visualize = plot_image_samples([8, 8], False)
    return digit_dist
コード例 #4
0
def get_pattern_distribution(patterns: Tuple[str, ...] = (
    "checkerboard", )) -> "Sampler":
    """
    Returns a pattern sampler.

    Args:
        patterns: A tuple of patterns to use.

    Returns:
        Sampler: Sampling from the patterns
    """
    pattern_arrays = [PATTERNS[pattern].flatten() for pattern in patterns]
    X = np.stack(pattern_arrays, axis=0)
    pattern_dist = Sampler.from_samples(X)
    pattern_dist.visualize = plot_image_samples([2, 2], False)
    return pattern_dist
コード例 #5
0
def get_sm_digit_distribution() -> Sampler:
    """
    Returns a Digits Sampler.

    Returns:
        Sampler: Sampling digits
    """
    X, y = load_digits(return_X_y=True)
    X = 2 * X.astype(np.float) / X.max() - 1
    n = X.shape[0]
    X_torch = torch.tensor(X.reshape([n, 1, 8, 8]))
    X = torch.nn.AvgPool2d(kernel_size=(2, 2))(X_torch).numpy()
    digit_dist = Sampler.from_samples(
        X, noise=lambda shape: 2 * np.random.rand(*shape) / 255
    )
    digit_dist.visualize = plot_image_samples([4, 4], False)
    return digit_dist