Ejemplo n.º 1
0
def load_perterbued_data(dataset, ptb_rate, ptb_type="meta"):
    if ptb_type == 'meta':
        data = Dataset(root='/tmp/',
                       name=dataset.lower(),
                       setting='nettack',
                       seed=15,
                       require_mask=True)
        data.x, data.y = data.features, data.labels
        if ptb_rate > 0:
            perturbed_data = PrePtbDataset(root='/tmp/',
                                           name=dataset.lower(),
                                           attack_method='meta',
                                           ptb_rate=ptb_rate)
            data.edge_index = perturbed_data.adj
        else:
            data.edge_index = data.adj
        return data

    elif ptb_type == 'random_add':
        data = Dataset(root='/tmp/',
                       name=dataset.lower(),
                       setting='nettack',
                       seed=15,
                       require_mask=True)
        data.x, data.y = data.features, data.labels
        num_edge = data.adj.sum(axis=None) / 2
        attacker = Random()
        attacker.attack(data.adj,
                        n_perturbations=int(ptb_rate * num_edge),
                        type='add')
        data.edge_index = attacker.modified_adj
        return data

    elif ptb_type == 'random_remove':
        data = Dataset(root='/tmp/',
                       name=dataset.lower(),
                       setting='nettack',
                       seed=15,
                       require_mask=True)
        data.x, data.y = data.features, data.labels
        num_edge = data.adj.sum(axis=None) / 2
        attacker = Random()
        attacker.attack(data.adj,
                        n_perturbations=int(ptb_rate * num_edge),
                        type='remove')
        data.edge_index = attacker.modified_adj
        return data

    raise Exception(f"the ptb_type of {ptb_type} has not been implemented")
Ejemplo n.º 2
0
perturbed_adj, features, labels = preprocess(perturbed_adj,
                                             features,
                                             labels,
                                             preprocess_adj=False,
                                             device=device)
labels = labels.cuda()
perturbed_adj = to_sparse(perturbed_adj.cuda())
features = features.cuda()
perturbed_adj = ut.normalize_adj_tensor(perturbed_adj, True)
features[features > 1] = 1
data = load_data(args.data,
                 normalize_feature=args.no_fea_norm,
                 missing_rate=args.missing_rate,
                 cuda=True)
data.x = features
data.y = labels
data.adj = perturbed_adj
data.train_mask, data.val_mask, data.test_mask = idx_train, idx_val, idx_test

nfeat = data.x.size(1)
nclass = int(data.y.max()) + 1
net = getattr(models, args.model)(nfeat,
                                  args.hid,
                                  nclass,
                                  dropout=args.dropout,
                                  nhead=args.nhead,
                                  nlayer=args.nlayer,
                                  norm_mode=args.norm_mode,
                                  norm_scale=args.norm_scale,
                                  residual=args.residual)