Beispiel #1
0
def test_conditional_distribution():
    """Test moments from conditional GMM."""
    random_state = check_random_state(0)

    gmm = GMM(n_components=2, priors=np.array([0.5, 0.5]), means=means,
              covariances=covariances, random_state=random_state)

    conditional = gmm.condition(np.array([1]), np.array([1.0]))
    assert_array_almost_equal(conditional.means[0], np.array([0.0]))
    assert_array_almost_equal(conditional.covariances[0], np.array([[0.3]]))
    conditional = gmm.condition(np.array([0]), np.array([2.0]))
    assert_array_almost_equal(conditional.means[1], np.array([-1.0]))
    assert_array_almost_equal(conditional.covariances[1], np.array([[0.3]]))
Beispiel #2
0
def test_conditional_distribution():
    """Test moments from conditional GMM."""
    random_state = check_random_state(0)

    gmm = GMM(n_components=2,
              priors=np.array([0.5, 0.5]),
              means=means,
              covariances=covariances,
              random_state=random_state)

    conditional = gmm.condition(np.array([1]), np.array([1.0]))
    assert_array_almost_equal(conditional.means[0], np.array([0.0]))
    assert_array_almost_equal(conditional.covariances[0], np.array([[0.3]]))
    conditional = gmm.condition(np.array([0]), np.array([2.0]))
    assert_array_almost_equal(conditional.means[1], np.array([-1.0]))
    assert_array_almost_equal(conditional.covariances[1], np.array([[0.3]]))
 def choose(node_info: MixtureGaussianParams,
            pvals: List[Union[str, float]]) -> Optional[float]:
     """
     Func to get value from current node
     node_info: nodes info from distributions
     pvals: parent values
     Return value from MixtureGaussian node
     """
     mean = node_info["mean"]
     covariance = node_info["covars"]
     w = node_info["coef"]
     n_comp = len(node_info['coef'])
     if n_comp != 0:
         if pvals:
             indexes = [i for i in range(1, len(pvals) + 1)]
             if not np.isnan(np.array(pvals)).all():
                 gmm = GMM(n_components=n_comp,
                           priors=w,
                           means=mean,
                           covariances=covariance)
                 cond_gmm = gmm.condition(indexes, [pvals])
                 sample = cond_gmm.sample(1)[0][0]
             else:
                 sample = np.nan
         else:
             gmm = GMM(n_components=n_comp,
                       priors=w,
                       means=mean,
                       covariances=covariance)
             sample = gmm.sample(1)[0][0]
     else:
         sample = np.nan
     return sample
