def train_sequence(self, index): # if the graph is changed? labels = self.cache.label_onehot[index] sequence = FullBatchSequence([self.cache.X, self.cache.A, index], labels, device=self.device) return sequence
def train_sequence(self, index): labels = self.graph.node_label[index] sequence = FullBatchSequence( [self.cache.X, self.cache.A, index], labels, device=self.device) return sequence
def train_sequence(self, index): labels = self.graph.node_label[index] sequence = FullBatchSequence(x=[self.cache.X], y=labels, out_weight=index, device=self.device) return sequence
def train_loader(self, index): labels = self.graph.node_label[index] sequence = FullBatchSequence([self.cache.X, *self.cache.E], labels, out_weight=index, device=self.data_device) return sequence
def train_sequence(self, index): labels = self.graph.node_label[index] sequence = FullBatchSequence([self.cache.X, self.cache.G], labels, out_weight=index, device=self.device, escape=type(self.cache.G)) return sequence
def train_sequence(self, index): labels = self.cache.Y[index] sequence = FullBatchSequence([self.cache.X, self.cache.A], labels, out_weight=index, device=self.device) return sequence
def test_loader(self, index): labels = self.graph.node_label[index] sequence = FullBatchSequence(x=self.cache.X, y=labels, out_weight=index, device=self.data_device) return sequence
def train_sequence(self, index): labels = self.graph.node_label[index] sequence = FullBatchSequence( [self.cache.X, self.cache.edge_index, self.cache.edge_x], labels, out_weight=index, device=self.device) return sequence
def train_sequence(self, index): labels = self.graph.node_label[index] cache = self.cache sequence = FullBatchSequence(x=[cache.X, cache.A, cache.knn_graph, cache.pseudo_labels, cache.node_pairs], y=labels, out_weight=index, device=self.device) return sequence
def train_sequence(self, index, batch_size=np.inf): mask = gf.index_to_mask(index, self.graph.num_nodes) index = get_indice_graph(self.cache.A, index, batch_size) while index.size < self.cache.K: index = get_indice_graph(self.cache.A, index) structure_inputs = self.cache.A[index][:, index] feature_inputs = self.cache.X[index] mask = mask[index] labels = self.graph.node_label[index[mask]] sequence = FullBatchSequence([feature_inputs, structure_inputs, mask], labels, device=self.device) return sequence
def train_sequence(self, index, batch_size=np.inf): cache = self.cache mask = gf.index_to_mask(index, self.graph.num_nodes) index = get_indice_graph(cache.A, index, batch_size) while index.size < self.cfg.model.K: index = get_indice_graph(cache.A, index) A = cache.A[index][:, index] X = cache.X[index] mask = mask[index] labels = self.graph.node_label[index[mask]] sequence = FullBatchSequence([X, A], labels, out_weight=mask, device=self.device) return sequence
def train_loader(self, index): labels = self.graph.node_label[index] X = self.cache.X[index] sequence = FullBatchSequence(X, labels, device=self.data_device) return sequence
def train(self, idx_train, idx_val=None, pre_train_epochs=100, epochs=100, early_stopping=None, verbose=1, save_best=True, ckpt_path=None, as_model=False, monitor='val_accuracy', early_stop_metric='val_loss'): histories = [] index_all = tf.range(self.graph.num_nodes, dtype=self.intx) # pre train model_q self.model = self.model_q history = super().train(idx_train, idx_val, epochs=pre_train_epochs, early_stopping=early_stopping, verbose=verbose, save_best=save_best, ckpt_path=ckpt_path, as_model=True, monitor=monitor, early_stop_metric=early_stop_metric) histories.append(history) label_predict = self.predict(index_all).argmax(1) label_predict[idx_train] = self.graph.node_label[idx_train] label_predict = tf.one_hot(label_predict, depth=self.graph.num_node_classes) # train model_p fitst train_sequence = FullBatchSequence( [label_predict, self.cache.A, index_all], label_predict, device=self.device) if idx_val is not None: val_sequence = FullBatchSequence( [label_predict, self.cache.A, idx_val], self.cache.label_onehot[idx_val], device=self.device) else: val_sequence = None self.model = self.model_p history = super().train(train_sequence, val_sequence, epochs=epochs, early_stopping=early_stopping, verbose=verbose, save_best=save_best, ckpt_path=ckpt_path, as_model=as_model, monitor=monitor, early_stop_metric=early_stop_metric) histories.append(history) # then train model_q again label_predict = self.model.predict_on_batch( gf.astensors(label_predict, self.cache.A, index_all, device=self.device)) label_predict = softmax(label_predict) if tf.is_tensor(label_predict): label_predict = label_predict.numpy() label_predict[idx_train] = self.cache.label_onehot[idx_train] self.model = self.model_q train_sequence = FullBatchSequence( [self.cache.X, self.cache.A, index_all], label_predict, device=self.device) history = super().train(train_sequence, idx_val, epochs=epochs, early_stopping=early_stopping, verbose=verbose, save_best=save_best, ckpt_path=ckpt_path, as_model=as_model, monitor=monitor, early_stop_metric=early_stop_metric) histories.append(history) return histories