示例#1
0
    def train_sequence(self, index):

        mask = gf.indices2mask(index, self.graph.num_nodes)
        labels = self.graph.node_label

        batch_idx, batch_labels = [], []
        batch_x, batch_adj = [], []
        for cluster in range(self.n_clusters):
            nodes = self.cluster_member[cluster]
            mini_mask = mask[nodes]
            mini_labels = labels[nodes][mini_mask]
            if mini_labels.size == 0:
                continue
            batch_x.append(self.batch_x[cluster])
            batch_adj.append(self.batch_adj[cluster])
            batch_idx.append(np.where(mini_mask)[0])
            batch_labels.append(mini_labels)

        batch_data = tuple(zip(batch_x, batch_adj, batch_idx))

        sequence = MiniBatchSequence(batch_data,
                                     batch_labels,
                                     device=self.device)
        return sequence
示例#2
0
    def train_sequence(self, index):
        node_mask = gf.index_to_mask(index, self.graph.num_nodes)
        labels = self.graph.node_label
        cache = self.cache

        batch_mask, batch_y = [], []
        batch_x, batch_adj = [], []
        for cluster in range(self.cfg.process.num_clusters):
            nodes = cache.cluster_member[cluster]
            mask = node_mask[nodes]
            y = labels[nodes][mask]
            if y.size == 0:
                continue
            batch_x.append(cache.batch_x[cluster])
            batch_adj.append(cache.batch_adj[cluster])
            batch_mask.append(mask)
            batch_y.append(y)

        batch_inputs = tuple(zip(batch_x, batch_adj))
        sequence = MiniBatchSequence(batch_inputs,
                                     batch_y,
                                     out_weight=batch_mask,
                                     device=self.device)
        return sequence