示例#1
0
def run_fix_mask(args, imp_num, adj_percent, wei_percent, dataset_dict):

    pruning_gin.setup_seed(args.seed)
    num_clusters = int(args.num_clusters)
    dataset = args.dataset

    batch_size = 1
    patience = 50
    l2_coef = 0.0
    hid_units = 16
    # sparse = True
    sparse = False
    nonlinearity = 'prelu' # special name to separate parameters

    adj = dataset_dict['adj']
    adj_sparse = dataset_dict['adj_sparse']
    features = dataset_dict['features']
    labels = dataset_dict['labels']
    val_edges = dataset_dict['val_edges']
    val_edges_false = dataset_dict['val_edges_false']
    test_edges = dataset_dict['test_edges']
    test_edges_false = dataset_dict['test_edges_false']

    nb_nodes = features.shape[1]
    ft_size = features.shape[2]

    g = dgl.DGLGraph()
    g.add_nodes(nb_nodes)
    adj = adj.tocoo()
    g.add_edges(adj.row, adj.col)

    
    b_xent = nn.BCEWithLogitsLoss()
    b_bce = nn.BCELoss()

    if args.net == 'gin':
        model = GIC_GIN(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g)
        pruning_gin.add_mask(model.gcn)
        pruning_gin.random_pruning(model.gcn, adj_percent, wei_percent)
        adj_spar, wei_spar = pruning_gin.print_sparsity(model.gcn)
    elif args.net == 'gat':
        model = GIC_GAT(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g)
        g.add_edges(list(range(nb_nodes)), list(range(nb_nodes)))
        pruning_gat.add_mask(model.gcn)
        pruning_gat.random_pruning(model.gcn, adj_percent, wei_percent)
        adj_spar, wei_spar = pruning_gat.print_sparsity(model.gcn)
    else: assert False

    for name, param in model.named_parameters():
        if 'mask' in name:
            param.requires_grad = False
            #print("NAME:{}\tSHAPE:{}\tGRAD:{}".format(name, param.shape, param.requires_grad))

    optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=l2_coef)
    model.cuda()
    best_val_acc = {'val_acc': 0, 'epoch' : 0, 'test_acc':0}
    for epoch in range(1, args.fix_epoch + 1):
        model.train()
        optimiser.zero_grad()

        idx = np.random.permutation(nb_nodes)
        shuf_fts = features[:, idx, :]
        lbl_1 = torch.ones(batch_size, nb_nodes)
        lbl_2 = torch.zeros(batch_size, nb_nodes)
        lbl = torch.cat((lbl_1, lbl_2), 1)
        shuf_fts = shuf_fts.cuda()
        lbl = lbl.cuda()

        logits, logits2 = model(features, shuf_fts, 
                                          g, 
                                          sparse, None, None, None, 100) 
        loss = 0.5 * b_xent(logits, lbl) + 0.5 * b_xent(logits2, lbl) 
        loss.backward()
        optimiser.step()

        with torch.no_grad():
            acc_val, _ = pruning.test(model, features, 
                                             g, 
                                             sparse, 
                                             adj_sparse, 
                                             val_edges, 
                                             val_edges_false)
            acc_test, _ = pruning.test(model, features, 
                                              g, 
                                              sparse, 
                                              adj_sparse, 
                                              test_edges, 
                                              test_edges_false)
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['test_acc'] = acc_test
                best_val_acc['val_acc'] = acc_val
                best_val_acc['epoch'] = epoch
                
        print("RP [{}] ({} {} FIX Mask) Epoch:[{}/{}], Loss:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]"
             .format(imp_num,
                            args.net,
                            args.dataset,
                            epoch, 
                            args.fix_epoch, 
                            loss, 
                            acc_val * 100, 
                            acc_test * 100, 
                            best_val_acc['val_acc'] * 100,
                            best_val_acc['test_acc'] * 100,
                            best_val_acc['epoch'],
                            adj_spar,
                            wei_spar))

    print("syd final: RP[{}] ({} {} FIX Mask) | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]"
             .format(imp_num,
                            args.net,
                            args.dataset, 
                            best_val_acc['val_acc'] * 100,
                            best_val_acc['test_acc'] * 100,
                            best_val_acc['epoch'],
                            adj_spar,
                            wei_spar))
