Exemplo n.º 1
0
    def train_gcn(self, epoch, labels, idx_train, idx_val):
        args = self.args
        t = time.time()
        self.model.train()
        self.optimizer.zero_grad()
        g = self.generate_g(self.estimator.estimated_adj)
        output = self.model(g, self.estimator2.estimated_feature)
        loss_fcn = torch.nn.CrossEntropyLoss()
        loss_train = loss_fcn(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])
        loss_train.backward()
        self.optimizer.step()

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        self.model.eval()
        with torch.no_grad():
            output = self.model(g, self.estimator2.estimated_feature)

            loss_val = loss_fcn(output[idx_val], labels[idx_val])
            acc_val = accuracy(output[idx_val], labels[idx_val])

        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = g
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print('\t=== saving current graph/gcn, best_val_acc: %s' %
                      self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = g
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' %
                      self.best_val_loss.item())
        del g, output
        if args.debug:
            if epoch % 1 == 0:
                print('Epoch: {:04d}'.format(epoch + 1),
                      'loss_train: {:.4f}'.format(loss_train.item()),
                      'acc_train: {:.4f}'.format(acc_train.item()),
                      'loss_val: {:.4f}'.format(loss_val.item()),
                      'acc_val: {:.4f}'.format(acc_val.item()),
                      'time: {:.4f}s'.format(time.time() - t))

        print('Epoch: {:04d}'.format(epoch + 1),
              'loss_train: {:.4f}'.format(loss_train.item()),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))
Exemplo n.º 2
0
    def train_gcn(self, epoch, features, adj, labels, idx_train, idx_val):
        args = self.args
        estimator = self.estimator
        adj = estimator.normalize()
        t = time.time()
        self.model.train()
        self.optimizer.zero_grad()

        b = sp.coo_matrix(adj.detach().cpu().numpy())
        g = DGLGraph(b).to(self.device)
        g.edata['weight'] = b.data
        output = self.model(g, features)
        loss_fcn = torch.nn.CrossEntropyLoss()
        loss_train = loss_fcn(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])
        loss_train.backward()
        self.optimizer.step()

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        self.model.eval()
        output = self.model(g, features)

        loss_val = loss_fcn(output[idx_val], labels[idx_val])
        acc_val = accuracy(output[idx_val], labels[idx_val])

        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = adj.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print('\t=== saving current graph/gcn, best_val_acc: %s' %
                      self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = adj.detach()
            self.weights = deepcopy(self.model.state_dict())

            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' %
                      self.best_val_loss.item())

        if args.debug:
            if epoch % 1 == 0:
                print('Epoch: {:04d}'.format(epoch + 1),
                      'loss_train: {:.4f}'.format(loss_train.item()),
                      'acc_train: {:.4f}'.format(acc_train.item()),
                      'loss_val: {:.4f}'.format(loss_val.item()),
                      'acc_val: {:.4f}'.format(acc_val.item()),
                      'time: {:.4f}s'.format(time.time() - t))
Exemplo n.º 3
0
    def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)

        best_loss_val = 100
        best_acc_val = 0

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            output = self.forward()
            loss_train = self._loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()
            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(i, loss_train.item()))

            self.eval()
            output = self.forward()
            loss_val = F.nll_loss(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output

            if acc_val > best_acc_val:
                best_acc_val = acc_val
                self.output = output

        print('=== picking the best model according to the performance on validation ===')
Exemplo n.º 4
0
    def get_meta_grad(self, features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training):

        hidden = features
        for ix, w in enumerate(self.weights):
            b = self.biases[ix] if self.with_bias else 0
            if self.sparse_features:
                hidden = adj_norm @ torch.spmm(hidden, w) + b
            else:
                hidden = adj_norm @ hidden @ w + b
            if self.with_relu:
                hidden = F.relu(hidden)

        output = F.log_softmax(hidden, dim=1)

        loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])
        loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled])
        loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled])

        if self.lambda_ == 1:
            attack_loss = loss_labeled
        elif self.lambda_ == 0:
            attack_loss = loss_unlabeled
        else:
            attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled

        print('GCN loss on unlabled data: {}'.format(loss_test_val.item()))
        print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item()))
        print('attack loss: {}'.format(attack_loss.item()))

        adj_grad, feature_grad = None, None
        if self.attack_structure:
            adj_grad = torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0]
        if self.attack_features:
            feature_grad = torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0]
        return adj_grad, feature_grad
