def __init__(self, n_meridians=40, n_circles_latitude=None, points=None): if n_circles_latitude is None: n_circles_latitude = max(n_meridians / 2, 4) u, v = gs.meshgrid(gs.arange(0, 2 * gs.pi, 2 * gs.pi / n_meridians), gs.arange(0, gs.pi, gs.pi / n_circles_latitude)) self.center = gs.zeros(3) self.radius = 1 self.sphere_x = self.center[0] + self.radius * gs.cos(u) * gs.sin(v) self.sphere_y = self.center[1] + self.radius * gs.sin(u) * gs.sin(v) self.sphere_z = self.center[2] + self.radius * gs.cos(v) self.points = [] if points is not None: self.add_points(points)
def plot_gaussian_mixture_distribution( data, mixture_coefficients, means, variances, plot_precision=DEFAULT_PLOT_PRECISION, save_path="", metric=None, ): """Plot Gaussian Mixture Model.""" x_axis_samples = gs.linspace(-1, 1, plot_precision) y_axis_samples = gs.linspace(-1, 1, plot_precision) x_axis_samples, y_axis_samples = gs.meshgrid(x_axis_samples, y_axis_samples) z_axis_samples = gs.zeros((plot_precision, plot_precision)) for z_index, _ in enumerate(z_axis_samples): x_y_plane_mesh = gs.concatenate( ( gs.expand_dims(x_axis_samples[z_index], -1), gs.expand_dims(y_axis_samples[z_index], -1), ), axis=-1, ) mesh_probabilities = weighted_gmm_pdf( mixture_coefficients, x_y_plane_mesh, means, variances, metric ) z_axis_samples[z_index] = mesh_probabilities.sum(-1) fig = plt.figure( "Learned Gaussian Mixture Model " "via Expectation Maximization on Poincaré Disc" ) ax = fig.gca(projection="3d") ax.plot_surface( x_axis_samples, y_axis_samples, z_axis_samples, rstride=1, cstride=1, linewidth=1, antialiased=True, cmap=plt.get_cmap("viridis"), ) z_circle = -0.8 p = Circle((0, 0), 1, edgecolor="b", lw=1, facecolor="none") ax.add_patch(p) art3d.pathpatch_2d_to_3d(p, z=z_circle, zdir="z") for data_index, _ in enumerate(data): ax.scatter( data[data_index][0], data[data_index][1], z_circle, c="b", marker="." ) for means_index, _ in enumerate(means): ax.scatter( means[means_index][0], means[means_index][1], z_circle, c="r", marker="D" ) ax.set_xlim(-1.2, 1.2) ax.set_ylim(-1.2, 1.2) ax.set_zlim(-0.8, 0.4) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("P") plt.savefig(save_path, format="pdf") return plt