def train_sequence(self, index): mask = gf.indices2mask(index, self.graph.num_nodes) labels = self.graph.node_label batch_idx, batch_labels = [], [] batch_x, batch_adj = [], [] for cluster in range(self.n_clusters): nodes = self.cluster_member[cluster] mini_mask = mask[nodes] mini_labels = labels[nodes][mini_mask] if mini_labels.size == 0: continue batch_x.append(self.batch_x[cluster]) batch_adj.append(self.batch_adj[cluster]) batch_idx.append(np.where(mini_mask)[0]) batch_labels.append(mini_labels) batch_data = tuple(zip(batch_x, batch_adj, batch_idx)) sequence = MiniBatchSequence(batch_data, batch_labels, device=self.device) return sequence
def train_sequence(self, index): node_mask = gf.index_to_mask(index, self.graph.num_nodes) labels = self.graph.node_label cache = self.cache batch_mask, batch_y = [], [] batch_x, batch_adj = [], [] for cluster in range(self.cfg.process.num_clusters): nodes = cache.cluster_member[cluster] mask = node_mask[nodes] y = labels[nodes][mask] if y.size == 0: continue batch_x.append(cache.batch_x[cluster]) batch_adj.append(cache.batch_adj[cluster]) batch_mask.append(mask) batch_y.append(y) batch_inputs = tuple(zip(batch_x, batch_adj)) sequence = MiniBatchSequence(batch_inputs, batch_y, out_weight=batch_mask, device=self.device) return sequence