perturbed_adj, labels, idx_train, train_iters=200, verbose=True) # # using validation to pick model # model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True) model.eval() # You can use the inner function of model to test model.test(idx_test) print('==================') print( '=== load graph perturbed by DeepRobust 5% metattack (under seed 15) ===') perturbed_data = PtbDataset(root='/tmp/', name=args.dataset, attack_method='meta') perturbed_adj = perturbed_data.adj model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True) # # using validation to pick model # model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True) model.eval() # You can use the inner function of model to test model.test(idx_test)
print('cuda: %s' % args.cuda) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # make sure you use the same data splits as you generated attacks np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) # load original dataset (to get clean features and labels) data = Dataset(root='/tmp/', name=args.dataset) adj, features, labels = data.adj, data.features, data.labels idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test # load pre-attacked graph perturbed_data = PtbDataset(root='/tmp/', name=args.dataset) perturbed_adj = perturbed_data.adj # Setup RGCN Model model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1], nclass=labels.max() + 1, nhid=64, device=device) model = model.to(device) model.fit(features, perturbed_adj, labels, idx_train,
np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) # Setup GCN Model model = GCN(nfeat=features.shape[1], nhid=16, nclass=labels.max()+1, device=device) model = model.to(device) model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True) # # using validation to pick model # model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True) model.eval() # You can use the inner function of model to test model.test(idx_test) print('=== load graph perturbed by DeepRobust 5% metattack (under seed 15) ===') perturbed_data = PtbDataset(root='/tmp/', name='cora', attack_method='meta') perturbed_adj = perturbed_data.adj model.fit(features, perturbed_adj, labels, idx_train, train_iters=200, verbose=True) # # using validation to pick model # model.fit(features, perturbed_adj, labels, idx_train, idx_val, train_iters=200, verbose=True) model.eval() # You can use the inner function of model to test model.test(idx_test)