def test_conditional_distribution(): """Test moments from conditional GMM.""" random_state = check_random_state(0) gmm = GMM(n_components=2, priors=np.array([0.5, 0.5]), means=means, covariances=covariances, random_state=random_state) conditional = gmm.condition(np.array([1]), np.array([1.0])) assert_array_almost_equal(conditional.means[0], np.array([0.0])) assert_array_almost_equal(conditional.covariances[0], np.array([[0.3]])) conditional = gmm.condition(np.array([0]), np.array([2.0])) assert_array_almost_equal(conditional.means[1], np.array([-1.0])) assert_array_almost_equal(conditional.covariances[1], np.array([[0.3]]))
def choose(node_info: MixtureGaussianParams, pvals: List[Union[str, float]]) -> Optional[float]: """ Func to get value from current node node_info: nodes info from distributions pvals: parent values Return value from MixtureGaussian node """ mean = node_info["mean"] covariance = node_info["covars"] w = node_info["coef"] n_comp = len(node_info['coef']) if n_comp != 0: if pvals: indexes = [i for i in range(1, len(pvals) + 1)] if not np.isnan(np.array(pvals)).all(): gmm = GMM(n_components=n_comp, priors=w, means=mean, covariances=covariance) cond_gmm = gmm.condition(indexes, [pvals]) sample = cond_gmm.sample(1)[0][0] else: sample = np.nan else: gmm = GMM(n_components=n_comp, priors=w, means=mean, covariances=covariance) sample = gmm.sample(1)[0][0] else: sample = np.nan return sample
def test_from_samples_with_oas(): n_samples = 9 n_features = 2 X = np.ndarray((n_samples, n_features)) X[:n_samples // 3, :] = random_state.multivariate_normal( [0.0, 1.0], [[0.5, -1.0], [-1.0, 5.0]], size=(n_samples // 3, )) X[n_samples // 3:-n_samples // 3, :] = random_state.multivariate_normal( [-2.0, -2.0], [[3.0, 1.0], [1.0, 1.0]], size=(n_samples // 3, )) X[-n_samples // 3:, :] = random_state.multivariate_normal( [3.0, 3.0], [[3.0, -1.0], [-1.0, 1.0]], size=(n_samples // 3, )) gmm = GMM(n_components=3, random_state=random_state) gmm.from_samples(X, init_params="kmeans++", oracle_approximating_shrinkage=True) cond = gmm.condition(np.array([0]), np.array([1.0])) for i in range(cond.n_components): eigvals = np.linalg.eigvals(cond.covariances[i]) assert_true(all(eigvals >= 0))
def choose(node_info: Dict[str, Dict[str, CondMixtureGaussParams]], pvals: List[Union[str, float]]) -> Optional[float]: """ Function to get value from ConditionalMixtureGaussian node params: node_info: nodes info from distributions pvals: parent values """ dispvals = [] lgpvals = [] for pval in pvals: if ((isinstance(pval, str)) | ((isinstance(pval, int)))): dispvals.append(pval) else: lgpvals.append(pval) lgdistribution = node_info["hybcprob"][str(dispvals)] mean = lgdistribution["mean"] covariance = lgdistribution["covars"] w = lgdistribution["coef"] if len(w) != 0: if len(lgpvals) != 0: indexes = [i for i in range(1, (len(lgpvals) + 1), 1)] if not np.isnan(np.array(lgpvals)).all(): n_comp = len(w) gmm = GMM(n_components=n_comp, priors=w, means=mean, covariances=covariance) cond_gmm = gmm.condition(indexes, [lgpvals]) sample = cond_gmm.sample(1)[0][0] else: sample = np.nan else: n_comp = len(w) gmm = GMM(n_components=n_comp, priors=w, means=mean, covariances=covariance) sample = gmm.sample(1)[0][0] else: sample = np.nan return sample
gmm = GMM(n_components=n_components, priors=bgmm.weights_, means=bgmm.means_, covariances=bgmm.covariances_, random_state=random_state) plt.figure(figsize=(10, 5)) plt.subplot(121) plt.title("Confidence Interval from GMM") plt.plot(X[:, :, 0].T, X[:, :, 1].T, c="k", alpha=0.1) means_over_time = [] y_stds = [] for step in t: conditional_gmm = gmm.condition([0], np.array([step])) conditional_mvn = conditional_gmm.to_mvn() means_over_time.append(conditional_mvn.mean) y_stds.append(np.sqrt(conditional_mvn.covariance[1, 1])) samples = conditional_gmm.sample(100) plt.scatter(samples[:, 0], samples[:, 1], s=1) means_over_time = np.array(means_over_time) y_stds = np.array(y_stds) plt.plot(means_over_time[:, 0], means_over_time[:, 1], c="r", lw=2) plt.fill_between(means_over_time[:, 0], means_over_time[:, 1] - 1.96 * y_stds, means_over_time[:, 1] + 1.96 * y_stds, color="r", alpha=0.5)
X2 = np.ndarray((1, 2)) for i in range(len(X_test)): X_point = X_test[i] Y_point = Y[i] X2[0, 0] = X_point X2[0, 1] = Y_point conditioned = gmm.condition(x_axes, X_point) conditioned_y = gmm.condition(y_axes, Y_point) total_priors = conditioned.priors + conditioned_y.priors total_max = np.argmax(total_priors) total_max2 = gmm2.predict(X2)[0] if (total_max2 == 0): total_max2 = 1 elif (total_max2 == 1): total_max2 = 0 plt.scatter(X_point, conditioned.means[total_max], s=200, marker=">", color="k")
X = points[::50] X_dot = (X[2:] - X[:-2]) / dt X = X[1:-1] X_train = np.hstack((X, X_dot)) random_state = np.random.RandomState(0) n_components = 15 bgmm = BayesianGaussianMixture(n_components=n_components, max_iter=500, random_state=random_state).fit(X_train) gmm = GMM(n_components=n_components, priors=bgmm.weights_, means=bgmm.means_, covariances=bgmm.covariances_, random_state=random_state) sampled_path = [] x = np.array([75.0, 90.0]) # left bottom sampling_dt = 0.2 # increases sampling frequency for t in range(500): sampled_path.append(x) cgmm = gmm.condition([0, 1], x) # default alpha defines the confidence region (e.g., 0.7 -> 70 %) x_dot = cgmm.sample_confidence_region(1, alpha=0.7)[0] x = x + sampling_dt * x_dot sampled_path = np.array(sampled_path) plt.plot(sampled_path[:, 0], sampled_path[:, 1]) plt.plot(X[:, 0], X[:, 1], alpha=0.2) plt.show()
def test_condition_numerical_issue(): """Test for numerical issue in #27.""" covariances = np.array( [[[ 6.56478114e-03, 0.00000000e+00, 4.35794725e-01, 0.00000000e+00, 1.74768907e-03, -2.36645017e-02, 5.64492049e-01, -2.35229252e-02, -1.87923556e-02, -9.39617778e-02, 2.44300622e-02, -6.63130035e-01, 3.85327366e-01, -4.16132516e-01 ], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00 ], [ 4.35794725e-01, 0.00000000e+00, 1.33846496e+02, 0.00000000e+00, 5.36771200e-01, -5.15467320e+00, 7.84028133e+01, -4.98159335e+00, -5.77173333e+00, -2.88586667e+01, 7.50325333e+00, -9.75587840e+01, 6.22656653e+01, -9.68002133e+01 ], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00 ], [ 1.74768907e-03, 0.00000000e+00, 5.36771200e-01, 0.00000000e+00, 2.15264000e-03, -2.06720400e-02, 3.14422667e-01, -1.99779293e-02, -2.31466667e-02, -1.15733333e-01, 3.00906667e-02, -3.91244800e-01, 2.49707067e-01, -3.88202667e-01 ], [ -2.36645017e-02, 0.00000000e+00, -5.15467320e+00, 0.00000000e+00, -2.06720400e-02, 5.04383573e-01, -4.56174933e+00, 1.93516556e-01, 2.22280000e-01, 1.11140000e+00, -2.88964000e-01, 4.54903147e+00, -3.73500573e+00, 6.17985600e+00 ], [ 5.64492049e-01, 0.00000000e+00, 7.84028133e+01, 0.00000000e+00, 3.14422667e-01, -4.56174933e+00, 2.13116622e+02, -3.97556151e+00, -3.38088889e+00, -1.69044444e+01, 4.39515556e+00, -8.21465200e+01, 5.40752489e+01, -9.20326222e+01 ], [ -2.35229252e-02, 0.00000000e+00, -4.98159335e+00, 0.00000000e+00, -1.99779293e-02, 1.93516556e-01, -3.97556151e+00, 2.23553458e-01, 2.14816444e-01, 1.07408222e+00, -2.79261378e-01, 3.84379952e+00, -2.48378654e+00, 3.73002618e+00 ], [ -1.87923556e-02, 0.00000000e+00, -5.77173333e+00, 0.00000000e+00, -2.31466667e-02, 2.22280000e-01, -3.38088889e+00, 2.14816444e-01, 2.48888889e-01, 1.24444444e+00, -3.23555556e-01, 4.20693333e+00, -2.68502222e+00, 4.17422222e+00 ], [ -9.39617778e-02, 0.00000000e+00, -2.88586667e+01, 0.00000000e+00, -1.15733333e-01, 1.11140000e+00, -1.69044444e+01, 1.07408222e+00, 1.24444444e+00, 6.22222222e+00, -1.61777778e+00, 2.10346667e+01, -1.34251111e+01, 2.08711111e+01 ], [ 2.44300622e-02, 0.00000000e+00, 7.50325333e+00, 0.00000000e+00, 3.00906667e-02, -2.88964000e-01, 4.39515556e+00, -2.79261378e-01, -3.23555556e-01, -1.61777778e+00, 4.20622222e-01, -5.46901333e+00, 3.49052889e+00, -5.42648889e+00 ], [ -6.63130035e-01, 0.00000000e+00, -9.75587840e+01, 0.00000000e+00, -3.91244800e-01, 4.54903147e+00, -8.21465200e+01, 3.84379952e+00, 4.20693333e+00, 2.10346667e+01, -5.46901333e+00, 1.15347429e+02, -6.85498213e+01, 8.26589867e+01 ], [ 3.85327366e-01, 0.00000000e+00, 6.22656653e+01, 0.00000000e+00, 2.49707067e-01, -3.73500573e+00, 5.40752489e+01, -2.48378654e+00, -2.68502222e+00, -1.34251111e+01, 3.49052889e+00, -6.85498213e+01, 4.73001529e+01, -5.91605822e+01 ], [ -4.16132516e-01, 0.00000000e+00, -9.68002133e+01, 0.00000000e+00, -3.88202667e-01, 6.17985600e+00, -9.20326222e+01, 3.73002618e+00, 4.17422222e+00, 2.08711111e+01, -5.42648889e+00, 8.26589867e+01, -5.91605822e+01, 9.61286222e+01 ]], [[ 4.43734259e-31, 0.00000000e+00, 1.53827877e-29, 0.00000000e+00, 7.39557099e-31, 8.28303950e-30, 1.41994963e-29, 5.32481111e-30, 5.91645679e-30, 4.16518558e-28, 2.83989926e-29, 5.30114528e-28, 4.25984889e-29, 3.54987407e-29 ], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00 ], [ 1.53827877e-29, 0.00000000e+00, 5.33269972e-28, 0.00000000e+00, 2.56379794e-29, 2.87145370e-28, 4.92249205e-28, 1.84593452e-28, 2.05103835e-28, 1.44393100e-26, 9.84498410e-28, 1.83773036e-26, 1.47674761e-27, 1.23062301e-27 ], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00 ], [ 7.39557099e-31, 0.00000000e+00, 2.56379794e-29, 0.00000000e+00, 1.23259516e-30, 1.38050658e-29, 2.36658272e-29, 8.87468518e-30, 9.86076132e-30, 6.94197597e-28, 4.73316543e-29, 8.83524214e-28, 7.09974815e-29, 5.91645679e-29 ], [ 8.28303950e-30, 0.00000000e+00, 2.87145370e-28, 0.00000000e+00, 1.38050658e-29, 1.54616737e-28, 2.65057264e-28, 9.93964741e-29, 1.10440527e-28, 7.77501308e-27, 5.30114528e-28, 9.89547120e-27, 7.95171792e-28, 6.62643160e-28 ], [ 1.41994963e-29, 0.00000000e+00, 4.92249205e-28, 0.00000000e+00, 2.36658272e-29, 2.65057264e-28, 4.54383881e-28, 1.70393956e-28, 1.89326617e-28, 1.33285939e-26, 9.08767763e-28, 1.69636649e-26, 1.36315164e-27, 1.13595970e-27 ], [ 5.32481111e-30, 0.00000000e+00, 1.84593452e-28, 0.00000000e+00, 8.87468518e-30, 9.93964741e-29, 1.70393956e-28, 6.38977333e-29, 7.09974815e-29, 4.99822270e-27, 3.40787911e-28, 6.36137434e-27, 5.11181867e-28, 4.25984889e-28 ], [ 5.91645679e-30, 0.00000000e+00, 2.05103835e-28, 0.00000000e+00, 9.86076132e-30, 1.10440527e-28, 1.89326617e-28, 7.09974815e-29, 7.88860905e-29, 5.55358077e-27, 3.78653235e-28, 7.06819371e-27, 5.67979852e-28, 4.73316543e-28 ], [ 4.16518558e-28, 0.00000000e+00, 1.44393100e-26, 0.00000000e+00, 6.94197597e-28, 7.77501308e-27, 1.33285939e-26, 4.99822270e-27, 5.55358077e-27, 3.90972086e-25, 2.66571877e-26, 4.97600837e-25, 3.99857816e-26, 3.33214846e-26 ], [ 2.83989926e-29, 0.00000000e+00, 9.84498410e-28, 0.00000000e+00, 4.73316543e-29, 5.30114528e-28, 9.08767763e-28, 3.40787911e-28, 3.78653235e-28, 2.66571877e-26, 1.81753553e-27, 3.39273298e-26, 2.72630329e-27, 2.27191941e-27 ], [ 5.30114528e-28, 0.00000000e+00, 1.83773036e-26, 0.00000000e+00, 8.83524214e-28, 9.89547120e-27, 1.69636649e-26, 6.36137434e-27, 7.06819371e-27, 4.97600837e-25, 3.39273298e-26, 6.33310156e-25, 5.08909947e-26, 4.24091623e-26 ], [ 4.25984889e-29, 0.00000000e+00, 1.47674761e-27, 0.00000000e+00, 7.09974815e-29, 7.95171792e-28, 1.36315164e-27, 5.11181867e-28, 5.67979852e-28, 3.99857816e-26, 2.72630329e-27, 5.08909947e-26, 4.08945493e-27, 3.40787911e-27 ], [ 3.54987407e-29, 0.00000000e+00, 1.23062301e-27, 0.00000000e+00, 5.91645679e-29, 6.62643160e-28, 1.13595970e-27, 4.25984889e-28, 4.73316543e-28, 3.33214846e-26, 2.27191941e-27, 4.24091623e-26, 3.40787911e-27, 2.83989926e-27 ]]]) gmm = GMM( n_components=2, priors=np.array([0.9375, 0.0625]), means=np.array([[ 1.08140667e-01, 0.00000000e+00, 1.32820000e+01, 0.00000000e+00, 5.31400000e-01, 6.35760000e+00, 8.18266667e+01, 2.49500667e+00, 2.53333333e+00, 1.90666667e+02, 1.84066667e+01, 3.86002000e+02, 1.30506667e+01, 2.83733333e+01 ], [ 2.89550000e-01, 0.00000000e+00, 1.05900000e+01, 0.00000000e+00, 4.89000000e-01, 5.41200000e+00, 9.80000000e+00, 3.58750000e+00, 4.00000000e+00, 2.77000000e+02, 1.86000000e+01, 3.48930000e+02, 2.95500000e+01, 2.37000000e+01 ]]), covariances=covariances) x = np.array([ 1.0959e-01, 0.0000e+00, 1.1930e+01, 0.0000e+00, 5.7300e-01, 6.7940e+00, 8.9300e+01, 2.3889e+00, 1.0000e+00, 2.7300e+02, 2.1000e+01, 3.9345e+02, 6.4800e+00 ]) gmm.apply_oracle_approximating_shrinkage(n_samples_eff=506) cond_gmm = gmm.condition(np.arange(len(x)), x) assert_true(all(np.isfinite(cond_gmm.priors))) assert_true(all(np.linalg.eigvals(cond_gmm.covariances[0] >= 0))) assert_true(all(np.linalg.eigvals(cond_gmm.covariances[1] >= 0)))
plt.figure(figsize=(10, 5)) ax = plt.subplot(121) ax.set_title("Dataset and GMM") ax.scatter(x, y, s=1) colors = ["r", "g", "b", "orange"] plot_error_ellipses(ax, gmm, colors=colors) ax.set_xlabel("x") ax.set_ylabel("y") ax = plt.subplot(122) ax.set_title("Conditional Distribution") Y = np.linspace(0, 1, 1000) Y_test = Y[:, np.newaxis] X_test = 0.5 conditional_gmm = gmm.condition([0], [X_test]) p_of_Y = conditional_gmm.to_probability_density(Y_test) ax.plot(Y, p_of_Y, color="k", label="GMR", lw=3) for component_idx in range(conditional_gmm.n_components): p_of_Y = (conditional_gmm.priors[component_idx] * conditional_gmm.extract_mvn( component_idx).to_probability_density(Y_test)) ax.plot(Y, p_of_Y, color=colors[component_idx], label="Component %d" % (component_idx + 1)) ax.set_xlabel("y") ax.set_ylabel("$p(y|x=%.1f)$" % X_test) ax.legend(loc="best") plt.tight_layout()
random_state = check_random_state(0) n_samples = 300 n_features = 2 X = np.ndarray((n_samples, n_features)) X[:n_samples // 3, :] = random_state.multivariate_normal( [0.0, 1.0], [[0.5, -1.0], [-1.0, 5.0]], size=(n_samples // 3,)) X[n_samples // 3:-n_samples // 3, :] = random_state.multivariate_normal( [-2.0, -2.0], [[3.0, 1.0], [1.0, 1.0]], size=(n_samples // 3,)) X[-n_samples // 3:, :] = random_state.multivariate_normal( [3.0, 1.0], [[3.0, -1.0], [-1.0, 1.0]], size=(n_samples // 3,)) gmm = GMM(n_components=3, random_state=random_state) gmm.from_samples(X) cond = gmm.condition(np.array([0]), np.array([1.0])) plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.title("Gaussian Mixture Model") plt.xlim((-10, 10)) plt.ylim((-10, 10)) plot_error_ellipses(plt.gca(), gmm, colors=["r", "g", "b"]) plt.scatter(X[:, 0], X[:, 1]) plt.subplot(1, 3, 2) plt.title("Probability Density and Samples") plt.xlim((-10, 10)) plt.ylim((-10, 10)) x, y = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))
if __name__ == "__main__": random_state = check_random_state(0) n_samples = 300 n_features = 2 X = np.ndarray((n_samples, n_features)) X[:n_samples / 3, :] = random_state.multivariate_normal( [0.0, 1.0], [[0.5, -2.0], [-2.0, 5.0]], size=(n_samples / 3,)) X[n_samples / 3:-n_samples / 3, :] = random_state.multivariate_normal( [-2.0, -2.0], [[3.0, 2.0], [2.0, 1.0]], size=(n_samples / 3,)) X[-n_samples / 3:, :] = random_state.multivariate_normal( [3.0, 1.0], [[3.0, -1.0], [-1.0, 1.0]], size=(n_samples / 3,)) gmm = GMM(n_components=3, random_state=random_state) gmm.from_samples(X) cond = gmm.condition(np.array([0]), np.array([1.0])) plt.figure(figsize=(15, 5)) plt.subplot(1, 3, 1) plt.title("Gaussian Mixture Model") plt.xlim((-10, 10)) plt.ylim((-10, 10)) plot_error_ellipses(plt.gca(), gmm, colors=["r", "g", "b"]) plt.scatter(X[:, 0], X[:, 1]) plt.subplot(1, 3, 2) plt.title("Probability Density and Samples") plt.xlim((-10, 10)) plt.ylim((-10, 10)) x, y = np.meshgrid(np.linspace(-10, 10, 100), np.linspace(-10, 10, 100))