Exemplo n.º 5
0
 def test(self, idx_test):
     # output = self.forward()
     output = self.output
     loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
     acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
     print("Test set results:",
           "loss= {:.4f}".format(loss_test.item()),
           "accuracy= {:.4f}".format(acc_test.item()))
Exemplo n.º 6
0
 def test(self, features, labels, idx_test):
     print("\t=== testing ===")
     self.model.eval()
     adj = self.best_graph
     if self.best_graph is None:
         adj = self.estimator.normalize()
     output = self.model(features, adj)
     loss_test = F.nll_loss(output[idx_test], labels[idx_test])
     acc_test = accuracy(output[idx_test], labels[idx_test])
     print("\tTest set results:", "loss= {:.4f}".format(loss_test.item()),
           "accuracy= {:.4f}".format(acc_test.item()))
Exemplo n.º 7
0
 def test(self, features, labels, idx_test):
     print("\t=== testing ===")
     self.model.eval()
     adj = self.best_graph
     if self.best_graph is None:
         adj = self.estimator.normalize()
     b = sp.coo_matrix(adj.detach().cpu().numpy())
     g = DGLGraph(b).to(self.device)
     g.edata['weight'] = b.data
     output = self.model(g, features)
     loss_fcn = torch.nn.CrossEntropyLoss()
     loss_test = loss_fcn(output[idx_test], labels[idx_test])
     acc_test = accuracy(output[idx_test], labels[idx_test])
     print("\tTest set results:", "loss= {:.4f}".format(loss_test.item()),
           "accuracy= {:.4f}".format(acc_test.item()))
     wandb.log({
         'Test_accuracy': acc_test.item(),
         'Test_loss': loss_test.item()
     })
Exemplo n.º 8
0
    def test(self, features, labels, idx_test):
        print("\t=== testing ===")
        self.model.eval()
        g = self.best_graph
        features = self.best_feature
        args = self.args

        if self.best_graph is None:
            #adj = self.estimator.normalize()
            adj = self.estimator.estimated_adj
            features = self.estimator2.estimated_feature
            g = self.generate_g(adj)
        with torch.no_grad():
            output = self.model(g, features)
            loss_fcn = torch.nn.CrossEntropyLoss()
            loss_test = loss_fcn(output[idx_test], labels[idx_test])
            acc_test = accuracy(output[idx_test], labels[idx_test])
        print("\tTest set results:", "loss= {:.4f}".format(loss_test.item()),
              "accuracy= {:.4f}".format(acc_test.item()))

        logging.info("Accuracy:" + str(acc_test.data))
Exemplo n.º 9
0
    def inner_train(self, features, modified_adj, idx_train, idx_unlabeled, labels, labels_self_training):
        adj_norm = utils.normalize_adj_tensor(modified_adj)

        for j in range(self.train_iters):
            hidden = features
            for w, b in zip(self.weights, self.biases):
                if self.sparse_features:
                    hidden = adj_norm @ torch.spmm(hidden, w) + b
                else:
                    hidden = adj_norm @ hidden @ w + b
                if self.with_relu:
                    hidden = F.relu(hidden)

            output = F.log_softmax(hidden, dim=1)
            loss_labeled = F.nll_loss(output[idx_train], labels[idx_train])
            loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled])

            if self.lambda_ == 1:
                attack_loss = loss_labeled
            elif self.lambda_ == 0:
                attack_loss = loss_unlabeled
            else:
                attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled

            self.optimizer.zero_grad()
            loss_labeled.backward(retain_graph=True)
            self.optimizer.step()

            if self.attack_structure:
                self.adj_changes.grad.zero_()
                self.adj_grad_sum += torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0]
            if self.attack_features:
                self.feature_changes.grad.zero_()
                self.feature_grad_sum += torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0]

        loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled])
        print('GCN loss on unlabled data: {}'.format(loss_test_val.item()))
        print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item()))
