def train_gcn(self, epoch, labels, idx_train, idx_val): args = self.args t = time.time() self.model.train() self.optimizer.zero_grad() g = self.generate_g(self.estimator.estimated_adj) output = self.model(g, self.estimator2.estimated_feature) loss_fcn = torch.nn.CrossEntropyLoss() loss_train = loss_fcn(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train.backward() self.optimizer.step() # Evaluate validation set performance separately, # deactivates dropout during validation run. self.model.eval() with torch.no_grad(): output = self.model(g, self.estimator2.estimated_feature) loss_val = loss_fcn(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) if acc_val > self.best_val_acc: self.best_val_acc = acc_val self.best_graph = g self.weights = deepcopy(self.model.state_dict()) if args.debug: print('\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item()) if loss_val < self.best_val_loss: self.best_val_loss = loss_val self.best_graph = g self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item()) del g, output if args.debug: if epoch % 1 == 0: print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss_train.item()), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t)) print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss_train.item()), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t))
def train_gcn(self, epoch, features, adj, labels, idx_train, idx_val): args = self.args estimator = self.estimator adj = estimator.normalize() t = time.time() self.model.train() self.optimizer.zero_grad() b = sp.coo_matrix(adj.detach().cpu().numpy()) g = DGLGraph(b).to(self.device) g.edata['weight'] = b.data output = self.model(g, features) loss_fcn = torch.nn.CrossEntropyLoss() loss_train = loss_fcn(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train.backward() self.optimizer.step() # Evaluate validation set performance separately, # deactivates dropout during validation run. self.model.eval() output = self.model(g, features) loss_val = loss_fcn(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) if acc_val > self.best_val_acc: self.best_val_acc = acc_val self.best_graph = adj.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print('\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item()) if loss_val < self.best_val_loss: self.best_val_loss = loss_val self.best_graph = adj.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item()) if args.debug: if epoch % 1 == 0: print('Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss_train.item()), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t))
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose): optimizer = optim.Adam(self.parameters(), lr=self.lr) best_loss_val = 100 best_acc_val = 0 for i in range(train_iters): self.train() optimizer.zero_grad() output = self.forward() loss_train = self._loss(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() if verbose and i % 10 == 0: print('Epoch {}, training loss: {}'.format(i, loss_train.item())) self.eval() output = self.forward() loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = utils.accuracy(output[idx_val], labels[idx_val]) if best_loss_val > loss_val: best_loss_val = loss_val self.output = output if acc_val > best_acc_val: best_acc_val = acc_val self.output = output print('=== picking the best model according to the performance on validation ===')
def get_meta_grad(self, features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training): hidden = features for ix, w in enumerate(self.weights): b = self.biases[ix] if self.with_bias else 0 if self.sparse_features: hidden = adj_norm @ torch.spmm(hidden, w) + b else: hidden = adj_norm @ hidden @ w + b if self.with_relu: hidden = F.relu(hidden) output = F.log_softmax(hidden, dim=1) loss_labeled = F.nll_loss(output[idx_train], labels[idx_train]) loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled]) loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled]) if self.lambda_ == 1: attack_loss = loss_labeled elif self.lambda_ == 0: attack_loss = loss_unlabeled else: attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled print('GCN loss on unlabled data: {}'.format(loss_test_val.item())) print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item())) print('attack loss: {}'.format(attack_loss.item())) adj_grad, feature_grad = None, None if self.attack_structure: adj_grad = torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0] if self.attack_features: feature_grad = torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0] return adj_grad, feature_grad
def test(self, idx_test): # output = self.forward() output = self.output loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) print("Test set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item()))
def test(self, features, labels, idx_test): print("\t=== testing ===") self.model.eval() adj = self.best_graph if self.best_graph is None: adj = self.estimator.normalize() output = self.model(features, adj) loss_test = F.nll_loss(output[idx_test], labels[idx_test]) acc_test = accuracy(output[idx_test], labels[idx_test]) print("\tTest set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item()))
def test(self, features, labels, idx_test): print("\t=== testing ===") self.model.eval() adj = self.best_graph if self.best_graph is None: adj = self.estimator.normalize() b = sp.coo_matrix(adj.detach().cpu().numpy()) g = DGLGraph(b).to(self.device) g.edata['weight'] = b.data output = self.model(g, features) loss_fcn = torch.nn.CrossEntropyLoss() loss_test = loss_fcn(output[idx_test], labels[idx_test]) acc_test = accuracy(output[idx_test], labels[idx_test]) print("\tTest set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item())) wandb.log({ 'Test_accuracy': acc_test.item(), 'Test_loss': loss_test.item() })
def test(self, features, labels, idx_test): print("\t=== testing ===") self.model.eval() g = self.best_graph features = self.best_feature args = self.args if self.best_graph is None: #adj = self.estimator.normalize() adj = self.estimator.estimated_adj features = self.estimator2.estimated_feature g = self.generate_g(adj) with torch.no_grad(): output = self.model(g, features) loss_fcn = torch.nn.CrossEntropyLoss() loss_test = loss_fcn(output[idx_test], labels[idx_test]) acc_test = accuracy(output[idx_test], labels[idx_test]) print("\tTest set results:", "loss= {:.4f}".format(loss_test.item()), "accuracy= {:.4f}".format(acc_test.item())) logging.info("Accuracy:" + str(acc_test.data))
def inner_train(self, features, modified_adj, idx_train, idx_unlabeled, labels, labels_self_training): adj_norm = utils.normalize_adj_tensor(modified_adj) for j in range(self.train_iters): hidden = features for w, b in zip(self.weights, self.biases): if self.sparse_features: hidden = adj_norm @ torch.spmm(hidden, w) + b else: hidden = adj_norm @ hidden @ w + b if self.with_relu: hidden = F.relu(hidden) output = F.log_softmax(hidden, dim=1) loss_labeled = F.nll_loss(output[idx_train], labels[idx_train]) loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled]) if self.lambda_ == 1: attack_loss = loss_labeled elif self.lambda_ == 0: attack_loss = loss_unlabeled else: attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled self.optimizer.zero_grad() loss_labeled.backward(retain_graph=True) self.optimizer.step() if self.attack_structure: self.adj_changes.grad.zero_() self.adj_grad_sum += torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0] if self.attack_features: self.feature_changes.grad.zero_() self.feature_grad_sum += torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0] loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled]) print('GCN loss on unlabled data: {}'.format(loss_test_val.item())) print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item()))
def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose): if verbose: print('=== training gcn model ===') optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) best_loss_val = 100 best_acc_val = 0 for i in range(train_iters): self.train() optimizer.zero_grad() output = self.forward(self.features, self.adj_norm) loss_train = F.nll_loss(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() if verbose and i % 10 == 0: print('Epoch {}, training loss: {}'.format(i, loss_train.item())) self.eval() output = self.forward(self.features, self.adj_norm) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = utils.accuracy(output[idx_val], labels[idx_val]) if best_loss_val > loss_val: best_loss_val = loss_val self.output = output weights = deepcopy(self.state_dict()) if acc_val > best_acc_val: best_acc_val = acc_val self.output = output weights = deepcopy(self.state_dict()) if verbose: print('=== picking the best model according to the performance on validation ===') self.load_state_dict(weights)
def train_adj(self, epoch, features, adj, labels, idx_train, idx_val): estimator = self.estimator args = self.args if args.debug: print("\n=== train_adj ===") t = time.time() estimator.train() self.optimizer_adj.zero_grad() loss_l1 = torch.norm(estimator.estimated_adj, 1) loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro') normalized_adj = estimator.normalize() if args.lambda_: loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj, features) else: loss_smooth_feat = 0 * loss_l1 output = self.model(features, normalized_adj) loss_gcn = F.nll_loss(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_symmetric = torch.norm(estimator.estimated_adj \ - estimator.estimated_adj.t(), p="fro") loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat + args.phi * loss_symmetric loss_diffiential.backward() self.optimizer_adj.step() loss_nuclear = 0 * loss_fro if args.beta != 0: self.optimizer_nuclear.zero_grad() self.optimizer_nuclear.step() loss_nuclear = prox_operators.nuclear_norm self.optimizer_l1.zero_grad() self.optimizer_l1.step() total_loss = loss_fro \ + args.gamma * loss_gcn \ + args.alpha * loss_l1 \ + args.beta * loss_nuclear \ + args.phi * loss_symmetric estimator.estimated_adj.data.copy_( torch.clamp(estimator.estimated_adj.data, min=0, max=1)) # Evaluate validation set performance separately, # deactivates dropout during validation run. self.model.eval() normalized_adj = estimator.normalize() output = self.model(features, normalized_adj) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print('Epoch: {:04d}'.format(epoch + 1), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t)) if acc_val > self.best_val_acc: self.best_val_acc = acc_val self.best_graph = normalized_adj.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item()) if loss_val < self.best_val_loss: self.best_val_loss = loss_val self.best_graph = normalized_adj.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item()) if args.debug: if epoch % 1 == 0: print( 'Epoch: {:04d}'.format(epoch + 1), 'loss_fro: {:.4f}'.format(loss_fro.item()), 'loss_gcn: {:.4f}'.format(loss_gcn.item()), 'loss_feat: {:.4f}'.format(loss_smooth_feat.item()), 'loss_symmetric: {:.4f}'.format(loss_symmetric.item()), 'delta_l1_norm: {:.4f}'.format( torch.norm(estimator.estimated_adj - adj, 1).item()), 'loss_l1: {:.4f}'.format(loss_l1.item()), 'loss_total: {:.4f}'.format(total_loss.item()), 'loss_nuclear: {:.4f}'.format(loss_nuclear.item()))
def train_feat(self, epoch, features, adj, labels, idx_train, idx_val): args = self.args if args.debug: print("\n === This the train_feature===") t = time.time() self.estimator2.train() self.optimizer_feat.zero_grad() loss_fro = torch.norm(self.estimator2.estimated_feature - features, p='fro') g = self.generate_g(adj) output = self.model(g, self.estimator2.estimated_feature) loss_fcn = torch.nn.CrossEntropyLoss() loss_gcn = loss_fcn(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) if args.method == "smooth": loss_feat = self.feature_smoothing( adj, self.estimator2.estimated_feature) loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_feat loss_diffiential.backward() self.optimizer_feat.step() total_loss = loss_fro \ + args.gamma * loss_gcn \ + args.lambda_ * loss_feat self.model.eval() with torch.no_grad(): output = self.model(g, self.estimator2.estimated_feature) loss_val = loss_fcn(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print('Epoch: {:04d}'.format(epoch + 1), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t)) torch.cuda.empty_cache() if acc_val > self.best_val_acc: self.best_val_acc = acc_val self.best_graph = g self.best_feature = self.estimator2.estimated_feature.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item()) if loss_val < self.best_val_loss: self.best_val_loss = loss_val self.best_graph = g self.best_feature = self.estimator2.estimated_feature.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item()) if args.debug: if epoch % 1 == 0: print('Epoch: {:04d}'.format(epoch + 1), 'loss_fro: {:.4f}'.format(loss_fro.item()), 'loss_gcn: {:.4f}'.format(loss_gcn.item()), 'loss_feat: {:.4f}'.format(loss_feat.item()), 'loss_total: {:.4f}'.format(total_loss.item()))
def train_adj(self, epoch, features, adj, labels, idx_train, idx_val): """ :param epoch: :param features: :param adj: :param labels: :param idx_train: :param idx_val: :return: """ estimator = self.estimator estimator2 = self.estimator2 args = self.args if args.debug: print("\n=== train_adj ===") t = time.time() estimator.train() self.optimizer_adj.zero_grad() # loss_l1 = torch.norm(estimator.estimated_adj, 1) loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro') g = self.generate_g(estimator.estimated_adj) if args.lambda_: loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj, features) else: loss_smooth_feat = 0 print(features.dtype) output = self.model(g, features) loss_fcn = torch.nn.CrossEntropyLoss() loss_gcn = loss_fcn(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat loss_diffiential.backward() self.optimizer_adj.step() estimator.estimated_adj.data.copy_( torch.clamp(estimator.estimated_adj.data, min=0, max=1)) # Evaluate validation set performance separately, # deactivates dropout during validation run. self.model.eval() g = self.generate_g(estimator.estimated_adj) with torch.no_grad(): output = self.model(g, features) loss_fcn = torch.nn.CrossEntropyLoss() loss_val = loss_fcn(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print('Epoch: {:04d}'.format(epoch + 1), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t)) if acc_val > self.best_val_acc: self.best_val_acc = acc_val self.best_graph = g self.best_feature = estimator2.estimated_feature.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item()) if loss_val < self.best_val_loss: self.best_val_loss = loss_val self.best_graph = g self.best_feature = estimator2.estimated_feature.detach() self.weights = deepcopy(self.model.state_dict()) if args.debug: print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item())
def train_adj(self, epoch, features, adj, labels, idx_train, idx_val): estimator = self.estimator args = self.args if args.debug: print("\n=== train_adj ===") t = time.time() estimator.train() self.optimizer_adj.zero_grad() loss_l1 = torch.norm(estimator.estimated_adj, 1) loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro') normalized_adj = estimator.normalize() # g = DGLGraph(nx.from_numpy_matrix(normalized_adj.detach().cpu().numpy(), create_using=nx.DiGraph())) # netg = nx.from_numpy_matrix(normalized_adj.detach().cpu().numpy(), create_using=nx.MultiGraph) # g = DGLGraph().to(self.device) # g.from_networkx(netg, edge_attrs=['weight']) b = sp.coo_matrix(normalized_adj.detach().cpu().numpy()) g = DGLGraph(b).to(self.device) g.edata['weight'] = b.data if args.lambda_: loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj, features) else: loss_smooth_feat = 0 * loss_l1 output = self.model(g, features) loss_fcn = torch.nn.CrossEntropyLoss() loss_gcn = loss_fcn(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_symmetric = torch.norm(estimator.estimated_adj \ - estimator.estimated_adj.t(), p="fro") loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat + args.phi * loss_symmetric loss_diffiential.backward() self.optimizer_adj.step() loss_nuclear = 0 * loss_fro if args.beta != 0: self.optimizer_nuclear.zero_grad() self.optimizer_nuclear.step() loss_nuclear = prox_operators.nuclear_norm self.optimizer_l1.zero_grad() self.optimizer_l1.step() total_loss = loss_fro \ + args.gamma * loss_gcn \ + args.alpha * loss_l1 \ + args.beta * loss_nuclear \ + args.phi * loss_symmetric estimator.estimated_adj.data.copy_( torch.clamp(estimator.estimated_adj.data, min=0, max=1)) # Evaluate validation set performance separately, # deactivates dropout during validation run. self.model.eval() normalized_adj = estimator.normalize() output = self.model(g, features) loss_val = loss_fcn(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print('Epoch: {:04d}'.format(epoch + 1), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()), 'time: {:.4f}s'.format(time.time() - t)) # wandb.log({'Epoch':epoch+1, 'acc_train':acc_train.item(),'loss_val':loss_val.item(),'acc_val': acc_val.item(),'time': time.time() - t}) if acc_val > self.best_val_acc: self.best_val_acc = acc_val self.best_graph = normalized_adj.detach() self.weights = deepcopy(self.model.state_dict()) wandb.run.summary["best_accuracy"] = self.best_val_acc if args.debug: print(f'\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item()) if loss_val < self.best_val_loss: self.best_val_loss = loss_val self.best_graph = normalized_adj.detach() self.weights = deepcopy(self.model.state_dict()) wandb.run.summary["best_accuracy"] = self.best_val_acc if args.debug: print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item()) if args.debug: if epoch % 1 == 0: print( 'Epoch: {:04d}'.format(epoch + 1), 'loss_fro: {:.4f}'.format(loss_fro.item()), 'loss_gcn: {:.4f}'.format(loss_gcn.item()), 'loss_feat: {:.4f}'.format(loss_smooth_feat.item()), 'loss_symmetric: {:.4f}'.format(loss_symmetric.item()), 'delta_l1_norm: {:.4f}'.format( torch.norm(estimator.estimated_adj - adj, 1).item()), 'loss_l1: {:.4f}'.format(loss_l1.item()), 'loss_total: {:.4f}'.format(total_loss.item()), 'loss_nuclear: {:.4f}'.format(loss_nuclear.item()))