示例#2
0
def run_get_mask(args, imp_num, rewind_weight_mask=None):

    pruning.setup_seed(args['seed'])
    adj, features, labels, idx_train, idx_val, idx_test = load_data(
        args['dataset'])
    adj = load_adj_raw(args['dataset'])

    node_num = features.size()[0]
    class_num = labels.numpy().max() + 1

    g = dgl.DGLGraph()
    g.add_nodes(node_num)
    adj = adj.tocoo()

    g.add_edges(adj.row, adj.col)
    features = features.cuda()
    labels = labels.cuda()

    loss_func = nn.CrossEntropyLoss()

    if args['net'] == 'gin':
        net_gcn = GINNet(args['embedding_dim'], g)
        pruning_gin.add_mask(net_gcn)
    elif args['net'] == 'gat':
        net_gcn = GATNet(args['embedding_dim'], g)
        g.add_edges(list(range(node_num)), list(range(node_num)))
        pruning_gat.add_mask(net_gcn)
    else:
        assert False

    net_gcn = net_gcn.cuda()

    if rewind_weight_mask:
        net_gcn.load_state_dict(rewind_weight_mask)

    if args['net'] == 'gin':
        pruning_gin.add_trainable_mask_noise(net_gcn, c=1e-5)
        adj_spar, wei_spar = pruning_gin.print_sparsity(net_gcn)
    else:
        pruning_gat.add_trainable_mask_noise(net_gcn, c=1e-5)
        adj_spar, wei_spar = pruning_gat.print_sparsity(net_gcn)

    optimizer = torch.optim.Adam(net_gcn.parameters(),
                                 lr=args['lr'],
                                 weight_decay=args['weight_decay'])
    best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0}

    rewind_weight = copy.deepcopy(net_gcn.state_dict())

    for epoch in range(args['mask_epoch']):

        optimizer.zero_grad()
        output = net_gcn(g, features, 0, 0)
        loss = loss_func(output[idx_train], labels[idx_train])
        loss.backward()
        if args['net'] == 'gin':
            pruning_gin.subgradient_update_mask(net_gcn, args)  # l1 norm
        else:
            pruning_gat.subgradient_update_mask(net_gcn, args)  # l1 norm

        optimizer.step()
        with torch.no_grad():
            net_gcn.eval()
            output = net_gcn(g, features, 0, 0)
            acc_val = f1_score(labels[idx_val].cpu().numpy(),
                               output[idx_val].cpu().numpy().argmax(axis=1),
                               average='micro')
            acc_test = f1_score(labels[idx_test].cpu().numpy(),
                                output[idx_test].cpu().numpy().argmax(axis=1),
                                average='micro')
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['val_acc'] = acc_val
                best_val_acc['test_acc'] = acc_test
                best_val_acc['epoch'] = epoch

                if args['net'] == 'gin':
                    rewind_weight, adj_spar, wei_spar = pruning_gin.get_final_mask_epoch(
                        net_gcn, rewind_weight, args)
                else:
                    rewind_weight, adj_spar, wei_spar = pruning_gat.get_final_mask_epoch(
                        net_gcn, rewind_weight, args)

        print(
            "IMP[{}] (Get Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]"
            .format(imp_num, epoch, args['mask_epoch'], loss, acc_val * 100,
                    acc_test * 100, best_val_acc['val_acc'] * 100,
                    best_val_acc['test_acc'] * 100, best_val_acc['epoch'],
                    adj_spar, wei_spar))

    return rewind_weight
示例#3
0
def run_fix_mask(args, imp_num, adj_percent, wei_percent):

    pruning.setup_seed(args['seed'])
    adj, features, labels, idx_train, idx_val, idx_test = load_data(
        args['dataset'])
    adj = load_adj_raw(args['dataset'])

    node_num = features.size()[0]
    class_num = labels.numpy().max() + 1

    g = dgl.DGLGraph()
    g.add_nodes(node_num)
    adj = adj.tocoo()
    g.add_edges(adj.row, adj.col)
    features = features.cuda()
    labels = labels.cuda()
    loss_func = nn.CrossEntropyLoss()

    if args['net'] == 'gin':
        net_gcn = GINNet(args['embedding_dim'], g)
        pruning_gin.add_mask(net_gcn)
        pruning_gin.random_pruning(net_gcn, adj_percent, wei_percent)
        adj_spar, wei_spar = pruning_gin.print_sparsity(net_gcn)

    elif args['net'] == 'gat':
        net_gcn = GATNet(args['embedding_dim'], g)
        g.add_edges(list(range(node_num)), list(range(node_num)))
        pruning_gat.add_mask(net_gcn)
        pruning_gat.random_pruning(net_gcn, adj_percent, wei_percent)
        adj_spar, wei_spar = pruning_gat.print_sparsity(net_gcn)

    else:
        assert False

    net_gcn = net_gcn.cuda()
    for name, param in net_gcn.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(net_gcn.parameters(),
                                 lr=args['lr'],
                                 weight_decay=args['weight_decay'])
    best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0}

    for epoch in range(args['fix_epoch']):

        optimizer.zero_grad()
        output = net_gcn(g, features, 0, 0)
        loss = loss_func(output[idx_train], labels[idx_train])
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            net_gcn.eval()
            output = net_gcn(g, features, 0, 0)
            acc_val = f1_score(labels[idx_val].cpu().numpy(),
                               output[idx_val].cpu().numpy().argmax(axis=1),
                               average='micro')
            acc_test = f1_score(labels[idx_test].cpu().numpy(),
                                output[idx_test].cpu().numpy().argmax(axis=1),
                                average='micro')
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['val_acc'] = acc_val
                best_val_acc['test_acc'] = acc_test
                best_val_acc['epoch'] = epoch

        print(
            "RP[{}] (Fix Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]"
            .format(imp_num, epoch, args['fix_epoch'], loss, acc_val * 100,
                    acc_test * 100, best_val_acc['val_acc'] * 100,
                    best_val_acc['test_acc'] * 100, best_val_acc['epoch']))

    print(
        "syd final: [{},{}] RP[{}] (Fix Mask) Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]"
        .format(args['dataset'], args['net'], imp_num,
                best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100,
                best_val_acc['epoch'], adj_spar, wei_spar))