Exemplo n.º 10
0
    def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose):
        if verbose:
            print('=== training gcn model ===')
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

        best_loss_val = 100
        best_acc_val = 0

        for i in range(train_iters):
            self.train()
            optimizer.zero_grad()
            output = self.forward(self.features, self.adj_norm)
            loss_train = F.nll_loss(output[idx_train], labels[idx_train])
            loss_train.backward()
            optimizer.step()

            if verbose and i % 10 == 0:
                print('Epoch {}, training loss: {}'.format(i, loss_train.item()))

            self.eval()
            output = self.forward(self.features, self.adj_norm)
            loss_val = F.nll_loss(output[idx_val], labels[idx_val])
            acc_val = utils.accuracy(output[idx_val], labels[idx_val])

            if best_loss_val > loss_val:
                best_loss_val = loss_val
                self.output = output
                weights = deepcopy(self.state_dict())

            if acc_val > best_acc_val:
                best_acc_val = acc_val
                self.output = output
                weights = deepcopy(self.state_dict())

        if verbose:
            print('=== picking the best model according to the performance on validation ===')
        self.load_state_dict(weights)
Exemplo n.º 11
0
    def train_adj(self, epoch, features, adj, labels, idx_train, idx_val):
        estimator = self.estimator
        args = self.args
        if args.debug:
            print("\n=== train_adj ===")
        t = time.time()
        estimator.train()
        self.optimizer_adj.zero_grad()

        loss_l1 = torch.norm(estimator.estimated_adj, 1)
        loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro')
        normalized_adj = estimator.normalize()

        if args.lambda_:
            loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj,
                                                      features)
        else:
            loss_smooth_feat = 0 * loss_l1

        output = self.model(features, normalized_adj)
        loss_gcn = F.nll_loss(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])

        loss_symmetric = torch.norm(estimator.estimated_adj \
                        - estimator.estimated_adj.t(), p="fro")

        loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat + args.phi * loss_symmetric

        loss_diffiential.backward()

        self.optimizer_adj.step()
        loss_nuclear = 0 * loss_fro
        if args.beta != 0:
            self.optimizer_nuclear.zero_grad()
            self.optimizer_nuclear.step()
            loss_nuclear = prox_operators.nuclear_norm

        self.optimizer_l1.zero_grad()
        self.optimizer_l1.step()

        total_loss = loss_fro \
                    + args.gamma * loss_gcn \
                    + args.alpha * loss_l1 \
                    + args.beta * loss_nuclear \
                    + args.phi * loss_symmetric

        estimator.estimated_adj.data.copy_(
            torch.clamp(estimator.estimated_adj.data, min=0, max=1))

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        self.model.eval()
        normalized_adj = estimator.normalize()
        output = self.model(features, normalized_adj)

        loss_val = F.nll_loss(output[idx_val], labels[idx_val])
        acc_val = accuracy(output[idx_val], labels[idx_val])
        print('Epoch: {:04d}'.format(epoch + 1),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))

        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_acc: %s' %
                      self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' %
                      self.best_val_loss.item())

        if args.debug:
            if epoch % 1 == 0:
                print(
                    'Epoch: {:04d}'.format(epoch + 1),
                    'loss_fro: {:.4f}'.format(loss_fro.item()),
                    'loss_gcn: {:.4f}'.format(loss_gcn.item()),
                    'loss_feat: {:.4f}'.format(loss_smooth_feat.item()),
                    'loss_symmetric: {:.4f}'.format(loss_symmetric.item()),
                    'delta_l1_norm: {:.4f}'.format(
                        torch.norm(estimator.estimated_adj - adj, 1).item()),
                    'loss_l1: {:.4f}'.format(loss_l1.item()),
                    'loss_total: {:.4f}'.format(total_loss.item()),
                    'loss_nuclear: {:.4f}'.format(loss_nuclear.item()))
