Example #1
0
    def train_sequence(self, index):

        # if the graph is changed?
        labels = self.cache.label_onehot[index]
        sequence = FullBatchSequence([self.cache.X, self.cache.A, index],
                                     labels,
                                     device=self.device)
        return sequence
Example #2
0
    def train_sequence(self, index):

        labels = self.graph.node_label[index]
        sequence = FullBatchSequence(
            [self.cache.X, self.cache.A, index],
            labels,
            device=self.device)
        return sequence
Example #3
0
    def train_sequence(self, index):

        labels = self.graph.node_label[index]
        sequence = FullBatchSequence(x=[self.cache.X],
                                     y=labels,
                                     out_weight=index,
                                     device=self.device)
        return sequence
Example #4
0
    def train_loader(self, index):

        labels = self.graph.node_label[index]
        sequence = FullBatchSequence([self.cache.X, *self.cache.E],
                                     labels,
                                     out_weight=index,
                                     device=self.data_device)
        return sequence
Example #5
0
 def train_sequence(self, index):
     labels = self.graph.node_label[index]
     sequence = FullBatchSequence([self.cache.X, self.cache.G],
                                  labels,
                                  out_weight=index,
                                  device=self.device,
                                  escape=type(self.cache.G))
     return sequence
Example #6
0
    def train_sequence(self, index):

        labels = self.cache.Y[index]
        sequence = FullBatchSequence([self.cache.X, self.cache.A],
                                     labels,
                                     out_weight=index,
                                     device=self.device)
        return sequence
Example #7
0
    def test_loader(self, index):

        labels = self.graph.node_label[index]
        sequence = FullBatchSequence(x=self.cache.X,
                                     y=labels,
                                     out_weight=index,
                                     device=self.data_device)
        return sequence
Example #8
0
    def train_sequence(self, index):

        labels = self.graph.node_label[index]
        sequence = FullBatchSequence(
            [self.cache.X, self.cache.edge_index, self.cache.edge_x],
            labels,
            out_weight=index,
            device=self.device)
        return sequence
Example #9
0
 def train_sequence(self, index):
     
     labels = self.graph.node_label[index]
     cache = self.cache
     sequence = FullBatchSequence(x=[cache.X, cache.A, cache.knn_graph, 
                                     cache.pseudo_labels, cache.node_pairs],
                                  y=labels,
                                  out_weight=index,
                                  device=self.device)
     return sequence
Example #10
0
    def train_sequence(self, index, batch_size=np.inf):

        mask = gf.index_to_mask(index, self.graph.num_nodes)
        index = get_indice_graph(self.cache.A, index, batch_size)
        while index.size < self.cache.K:
            index = get_indice_graph(self.cache.A, index)

        structure_inputs = self.cache.A[index][:, index]
        feature_inputs = self.cache.X[index]
        mask = mask[index]
        labels = self.graph.node_label[index[mask]]

        sequence = FullBatchSequence([feature_inputs, structure_inputs, mask],
                                     labels,
                                     device=self.device)
        return sequence
Example #11
0
    def train_sequence(self, index, batch_size=np.inf):
        cache = self.cache
        mask = gf.index_to_mask(index, self.graph.num_nodes)
        index = get_indice_graph(cache.A, index, batch_size)
        while index.size < self.cfg.model.K:
            index = get_indice_graph(cache.A, index)

        A = cache.A[index][:, index]
        X = cache.X[index]
        mask = mask[index]
        labels = self.graph.node_label[index[mask]]

        sequence = FullBatchSequence([X, A],
                                     labels,
                                     out_weight=mask,
                                     device=self.device)
        return sequence
Example #12
0
 def train_loader(self, index):
     labels = self.graph.node_label[index]
     X = self.cache.X[index]
     sequence = FullBatchSequence(X, labels, device=self.data_device)
     return sequence
Example #13
0
    def train(self,
              idx_train,
              idx_val=None,
              pre_train_epochs=100,
              epochs=100,
              early_stopping=None,
              verbose=1,
              save_best=True,
              ckpt_path=None,
              as_model=False,
              monitor='val_accuracy',
              early_stop_metric='val_loss'):

        histories = []
        index_all = tf.range(self.graph.num_nodes, dtype=self.intx)

        # pre train model_q
        self.model = self.model_q
        history = super().train(idx_train,
                                idx_val,
                                epochs=pre_train_epochs,
                                early_stopping=early_stopping,
                                verbose=verbose,
                                save_best=save_best,
                                ckpt_path=ckpt_path,
                                as_model=True,
                                monitor=monitor,
                                early_stop_metric=early_stop_metric)
        histories.append(history)

        label_predict = self.predict(index_all).argmax(1)
        label_predict[idx_train] = self.graph.node_label[idx_train]
        label_predict = tf.one_hot(label_predict,
                                   depth=self.graph.num_node_classes)
        # train model_p fitst
        train_sequence = FullBatchSequence(
            [label_predict, self.cache.A, index_all],
            label_predict,
            device=self.device)
        if idx_val is not None:
            val_sequence = FullBatchSequence(
                [label_predict, self.cache.A, idx_val],
                self.cache.label_onehot[idx_val],
                device=self.device)
        else:
            val_sequence = None

        self.model = self.model_p
        history = super().train(train_sequence,
                                val_sequence,
                                epochs=epochs,
                                early_stopping=early_stopping,
                                verbose=verbose,
                                save_best=save_best,
                                ckpt_path=ckpt_path,
                                as_model=as_model,
                                monitor=monitor,
                                early_stop_metric=early_stop_metric)
        histories.append(history)

        # then train model_q again
        label_predict = self.model.predict_on_batch(
            gf.astensors(label_predict,
                         self.cache.A,
                         index_all,
                         device=self.device))

        label_predict = softmax(label_predict)
        if tf.is_tensor(label_predict):
            label_predict = label_predict.numpy()

        label_predict[idx_train] = self.cache.label_onehot[idx_train]

        self.model = self.model_q
        train_sequence = FullBatchSequence(
            [self.cache.X, self.cache.A, index_all],
            label_predict,
            device=self.device)
        history = super().train(train_sequence,
                                idx_val,
                                epochs=epochs,
                                early_stopping=early_stopping,
                                verbose=verbose,
                                save_best=save_best,
                                ckpt_path=ckpt_path,
                                as_model=as_model,
                                monitor=monitor,
                                early_stop_metric=early_stop_metric)

        histories.append(history)

        return histories