示例#4
0
def run_get_mask(args, imp_num, rewind_weight_mask, dataset_dict):

    pruning_gin.setup_seed(args.seed)
    num_clusters = int(args.num_clusters)
    dataset = args.dataset

    batch_size = 1
    patience = 50
    l2_coef = 0.0
    hid_units = 16
    # sparse = True
    sparse = False
    nonlinearity = 'prelu' # special name to separate parameters

    adj = dataset_dict['adj']
    adj_sparse = dataset_dict['adj_sparse']
    features = dataset_dict['features']
    labels = dataset_dict['labels']
    val_edges = dataset_dict['val_edges']
    val_edges_false = dataset_dict['val_edges_false']
    test_edges = dataset_dict['test_edges']
    test_edges_false = dataset_dict['test_edges_false']

    nb_nodes = features.shape[1]
    ft_size = features.shape[2]
    
    g = dgl.DGLGraph()
    g.add_nodes(nb_nodes)
    adj = adj.tocoo()
    g.add_edges(adj.row, adj.col)
    
    b_xent = nn.BCEWithLogitsLoss()
    b_bce = nn.BCELoss()

    if args.net == 'gin':
        model = GIC_GIN(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g)
        pruning_gin.add_mask(model.gcn)
    elif args.net == 'gat':
        model = GIC_GAT(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g)
        pruning_gat.add_mask(model.gcn)
        g.add_edges(list(range(nb_nodes)), list(range(nb_nodes)))
    else: assert False

    if rewind_weight_mask is not None:
        model.load_state_dict(rewind_weight_mask)

    if args.net == 'gin':
        pruning_gin.add_trainable_mask_noise(model.gcn, c=1e-4)
        adj_spar, wei_spar = pruning_gin.print_sparsity(model.gcn)
    else:
        pruning_gat.add_trainable_mask_noise(model.gcn, c=1e-4)
        adj_spar, wei_spar = pruning_gat.print_sparsity(model.gcn)

    optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=l2_coef)
    cnt_wait = 0
    best = 1e9
    best_t = 0

    model.cuda()
    best_val_acc = {'val_acc': 0, 'epoch' : 0, 'test_acc':0}
    rewind_weight = copy.deepcopy(model.state_dict())
    for epoch in range(1, args.mask_epoch + 1):
        model.train()
        optimiser.zero_grad()

        idx = np.random.permutation(nb_nodes)
        shuf_fts = features[:, idx, :]
        lbl_1 = torch.ones(batch_size, nb_nodes)
        lbl_2 = torch.zeros(batch_size, nb_nodes)
        lbl = torch.cat((lbl_1, lbl_2), 1)
        shuf_fts = shuf_fts.cuda()
        lbl = lbl.cuda()

        logits, logits2 = model(features, shuf_fts, 
                                          g, 
                                          sparse, None, None, None, 100) 
        loss = 0.5 * b_xent(logits, lbl) + 0.5 * b_xent(logits2, lbl) 
        loss.backward()
        if args.net == 'gin':
            pruning_gin.subgradient_update_mask(model.gcn, args)
        else:
            pruning_gat.subgradient_update_mask(model.gcn, args)

        optimiser.step()
        with torch.no_grad():
            acc_val, _ = pruning.test(model, features, 
                                             g, 
                                             sparse, 
                                             adj_sparse, 
                                             val_edges, 
                                             val_edges_false)
            acc_test, _ = pruning.test(model, features, 
                                              g, 
                                              sparse, 
                                              adj_sparse, 
                                              test_edges, 
                                              test_edges_false)
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['test_acc'] = acc_test
                best_val_acc['val_acc'] = acc_val
                best_val_acc['epoch'] = epoch

                if args.net == 'gin':
                    rewind_weight, adj_spar, wei_spar = pruning_gin.get_final_mask_epoch(model, rewind_weight, args) 
                else:
                    rewind_weight, adj_spar, wei_spar = pruning_gat.get_final_mask_epoch(model, rewind_weight, args)

        print("IMP[{}] ({}, {} Get Mask) Epoch:[{}/{}], Loss:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]"
             .format(imp_num,
                            args.net,
                            args.dataset,
                            epoch, 
                            args.mask_epoch, 
                            loss, 
                            acc_val * 100, 
                            acc_test * 100, 
                            best_val_acc['val_acc'] * 100,
                            best_val_acc['test_acc'] * 100,
                            best_val_acc['epoch'],
                            adj_spar,
                            wei_spar))

    return rewind_weight