Example #1
0
def load_perterbued_data(dataset, ptb_rate, ptb_type="meta"):
    if ptb_type == 'meta':
        data = Dataset(root='/tmp/',
                       name=dataset.lower(),
                       setting='nettack',
                       seed=15,
                       require_mask=True)
        data.x, data.y = data.features, data.labels
        if ptb_rate > 0:
            perturbed_data = PrePtbDataset(root='/tmp/',
                                           name=dataset.lower(),
                                           attack_method='meta',
                                           ptb_rate=ptb_rate)
            data.edge_index = perturbed_data.adj
        else:
            data.edge_index = data.adj
        return data

    elif ptb_type == 'random_add':
        data = Dataset(root='/tmp/',
                       name=dataset.lower(),
                       setting='nettack',
                       seed=15,
                       require_mask=True)
        data.x, data.y = data.features, data.labels
        num_edge = data.adj.sum(axis=None) / 2
        attacker = Random()
        attacker.attack(data.adj,
                        n_perturbations=int(ptb_rate * num_edge),
                        type='add')
        data.edge_index = attacker.modified_adj
        return data

    elif ptb_type == 'random_remove':
        data = Dataset(root='/tmp/',
                       name=dataset.lower(),
                       setting='nettack',
                       seed=15,
                       require_mask=True)
        data.x, data.y = data.features, data.labels
        num_edge = data.adj.sum(axis=None) / 2
        attacker = Random()
        attacker.attack(data.adj,
                        n_perturbations=int(ptb_rate * num_edge),
                        type='remove')
        data.edge_index = attacker.modified_adj
        return data

    raise Exception(f"the ptb_type of {ptb_type} has not been implemented")
Example #2
0
data = Dataset(root='/tmp/', name=args.dataset, setting='nettack')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

if args.attack == 'no':
    perturbed_adj = adj

