Пример #1
0
    def predict(self, index):

        mask = gf.indices2mask(index, self.graph.num_nodes)

        orders_dict = {idx: order for order, idx in enumerate(index)}
        batch_idx, orders = [], []
        batch_x, batch_adj = [], []
        for cluster in range(self.n_clusters):
            nodes = self.cluster_member[cluster]
            mini_mask = mask[nodes]
            batch_nodes = np.asarray(nodes)[mini_mask]
            if batch_nodes.size == 0:
                continue
            batch_x.append(self.batch_x[cluster])
            batch_adj.append(self.batch_adj[cluster])
            batch_idx.append(np.where(mini_mask)[0])
            orders.append([orders_dict[n] for n in batch_nodes])

        batch_data = tuple(zip(batch_x, batch_adj, batch_idx))

        logit = np.zeros((index.size, self.graph.num_node_classes),
                         dtype=self.floatx)
        batch_data = gf.astensors(batch_data, device=self.device)

        model = self.model
        with tf.device(self.device):
            for order, inputs in zip(orders, batch_data):
                output = model.predict_on_batch(inputs)
                logit[order] = output

        return logit
Пример #2
0
    def train_sequence(self, index, batch_size=np.inf):

        mask = gf.indices2mask(index, self.graph.num_nodes)
        index = get_indice_graph(self.structure_inputs, index, batch_size)
        while index.size < self.K:
            index = get_indice_graph(self.structure_inputs, index)

        structure_inputs = self.structure_inputs[index][:, index]
        feature_inputs = self.feature_inputs[index]
        mask = mask[index]
        labels = self.graph.node_label[index[mask]]

        sequence = FullBatchNodeSequence(
            [feature_inputs, structure_inputs, mask],
            labels,
            device=self.device)
        return sequence
Пример #3
0
    def train_sequence(self, index):

        mask = gf.indices2mask(index, self.graph.num_nodes)
        labels = self.graph.node_label

        batch_idx, batch_labels = [], []
        batch_x, batch_adj = [], []
        for cluster in range(self.n_clusters):
            nodes = self.cluster_member[cluster]
            mini_mask = mask[nodes]
            mini_labels = labels[nodes][mini_mask]
            if mini_labels.size == 0:
                continue
            batch_x.append(self.batch_x[cluster])
            batch_adj.append(self.batch_adj[cluster])
            batch_idx.append(np.where(mini_mask)[0])
            batch_labels.append(mini_labels)

        batch_data = tuple(zip(batch_x, batch_adj, batch_idx))

        sequence = MiniBatchSequence(batch_data,
                                     batch_labels,
                                     device=self.device)
        return sequence