Esempio n. 1
0
class TestPrepareGraphData(geomstats.tests.TestCase):
    """Class for testing embedding."""

    def setup_method(self):
        """Set up function."""
        gs.random.seed(1234)
        dim = 2
        max_epochs = 3
        lr = 0.05
        n_negative = 2
        context_size = 1
        self.karate_graph = load_karate_graph()

        self.embedding = HyperbolicEmbedding(
            dim=dim,
            max_epochs=max_epochs,
            lr=lr,
            n_context=context_size,
            n_negative=n_negative,
        )

    def test_log_sigmoid(self):
        """Test log_sigmoid."""
        point = gs.array([0.1, 0.3])
        result = self.embedding.log_sigmoid(point)

        expected = gs.array([-0.644397, -0.554355])
        self.assertAllClose(result, expected)

    def test_grad_log_sigmoid(self):
        """Test grad_log_sigmoid."""
        point = gs.array([0.1, 0.3])
        result = self.embedding.grad_log_sigmoid(point)

        expected = gs.array([0.47502081, 0.42555748])
        self.assertAllClose(result, expected)

    def test_loss(self):
        """Test loss function."""
        point = gs.array([0.5, 0.5])
        point_context = gs.array([0.6, 0.6])
        point_negative = gs.array([-0.4, -0.4])

        loss_value, loss_grad = self.embedding.loss(
            point, point_context, point_negative
        )

        expected_loss = 1.00322045
        expected_grad = gs.array([-0.16565083, -0.16565083])

        self.assertAllClose(loss_value[0], expected_loss)
        self.assertAllClose(gs.squeeze(loss_grad), expected_grad)

    def test_embed(self):
        """Test embedding function."""
        embeddings = self.embedding.embed(self.karate_graph)
        self.assertTrue(gs.all(self.embedding.manifold.belongs(embeddings)))
    def setUp(self):
        """Set up function."""
        gs.random.seed(1234)
        dim = 2
        max_epochs = 3
        lr = .05
        n_negative = 2
        context_size = 1
        self.karate_graph = load_karate_graph()

        self.embedding = HyperbolicEmbedding(dim=dim,
                                             max_epochs=max_epochs,
                                             lr=lr,
                                             n_context=context_size,
                                             n_negative=n_negative)
def main():
    """Learning Poincaré graph embedding.

    Learns Poincaré Ball embedding by using Riemannian
    gradient descent algorithm. Then K-means is applied
    to learn labels of each data sample.
    """
    gs.random.seed(1234)

    karate_graph = load_karate_graph()

    hyperbolic_embedding = HyperbolicEmbedding()

    embeddings = hyperbolic_embedding.embed(karate_graph)

    colors = {1: 'b', 2: 'r'}
    group_1 = mpatches.Patch(color=colors[1], label='Group 1')
    group_2 = mpatches.Patch(color=colors[2], label='Group 2')

    circle = visualization.PoincareDisk(point_type='ball')

    _, ax = plt.subplots(figsize=(8, 8))
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    circle.set_ax(ax)
    circle.draw(ax=ax)
    for i_embedding, embedding in enumerate(embeddings):
        x = embedding[0]
        y = embedding[1]
        pt_id = i_embedding
        plt.scatter(
            x, y,
            c=colors[karate_graph.labels[pt_id][0]],
            s=150
        )
        ax.annotate(pt_id, (x, y))

    plt.tick_params(
        which='both')
    plt.title('Poincare Ball Embedding of the Karate Club Network')
    plt.legend(handles=[group_1, group_2])
    plt.show()

    n_clusters = 2

    kmeans = RiemannianKMeans(
        riemannian_metric=hyperbolic_embedding.manifold.metric,
        n_clusters=n_clusters,
        init='random',
        mean_method='frechet-poincare-ball')

    centroids = kmeans.fit(X=embeddings, max_iter=100)
    labels = kmeans.predict(X=embeddings)

    colors = ['g', 'c', 'm']
    circle = visualization.PoincareDisk(point_type='ball')
    _, ax2 = plt.subplots(figsize=(8, 8))
    circle.set_ax(ax2)
    circle.draw(ax=ax2)
    ax2.axes.xaxis.set_visible(False)
    ax2.axes.yaxis.set_visible(False)
    group_1_predicted = mpatches.Patch(
        color=colors[0], label='Predicted Group 1')
    group_2_predicted = mpatches.Patch(
        color=colors[1], label='Predicted Group 2')
    group_centroids = mpatches.Patch(
        color=colors[2], label='Cluster centroids')

    for _ in range(n_clusters):
        for i_embedding, embedding in enumerate(embeddings):
            x = embedding[0]
            y = embedding[1]
            pt_id = i_embedding
            if labels[i_embedding] == 0:
                color = colors[0]
            else:
                color = colors[1]
            plt.scatter(
                x, y,
                c=color,
                s=150
            )
            ax2.annotate(pt_id, (x, y))

    for _, centroid in enumerate(centroids):
        x = centroid[0]
        y = centroid[1]
        plt.scatter(
            x, y,
            c=colors[2],
            marker='*',
            s=150,
        )

    plt.title('K-means applied to Karate club embedding')
    plt.legend(handles=[group_1_predicted, group_2_predicted, group_centroids])
    plt.show()