Exemplo n.º 12
0
    def train_feat(self, epoch, features, adj, labels, idx_train, idx_val):

        args = self.args
        if args.debug:
            print("\n === This the train_feature===")
        t = time.time()
        self.estimator2.train()
        self.optimizer_feat.zero_grad()
        loss_fro = torch.norm(self.estimator2.estimated_feature - features,
                              p='fro')
        g = self.generate_g(adj)
        output = self.model(g, self.estimator2.estimated_feature)
        loss_fcn = torch.nn.CrossEntropyLoss()
        loss_gcn = loss_fcn(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])
        if args.method == "smooth":
            loss_feat = self.feature_smoothing(
                adj, self.estimator2.estimated_feature)

        loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_feat

        loss_diffiential.backward()
        self.optimizer_feat.step()

        total_loss = loss_fro \
                 + args.gamma * loss_gcn \
                  + args.lambda_ * loss_feat

        self.model.eval()
        with torch.no_grad():
            output = self.model(g, self.estimator2.estimated_feature)

            loss_val = loss_fcn(output[idx_val], labels[idx_val])
            acc_val = accuracy(output[idx_val], labels[idx_val])
        print('Epoch: {:04d}'.format(epoch + 1),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))
        torch.cuda.empty_cache()
        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = g
            self.best_feature = self.estimator2.estimated_feature.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_acc: %s' %
                      self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = g
            self.best_feature = self.estimator2.estimated_feature.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' %
                      self.best_val_loss.item())
        if args.debug:
            if epoch % 1 == 0:
                print('Epoch: {:04d}'.format(epoch + 1),
                      'loss_fro: {:.4f}'.format(loss_fro.item()),
                      'loss_gcn: {:.4f}'.format(loss_gcn.item()),
                      'loss_feat: {:.4f}'.format(loss_feat.item()),
                      'loss_total: {:.4f}'.format(total_loss.item()))
Exemplo n.º 13
0
    def train_adj(self, epoch, features, adj, labels, idx_train, idx_val):
        """

        :param epoch:
        :param features:
        :param adj:
        :param labels:
        :param idx_train:
        :param idx_val:
        :return:
        """
        estimator = self.estimator
        estimator2 = self.estimator2
        args = self.args
        if args.debug:
            print("\n=== train_adj ===")
        t = time.time()
        estimator.train()
        self.optimizer_adj.zero_grad()
        # loss_l1 = torch.norm(estimator.estimated_adj, 1)
        loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro')
        g = self.generate_g(estimator.estimated_adj)
        if args.lambda_:
            loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj,
                                                      features)
        else:
            loss_smooth_feat = 0
        print(features.dtype)
        output = self.model(g, features)
        loss_fcn = torch.nn.CrossEntropyLoss()
        loss_gcn = loss_fcn(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])
        loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat
        loss_diffiential.backward()
        self.optimizer_adj.step()

        estimator.estimated_adj.data.copy_(
            torch.clamp(estimator.estimated_adj.data, min=0, max=1))

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        self.model.eval()

        g = self.generate_g(estimator.estimated_adj)
        with torch.no_grad():
            output = self.model(g, features)
        loss_fcn = torch.nn.CrossEntropyLoss()
        loss_val = loss_fcn(output[idx_val], labels[idx_val])
        acc_val = accuracy(output[idx_val], labels[idx_val])
        print('Epoch: {:04d}'.format(epoch + 1),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))

        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = g
            self.best_feature = estimator2.estimated_feature.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_acc: %s' %
                      self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = g
            self.best_feature = estimator2.estimated_feature.detach()
            self.weights = deepcopy(self.model.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' %
                      self.best_val_loss.item())