Beispiel #4
0
def test_from_samples_with_oas():
    n_samples = 9
    n_features = 2
    X = np.ndarray((n_samples, n_features))
    X[:n_samples // 3, :] = random_state.multivariate_normal(
        [0.0, 1.0], [[0.5, -1.0], [-1.0, 5.0]], size=(n_samples // 3, ))
    X[n_samples // 3:-n_samples // 3, :] = random_state.multivariate_normal(
        [-2.0, -2.0], [[3.0, 1.0], [1.0, 1.0]], size=(n_samples // 3, ))
    X[-n_samples // 3:, :] = random_state.multivariate_normal(
        [3.0, 3.0], [[3.0, -1.0], [-1.0, 1.0]], size=(n_samples // 3, ))

    gmm = GMM(n_components=3, random_state=random_state)
    gmm.from_samples(X,
                     init_params="kmeans++",
                     oracle_approximating_shrinkage=True)
    cond = gmm.condition(np.array([0]), np.array([1.0]))
    for i in range(cond.n_components):
        eigvals = np.linalg.eigvals(cond.covariances[i])
        assert_true(all(eigvals >= 0))
 def choose(node_info: Dict[str, Dict[str, CondMixtureGaussParams]],
            pvals: List[Union[str, float]]) -> Optional[float]:
     """
     Function to get value from ConditionalMixtureGaussian node
     params:
     node_info: nodes info from distributions
     pvals: parent values
     """
     dispvals = []
     lgpvals = []
     for pval in pvals:
         if ((isinstance(pval, str)) | ((isinstance(pval, int)))):
             dispvals.append(pval)
         else:
             lgpvals.append(pval)
     lgdistribution = node_info["hybcprob"][str(dispvals)]
     mean = lgdistribution["mean"]
     covariance = lgdistribution["covars"]
     w = lgdistribution["coef"]
     if len(w) != 0:
         if len(lgpvals) != 0:
             indexes = [i for i in range(1, (len(lgpvals) + 1), 1)]
             if not np.isnan(np.array(lgpvals)).all():
                 n_comp = len(w)
                 gmm = GMM(n_components=n_comp,
                           priors=w,
                           means=mean,
                           covariances=covariance)
                 cond_gmm = gmm.condition(indexes, [lgpvals])
                 sample = cond_gmm.sample(1)[0][0]
             else:
                 sample = np.nan
         else:
             n_comp = len(w)
             gmm = GMM(n_components=n_comp,
                       priors=w,
                       means=mean,
                       covariances=covariance)
             sample = gmm.sample(1)[0][0]
     else:
         sample = np.nan
     return sample
Beispiel #6
0
gmm = GMM(n_components=n_components,
          priors=bgmm.weights_,
          means=bgmm.means_,
          covariances=bgmm.covariances_,
          random_state=random_state)

plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title("Confidence Interval from GMM")

plt.plot(X[:, :, 0].T, X[:, :, 1].T, c="k", alpha=0.1)

means_over_time = []
y_stds = []
for step in t:
    conditional_gmm = gmm.condition([0], np.array([step]))
    conditional_mvn = conditional_gmm.to_mvn()
    means_over_time.append(conditional_mvn.mean)
    y_stds.append(np.sqrt(conditional_mvn.covariance[1, 1]))
    samples = conditional_gmm.sample(100)
    plt.scatter(samples[:, 0], samples[:, 1], s=1)
means_over_time = np.array(means_over_time)
y_stds = np.array(y_stds)

plt.plot(means_over_time[:, 0], means_over_time[:, 1], c="r", lw=2)
plt.fill_between(means_over_time[:, 0],
                 means_over_time[:, 1] - 1.96 * y_stds,
                 means_over_time[:, 1] + 1.96 * y_stds,
                 color="r",
                 alpha=0.5)



X2 = np.ndarray((1, 2))

for i in range(len(X_test)):
    X_point = X_test[i]
    Y_point = Y[i]



    X2[0, 0] = X_point
    X2[0, 1] = Y_point

    conditioned = gmm.condition(x_axes, X_point)
    conditioned_y = gmm.condition(y_axes, Y_point)


    total_priors = conditioned.priors + conditioned_y.priors
    total_max = np.argmax(total_priors)

    total_max2 = gmm2.predict(X2)[0]

    if (total_max2 == 0):
        total_max2 = 1
    elif (total_max2 == 1):
        total_max2 = 0


    plt.scatter(X_point, conditioned.means[total_max], s=200, marker=">", color="k")
X = points[::50]
X_dot = (X[2:] - X[:-2]) / dt
X = X[1:-1]
X_train = np.hstack((X, X_dot))

random_state = np.random.RandomState(0)
n_components = 15

bgmm = BayesianGaussianMixture(n_components=n_components,
                               max_iter=500,
                               random_state=random_state).fit(X_train)
gmm = GMM(n_components=n_components,
          priors=bgmm.weights_,
          means=bgmm.means_,
          covariances=bgmm.covariances_,
          random_state=random_state)

sampled_path = []
x = np.array([75.0, 90.0])  # left bottom
sampling_dt = 0.2  # increases sampling frequency
for t in range(500):
    sampled_path.append(x)
    cgmm = gmm.condition([0, 1], x)
    # default alpha defines the confidence region (e.g., 0.7 -> 70 %)
    x_dot = cgmm.sample_confidence_region(1, alpha=0.7)[0]
    x = x + sampling_dt * x_dot
sampled_path = np.array(sampled_path)

plt.plot(sampled_path[:, 0], sampled_path[:, 1])
plt.plot(X[:, 0], X[:, 1], alpha=0.2)
plt.show()
Beispiel #9
0
def test_condition_numerical_issue():
    """Test for numerical issue in #27."""
    covariances = np.array(
        [[[
            6.56478114e-03, 0.00000000e+00, 4.35794725e-01, 0.00000000e+00,
            1.74768907e-03, -2.36645017e-02, 5.64492049e-01, -2.35229252e-02,
            -1.87923556e-02, -9.39617778e-02, 2.44300622e-02, -6.63130035e-01,
            3.85327366e-01, -4.16132516e-01
        ],
          [
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00
          ],
          [
              4.35794725e-01, 0.00000000e+00, 1.33846496e+02, 0.00000000e+00,
              5.36771200e-01, -5.15467320e+00, 7.84028133e+01, -4.98159335e+00,
              -5.77173333e+00, -2.88586667e+01, 7.50325333e+00,
              -9.75587840e+01, 6.22656653e+01, -9.68002133e+01
          ],
          [
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00
          ],
          [
              1.74768907e-03, 0.00000000e+00, 5.36771200e-01, 0.00000000e+00,
              2.15264000e-03, -2.06720400e-02, 3.14422667e-01, -1.99779293e-02,
              -2.31466667e-02, -1.15733333e-01, 3.00906667e-02,
              -3.91244800e-01, 2.49707067e-01, -3.88202667e-01
          ],
          [
              -2.36645017e-02, 0.00000000e+00, -5.15467320e+00, 0.00000000e+00,
              -2.06720400e-02, 5.04383573e-01, -4.56174933e+00, 1.93516556e-01,
              2.22280000e-01, 1.11140000e+00, -2.88964000e-01, 4.54903147e+00,
              -3.73500573e+00, 6.17985600e+00
          ],
          [
              5.64492049e-01, 0.00000000e+00, 7.84028133e+01, 0.00000000e+00,
              3.14422667e-01, -4.56174933e+00, 2.13116622e+02, -3.97556151e+00,
              -3.38088889e+00, -1.69044444e+01, 4.39515556e+00,
              -8.21465200e+01, 5.40752489e+01, -9.20326222e+01
          ],
          [
              -2.35229252e-02, 0.00000000e+00, -4.98159335e+00, 0.00000000e+00,
              -1.99779293e-02, 1.93516556e-01, -3.97556151e+00, 2.23553458e-01,
              2.14816444e-01, 1.07408222e+00, -2.79261378e-01, 3.84379952e+00,
              -2.48378654e+00, 3.73002618e+00
          ],
          [
              -1.87923556e-02, 0.00000000e+00, -5.77173333e+00, 0.00000000e+00,
              -2.31466667e-02, 2.22280000e-01, -3.38088889e+00, 2.14816444e-01,
              2.48888889e-01, 1.24444444e+00, -3.23555556e-01, 4.20693333e+00,
              -2.68502222e+00, 4.17422222e+00
          ],
          [
              -9.39617778e-02, 0.00000000e+00, -2.88586667e+01, 0.00000000e+00,
              -1.15733333e-01, 1.11140000e+00, -1.69044444e+01, 1.07408222e+00,
              1.24444444e+00, 6.22222222e+00, -1.61777778e+00, 2.10346667e+01,
              -1.34251111e+01, 2.08711111e+01
          ],
          [
              2.44300622e-02, 0.00000000e+00, 7.50325333e+00, 0.00000000e+00,
              3.00906667e-02, -2.88964000e-01, 4.39515556e+00, -2.79261378e-01,
              -3.23555556e-01, -1.61777778e+00, 4.20622222e-01,
              -5.46901333e+00, 3.49052889e+00, -5.42648889e+00
          ],
          [
              -6.63130035e-01, 0.00000000e+00, -9.75587840e+01, 0.00000000e+00,
              -3.91244800e-01, 4.54903147e+00, -8.21465200e+01, 3.84379952e+00,
              4.20693333e+00, 2.10346667e+01, -5.46901333e+00, 1.15347429e+02,
              -6.85498213e+01, 8.26589867e+01
          ],
          [
              3.85327366e-01, 0.00000000e+00, 6.22656653e+01, 0.00000000e+00,
              2.49707067e-01, -3.73500573e+00, 5.40752489e+01, -2.48378654e+00,
              -2.68502222e+00, -1.34251111e+01, 3.49052889e+00,
              -6.85498213e+01, 4.73001529e+01, -5.91605822e+01
          ],
          [
              -4.16132516e-01, 0.00000000e+00, -9.68002133e+01, 0.00000000e+00,
              -3.88202667e-01, 6.17985600e+00, -9.20326222e+01, 3.73002618e+00,
              4.17422222e+00, 2.08711111e+01, -5.42648889e+00, 8.26589867e+01,
              -5.91605822e+01, 9.61286222e+01
          ]],
         [[
             4.43734259e-31, 0.00000000e+00, 1.53827877e-29, 0.00000000e+00,
             7.39557099e-31, 8.28303950e-30, 1.41994963e-29, 5.32481111e-30,
             5.91645679e-30, 4.16518558e-28, 2.83989926e-29, 5.30114528e-28,
             4.25984889e-29, 3.54987407e-29
         ],
          [
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00
          ],
          [
              1.53827877e-29, 0.00000000e+00, 5.33269972e-28, 0.00000000e+00,
              2.56379794e-29, 2.87145370e-28, 4.92249205e-28, 1.84593452e-28,
              2.05103835e-28, 1.44393100e-26, 9.84498410e-28, 1.83773036e-26,
              1.47674761e-27, 1.23062301e-27
          ],
          [
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
              0.00000000e+00, 0.00000000e+00
          ],
          [
              7.39557099e-31, 0.00000000e+00, 2.56379794e-29, 0.00000000e+00,
              1.23259516e-30, 1.38050658e-29, 2.36658272e-29, 8.87468518e-30,
              9.86076132e-30, 6.94197597e-28, 4.73316543e-29, 8.83524214e-28,
              7.09974815e-29, 5.91645679e-29
          ],
          [
              8.28303950e-30, 0.00000000e+00, 2.87145370e-28, 0.00000000e+00,
              1.38050658e-29, 1.54616737e-28, 2.65057264e-28, 9.93964741e-29,
              1.10440527e-28, 7.77501308e-27, 5.30114528e-28, 9.89547120e-27,
              7.95171792e-28, 6.62643160e-28
          ],
          [
              1.41994963e-29, 0.00000000e+00, 4.92249205e-28, 0.00000000e+00,
              2.36658272e-29, 2.65057264e-28, 4.54383881e-28, 1.70393956e-28,
              1.89326617e-28, 1.33285939e-26, 9.08767763e-28, 1.69636649e-26,
              1.36315164e-27, 1.13595970e-27
          ],
          [
              5.32481111e-30, 0.00000000e+00, 1.84593452e-28, 0.00000000e+00,
              8.87468518e-30, 9.93964741e-29, 1.70393956e-28, 6.38977333e-29,
              7.09974815e-29, 4.99822270e-27, 3.40787911e-28, 6.36137434e-27,
              5.11181867e-28, 4.25984889e-28
          ],
          [
              5.91645679e-30, 0.00000000e+00, 2.05103835e-28, 0.00000000e+00,
              9.86076132e-30, 1.10440527e-28, 1.89326617e-28, 7.09974815e-29,
              7.88860905e-29, 5.55358077e-27, 3.78653235e-28, 7.06819371e-27,
              5.67979852e-28, 4.73316543e-28
          ],
          [
              4.16518558e-28, 0.00000000e+00, 1.44393100e-26, 0.00000000e+00,
              6.94197597e-28, 7.77501308e-27, 1.33285939e-26, 4.99822270e-27,
              5.55358077e-27, 3.90972086e-25, 2.66571877e-26, 4.97600837e-25,
              3.99857816e-26, 3.33214846e-26
          ],
          [
              2.83989926e-29, 0.00000000e+00, 9.84498410e-28, 0.00000000e+00,
              4.73316543e-29, 5.30114528e-28, 9.08767763e-28, 3.40787911e-28,
              3.78653235e-28, 2.66571877e-26, 1.81753553e-27, 3.39273298e-26,
              2.72630329e-27, 2.27191941e-27
          ],
          [
              5.30114528e-28, 0.00000000e+00, 1.83773036e-26, 0.00000000e+00,
              8.83524214e-28, 9.89547120e-27, 1.69636649e-26, 6.36137434e-27,
              7.06819371e-27, 4.97600837e-25, 3.39273298e-26, 6.33310156e-25,
              5.08909947e-26, 4.24091623e-26
          ],
          [
              4.25984889e-29, 0.00000000e+00, 1.47674761e-27, 0.00000000e+00,
              7.09974815e-29, 7.95171792e-28, 1.36315164e-27, 5.11181867e-28,
              5.67979852e-28, 3.99857816e-26, 2.72630329e-27, 5.08909947e-26,
              4.08945493e-27, 3.40787911e-27
          ],
          [
              3.54987407e-29, 0.00000000e+00, 1.23062301e-27, 0.00000000e+00,
              5.91645679e-29, 6.62643160e-28, 1.13595970e-27, 4.25984889e-28,
              4.73316543e-28, 3.33214846e-26, 2.27191941e-27, 4.24091623e-26,
              3.40787911e-27, 2.83989926e-27
          ]]])
    gmm = GMM(
        n_components=2,
        priors=np.array([0.9375, 0.0625]),
        means=np.array([[
            1.08140667e-01, 0.00000000e+00, 1.32820000e+01, 0.00000000e+00,
            5.31400000e-01, 6.35760000e+00, 8.18266667e+01, 2.49500667e+00,
            2.53333333e+00, 1.90666667e+02, 1.84066667e+01, 3.86002000e+02,
            1.30506667e+01, 2.83733333e+01
        ],
                        [
                            2.89550000e-01, 0.00000000e+00, 1.05900000e+01,
                            0.00000000e+00, 4.89000000e-01, 5.41200000e+00,
                            9.80000000e+00, 3.58750000e+00, 4.00000000e+00,
                            2.77000000e+02, 1.86000000e+01, 3.48930000e+02,
                            2.95500000e+01, 2.37000000e+01
                        ]]),
        covariances=covariances)
    x = np.array([
        1.0959e-01, 0.0000e+00, 1.1930e+01, 0.0000e+00, 5.7300e-01, 6.7940e+00,
        8.9300e+01, 2.3889e+00, 1.0000e+00, 2.7300e+02, 2.1000e+01, 3.9345e+02,
        6.4800e+00
    ])
    gmm.apply_oracle_approximating_shrinkage(n_samples_eff=506)
    cond_gmm = gmm.condition(np.arange(len(x)), x)
    assert_true(all(np.isfinite(cond_gmm.priors)))
    assert_true(all(np.linalg.eigvals(cond_gmm.covariances[0] >= 0)))
    assert_true(all(np.linalg.eigvals(cond_gmm.covariances[1] >= 0)))
Beispiel #10
0
plt.figure(figsize=(10, 5))

ax = plt.subplot(121)
ax.set_title("Dataset and GMM")
ax.scatter(x, y, s=1)
colors = ["r", "g", "b", "orange"]
plot_error_ellipses(ax, gmm, colors=colors)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax = plt.subplot(122)
ax.set_title("Conditional Distribution")
Y = np.linspace(0, 1, 1000)
Y_test = Y[:, np.newaxis]
X_test = 0.5
conditional_gmm = gmm.condition([0], [X_test])
p_of_Y = conditional_gmm.to_probability_density(Y_test)
ax.plot(Y, p_of_Y, color="k", label="GMR", lw=3)
for component_idx in range(conditional_gmm.n_components):
    p_of_Y = (conditional_gmm.priors[component_idx] *
              conditional_gmm.extract_mvn(
                  component_idx).to_probability_density(Y_test))
    ax.plot(Y,
            p_of_Y,
            color=colors[component_idx],
            label="Component %d" % (component_idx + 1))
ax.set_xlabel("y")
ax.set_ylabel("$p(y|x=%.1f)$" % X_test)
ax.legend(loc="best")

plt.tight_layout()
Beispiel #11
0
random_state = check_random_state(0)

n_samples = 300
n_features = 2
X = np.ndarray((n_samples, n_features))
X[:n_samples // 3, :] = random_state.multivariate_normal(
    [0.0, 1.0], [[0.5, -1.0], [-1.0, 5.0]], size=(n_samples // 3,))
X[n_samples // 3:-n_samples // 3, :] = random_state.multivariate_normal(
    [-2.0, -2.0], [[3.0, 1.0], [1.0, 1.0]], size=(n_samples // 3,))
X[-n_samples // 3:, :] = random_state.multivariate_normal(
    [3.0, 1.0], [[3.0, -1.0], [-1.0, 1.0]], size=(n_samples // 3,))

gmm = GMM(n_components=3, random_state=random_state)
gmm.from_samples(X)
cond = gmm.condition(np.array([0]), np.array([1.0]))

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title("Gaussian Mixture Model")
plt.xlim((-10, 10))
plt.ylim((-10, 10))
plot_error_ellipses(plt.gca(), gmm, colors=["r", "g", "b"])
plt.scatter(X[:, 0], X[:, 1])

plt.subplot(1, 3, 2)
plt.title("Probability Density and Samples")
plt.xlim((-10, 10))
plt.ylim((-10, 10))
x, y = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))
Beispiel #12
0
if __name__ == "__main__":
    random_state = check_random_state(0)

    n_samples = 300
    n_features = 2
    X = np.ndarray((n_samples, n_features))
    X[:n_samples / 3, :] = random_state.multivariate_normal(
        [0.0, 1.0], [[0.5, -2.0], [-2.0, 5.0]], size=(n_samples / 3,))
    X[n_samples / 3:-n_samples / 3, :] = random_state.multivariate_normal(
        [-2.0, -2.0], [[3.0, 2.0], [2.0, 1.0]], size=(n_samples / 3,))
    X[-n_samples / 3:, :] = random_state.multivariate_normal(
        [3.0, 1.0], [[3.0, -1.0], [-1.0, 1.0]], size=(n_samples / 3,))

    gmm = GMM(n_components=3, random_state=random_state)
    gmm.from_samples(X)
    cond = gmm.condition(np.array([0]), np.array([1.0]))

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.title("Gaussian Mixture Model")
    plt.xlim((-10, 10))
    plt.ylim((-10, 10))
    plot_error_ellipses(plt.gca(), gmm, colors=["r", "g", "b"])
    plt.scatter(X[:, 0], X[:, 1])

    plt.subplot(1, 3, 2)
    plt.title("Probability Density and Samples")
    plt.xlim((-10, 10))
    plt.ylim((-10, 10))
    x, y = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))