def predict(self, index): mask = gf.indices2mask(index, self.graph.num_nodes) orders_dict = {idx: order for order, idx in enumerate(index)} batch_idx, orders = [], [] batch_x, batch_adj = [], [] for cluster in range(self.n_clusters): nodes = self.cluster_member[cluster] mini_mask = mask[nodes] batch_nodes = np.asarray(nodes)[mini_mask] if batch_nodes.size == 0: continue batch_x.append(self.batch_x[cluster]) batch_adj.append(self.batch_adj[cluster]) batch_idx.append(np.where(mini_mask)[0]) orders.append([orders_dict[n] for n in batch_nodes]) batch_data = tuple(zip(batch_x, batch_adj, batch_idx)) logit = np.zeros((index.size, self.graph.num_node_classes), dtype=self.floatx) batch_data = gf.astensors(batch_data, device=self.device) model = self.model with tf.device(self.device): for order, inputs in zip(orders, batch_data): output = model.predict_on_batch(inputs) logit[order] = output return logit
def train_sequence(self, index, batch_size=np.inf): mask = gf.indices2mask(index, self.graph.num_nodes) index = get_indice_graph(self.structure_inputs, index, batch_size) while index.size < self.K: index = get_indice_graph(self.structure_inputs, index) structure_inputs = self.structure_inputs[index][:, index] feature_inputs = self.feature_inputs[index] mask = mask[index] labels = self.graph.node_label[index[mask]] sequence = FullBatchNodeSequence( [feature_inputs, structure_inputs, mask], labels, device=self.device) return sequence
def train_sequence(self, index): mask = gf.indices2mask(index, self.graph.num_nodes) labels = self.graph.node_label batch_idx, batch_labels = [], [] batch_x, batch_adj = [], [] for cluster in range(self.n_clusters): nodes = self.cluster_member[cluster] mini_mask = mask[nodes] mini_labels = labels[nodes][mini_mask] if mini_labels.size == 0: continue batch_x.append(self.batch_x[cluster]) batch_adj.append(self.batch_adj[cluster]) batch_idx.append(np.where(mini_mask)[0]) batch_labels.append(mini_labels) batch_data = tuple(zip(batch_x, batch_adj, batch_idx)) sequence = MiniBatchSequence(batch_data, batch_labels, device=self.device) return sequence