if args.attack == 'random':
    from deeprobust.graph.global_attack import Random
    attacker = Random()
    n_perturbations = int(args.ptb_rate * (adj.sum() // 2))
    perturbed_adj = attacker.attack(adj, n_perturbations, type='add')

if args.attack == 'meta' or args.attack == 'nettack':
    perturbed_data = PrePtbDataset(root='/tmp/',
                                   name=args.dataset,
                                   attack_method=args.attack,
                                   ptb_rate=args.ptb_rate)
    perturbed_adj = perturbed_data.adj
    if args.attack == 'nettack':
        idx_test = perturbed_data.target_nodes

np.random.seed(args.seed)
torch.manual_seed(args.seed)

model = GCN(nfeat=features.shape[1],
            nhid=args.hidden,
            nclass=labels.max().item() + 1,
            dropout=args.dropout,
            device=device)

perturbed_adj, features, labels = preprocess(perturbed_adj,
Example #3
0
                A[i] = 0
                # A[n2, n1] = 0
                removed_cnt += 1

    return removed_cnt


if __name__ == "__main__":
    from deeprobust.graph.data import PrePtbDataset, Dataset
    # load clean graph data
    dataset_str = 'pubmed'
    data = Dataset(root='/tmp/', name=dataset_str, seed=15)
    adj, features, labels = data.adj, data.features, data.labels
    idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
    # load perturbed graph data
    perturbed_data = PrePtbDataset(root='/tmp/', name=dataset_str)
    perturbed_adj = perturbed_data.adj
    # train defense model
    print("Test GCNJaccard")
    model = GCNJaccard(nfeat=features.shape[1],
                       nhid=16,
                       nclass=labels.max().item() + 1,
                       binary_feature=False,
                       dropout=0.5,
                       device='cuda').to('cuda')
    model.fit(features,
              perturbed_adj,
              labels,
              idx_train,
              idx_val,
              threshold=0.1)
Example #4
0
# we need to set the random seed to be the same as that when you generate the perturbed graph
# data = Dataset(root='/tmp/', name=args.dataset, setting='nettack', seed=15)
# Or we can just use setting='prognn' to get the splits
data = Dataset(
    root=r'D:\Python Project\defense\Low_pass_defense\fold_defense\tmp\\',
    name=args.dataset,
    setting='prognn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test

# load pre-attacked graph by Zugner: https://github.com/danielzuegner/gnn-meta-attack
print('==================')
print('=== load graph perturbed by Zugner metattack (under seed 15) ===')
perturbed_data = PrePtbDataset(
    root=r'D:\Python Project\defense\Low_pass_defense\fold_defense\tmp\\',
    name=args.dataset,
    attack_method='meta',
    ptb_rate=args.ptb_rate)
perturbed_adj = perturbed_data.adj

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Setup GCN Model
model = GCN(nfeat=features.shape[1],
            nhid=16,
            nclass=labels.max() + 1,
            device=device)
model = model.to(device)
Example #5
0
    def __init__(self, dataset, args, data_path="data", task_type="full"):
        self.dataset = dataset
        self.data_path = data_path
        (self.adj, self.train_adj, self.features, self.train_features,
         self.labels, self.idx_train, self.idx_val, self.idx_test, self.degree,
         self.learning_type) = data_loader(dataset,
                                           data_path,
                                           "NoNorm",
                                           False,
                                           task_type,
                                           seed=args.seed)

        if args.ptb_rate > 0:
            # need to install deeprobust: https://github.com/DSE-MSU/DeepRobust
            from deeprobust.graph.data import Dataset, PrePtbDataset
            data = Dataset(root='/tmp/',
                           name=args.dataset,
                           setting='nettack',
                           seed=15)
            self.adj, self.features, self.labels = data.adj, data.features.todense(
            ), data.labels
            self.idx_train, self.idx_val, self.idx_test = data.idx_train, data.idx_val, data.idx_test
            if args.ptb_rate != 10:
                perturbed_data = PrePtbDataset(root='/tmp/',
                                               name=args.dataset,
                                               attack_method='meta',
                                               ptb_rate=args.ptb_rate)
                self.adj = perturbed_data.adj

        self.train_adj = self.adj
        self.train_features = self.features
        self.learning_type = 'transductive'
        self.labels = self.labels.astype(np.int)

        self.features = torch.FloatTensor(self.features).float()
        self.train_features = torch.FloatTensor(self.train_features).float()
        # self.train_adj = self.train_adj.tocsr()

        if args.train_size and not args.fastmode:
            self.idx_train, self.idx_val, self.idx_test = get_splits_each_class(
                labels=self.labels, train_size=args.train_size)
            # print(self.idx_train[:10])
            # from ssl_utils import get_few_labeled_splits
            # self.idx_train, self.idx_val, self.idx_test = get_few_labeled_splits(
            #         labels=self.labels, train_size=args.train_size)

        if args.fastmode:
            from deeprobust.graph.utils import get_train_test
            self.idx_train, self.idx_test = get_train_test(
                nnodes=self.adj.shape[0],
                test_size=1 - args.label_rate,
                stratify=self.labels)
            self.idx_test = self.idx_test[:1000]

        self.labels_torch = torch.LongTensor(self.labels)
        self.idx_train_torch = torch.LongTensor(self.idx_train)
        self.idx_val_torch = torch.LongTensor(self.idx_val)
        self.idx_test_torch = torch.LongTensor(self.idx_test)
        # vertex_sampler cache
        # where return a tuple
        self.pos_train_idx = np.where(self.labels[self.idx_train] == 1)[0]
        self.neg_train_idx = np.where(self.labels[self.idx_train] == 0)[0]
        # self.pos_train_neighbor_idx = np.where

        self.nfeat = self.features.shape[1]
        self.nclass = int(self.labels.max().item() + 1)
        self.trainadj_cache = {}
        self.adj_cache = {}
        #print(type(self.train_adj))
        self.degree_p = None