예제 #1
0
    def process_step(self):
        graph = self.graph
        node_attr = self.attr_transform(graph.node_attr)

        batch_adj, batch_x, self.cluster_member = gf.graph_partition(
            graph.adj_matrix, node_attr, n_clusters=self.n_clusters)

        batch_adj = self.adj_transform(*batch_adj)

        (self.batch_adj, self.batch_x) = gf.astensors(batch_adj,
                                                      batch_x,
                                                      device=self.device)
예제 #2
0
    def process_step(self):
        graph = self.transform.graph_transform(self.graph)
        graph.node_attr = self.transform.attr_transform(graph.node_attr)

        batch_adj, batch_x, cluster_member = gf.graph_partition(
            graph, n_clusters=self.cache.n_clusters)

        batch_adj = self.transform.adj_transform(*batch_adj)
        batch_adj, batch_x = gf.astensors(batch_adj,
                                          batch_x,
                                          device=self.device)

        # ``A`` and ``X`` and ``cluster_member`` are cached for later use
        self.register_cache("batch_x", batch_x)
        self.register_cache("batch_adj", batch_adj)
        self.register_cache("cluster_member", cluster_member)
예제 #3
0
    def data_step(self,
                  adj_transform="normalize_adj",
                  attr_transform=None,
                  num_clusters=10):

        graph = self.graph
        batch_adj, batch_x, cluster_member = gf.graph_partition(
            graph, num_clusters=num_clusters, metis_partition=True)

        batch_adj = gf.get(adj_transform)(*batch_adj)
        batch_x = gf.get(attr_transform)(*batch_x)

        batch_adj, batch_x = gf.astensors(batch_adj, batch_x, device=self.data_device)

        # ``A`` and ``X`` and ``cluster_member`` are cached for later use
        self.register_cache(batch_x=batch_x, batch_adj=batch_adj,
                            cluster_member=cluster_member)
예제 #4
0
    def data_step(self,
                  adj_transform="normalize_adj",
                  feat_transform=None,
                  num_clusters=10,
                  partition='louvain'):

        assert partition in {'metis', 'random', 'louvain'}

        graph = self.graph
        batch_adj, batch_feat, cluster_member = gf.graph_partition(
            graph, num_clusters=num_clusters, partition=partition)

        batch_adj = gf.get(adj_transform)(*batch_adj)
        batch_feat = gf.get(feat_transform)(*batch_feat)

        batch_adj, batch_feat = gf.astensors(batch_adj, batch_feat, device=self.data_device)

        self.register_cache(batch_feat=batch_feat, batch_adj=batch_adj,
                            cluster_member=cluster_member)
        # for louvain clustering
        self.num_clusters = len(cluster_member)