Exemplo n.º 14
0
    def train_adj(self, epoch, features, adj, labels, idx_train, idx_val):
        estimator = self.estimator
        args = self.args
        if args.debug:
            print("\n=== train_adj ===")
        t = time.time()
        estimator.train()
        self.optimizer_adj.zero_grad()

        loss_l1 = torch.norm(estimator.estimated_adj, 1)
        loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro')
        normalized_adj = estimator.normalize()
        # g = DGLGraph(nx.from_numpy_matrix(normalized_adj.detach().cpu().numpy(), create_using=nx.DiGraph()))

        # netg = nx.from_numpy_matrix(normalized_adj.detach().cpu().numpy(), create_using=nx.MultiGraph)
        # g = DGLGraph().to(self.device)
        # g.from_networkx(netg, edge_attrs=['weight'])
        b = sp.coo_matrix(normalized_adj.detach().cpu().numpy())
        g = DGLGraph(b).to(self.device)
        g.edata['weight'] = b.data

        if args.lambda_:

            loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj,
                                                      features)
        else:
            loss_smooth_feat = 0 * loss_l1

        output = self.model(g, features)
        loss_fcn = torch.nn.CrossEntropyLoss()
        loss_gcn = loss_fcn(output[idx_train], labels[idx_train])
        acc_train = accuracy(output[idx_train], labels[idx_train])

        loss_symmetric = torch.norm(estimator.estimated_adj \
                        - estimator.estimated_adj.t(), p="fro")

        loss_diffiential = loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat + args.phi * loss_symmetric

        loss_diffiential.backward()

        self.optimizer_adj.step()
        loss_nuclear = 0 * loss_fro
        if args.beta != 0:
            self.optimizer_nuclear.zero_grad()
            self.optimizer_nuclear.step()
            loss_nuclear = prox_operators.nuclear_norm

        self.optimizer_l1.zero_grad()
        self.optimizer_l1.step()

        total_loss = loss_fro \
                    + args.gamma * loss_gcn \
                    + args.alpha * loss_l1 \
                    + args.beta * loss_nuclear \
                    + args.phi * loss_symmetric

        estimator.estimated_adj.data.copy_(
            torch.clamp(estimator.estimated_adj.data, min=0, max=1))

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        self.model.eval()
        normalized_adj = estimator.normalize()
        output = self.model(g, features)

        loss_val = loss_fcn(output[idx_val], labels[idx_val])
        acc_val = accuracy(output[idx_val], labels[idx_val])
        print('Epoch: {:04d}'.format(epoch + 1),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))
        # wandb.log({'Epoch':epoch+1, 'acc_train':acc_train.item(),'loss_val':loss_val.item(),'acc_val': acc_val.item(),'time': time.time() - t})
        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.model.state_dict())
            wandb.run.summary["best_accuracy"] = self.best_val_acc
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_acc: %s' %
                      self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = normalized_adj.detach()
            self.weights = deepcopy(self.model.state_dict())
            wandb.run.summary["best_accuracy"] = self.best_val_acc
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' %
                      self.best_val_loss.item())

        if args.debug:
            if epoch % 1 == 0:
                print(
                    'Epoch: {:04d}'.format(epoch + 1),
                    'loss_fro: {:.4f}'.format(loss_fro.item()),
                    'loss_gcn: {:.4f}'.format(loss_gcn.item()),
                    'loss_feat: {:.4f}'.format(loss_smooth_feat.item()),
                    'loss_symmetric: {:.4f}'.format(loss_symmetric.item()),
                    'delta_l1_norm: {:.4f}'.format(
                        torch.norm(estimator.estimated_adj - adj, 1).item()),
                    'loss_l1: {:.4f}'.format(loss_l1.item()),
                    'loss_total: {:.4f}'.format(total_loss.item()),
                    'loss_nuclear: {:.4f}'.format(loss_nuclear.item()))