def run_fix_mask(args, imp_num, adj_percent, wei_percent, dataset_dict): pruning_gin.setup_seed(args.seed) num_clusters = int(args.num_clusters) dataset = args.dataset batch_size = 1 patience = 50 l2_coef = 0.0 hid_units = 16 # sparse = True sparse = False nonlinearity = 'prelu' # special name to separate parameters adj = dataset_dict['adj'] adj_sparse = dataset_dict['adj_sparse'] features = dataset_dict['features'] labels = dataset_dict['labels'] val_edges = dataset_dict['val_edges'] val_edges_false = dataset_dict['val_edges_false'] test_edges = dataset_dict['test_edges'] test_edges_false = dataset_dict['test_edges_false'] nb_nodes = features.shape[1] ft_size = features.shape[2] g = dgl.DGLGraph() g.add_nodes(nb_nodes) adj = adj.tocoo() g.add_edges(adj.row, adj.col) b_xent = nn.BCEWithLogitsLoss() b_bce = nn.BCELoss() if args.net == 'gin': model = GIC_GIN(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g) pruning_gin.add_mask(model.gcn) pruning_gin.random_pruning(model.gcn, adj_percent, wei_percent) adj_spar, wei_spar = pruning_gin.print_sparsity(model.gcn) elif args.net == 'gat': model = GIC_GAT(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g) g.add_edges(list(range(nb_nodes)), list(range(nb_nodes))) pruning_gat.add_mask(model.gcn) pruning_gat.random_pruning(model.gcn, adj_percent, wei_percent) adj_spar, wei_spar = pruning_gat.print_sparsity(model.gcn) else: assert False for name, param in model.named_parameters(): if 'mask' in name: param.requires_grad = False #print("NAME:{}\tSHAPE:{}\tGRAD:{}".format(name, param.shape, param.requires_grad)) optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=l2_coef) model.cuda() best_val_acc = {'val_acc': 0, 'epoch' : 0, 'test_acc':0} for epoch in range(1, args.fix_epoch + 1): model.train() optimiser.zero_grad() idx = np.random.permutation(nb_nodes) shuf_fts = features[:, idx, :] lbl_1 = torch.ones(batch_size, nb_nodes) lbl_2 = torch.zeros(batch_size, nb_nodes) lbl = torch.cat((lbl_1, lbl_2), 1) shuf_fts = shuf_fts.cuda() lbl = lbl.cuda() logits, logits2 = model(features, shuf_fts, g, sparse, None, None, None, 100) loss = 0.5 * b_xent(logits, lbl) + 0.5 * b_xent(logits2, lbl) loss.backward() optimiser.step() with torch.no_grad(): acc_val, _ = pruning.test(model, features, g, sparse, adj_sparse, val_edges, val_edges_false) acc_test, _ = pruning.test(model, features, g, sparse, adj_sparse, test_edges, test_edges_false) if acc_val > best_val_acc['val_acc']: best_val_acc['test_acc'] = acc_test best_val_acc['val_acc'] = acc_val best_val_acc['epoch'] = epoch print("RP [{}] ({} {} FIX Mask) Epoch:[{}/{}], Loss:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, args.net, args.dataset, epoch, args.fix_epoch, loss, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'], adj_spar, wei_spar)) print("syd final: RP[{}] ({} {} FIX Mask) | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, args.net, args.dataset, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'], adj_spar, wei_spar))
def run_get_mask(args, imp_num, rewind_weight_mask=None): pruning.setup_seed(args['seed']) adj, features, labels, idx_train, idx_val, idx_test = load_data( args['dataset']) adj = load_adj_raw(args['dataset']) node_num = features.size()[0] class_num = labels.numpy().max() + 1 g = dgl.DGLGraph() g.add_nodes(node_num) adj = adj.tocoo() g.add_edges(adj.row, adj.col) features = features.cuda() labels = labels.cuda() loss_func = nn.CrossEntropyLoss() if args['net'] == 'gin': net_gcn = GINNet(args['embedding_dim'], g) pruning_gin.add_mask(net_gcn) elif args['net'] == 'gat': net_gcn = GATNet(args['embedding_dim'], g) g.add_edges(list(range(node_num)), list(range(node_num))) pruning_gat.add_mask(net_gcn) else: assert False net_gcn = net_gcn.cuda() if rewind_weight_mask: net_gcn.load_state_dict(rewind_weight_mask) if args['net'] == 'gin': pruning_gin.add_trainable_mask_noise(net_gcn, c=1e-5) adj_spar, wei_spar = pruning_gin.print_sparsity(net_gcn) else: pruning_gat.add_trainable_mask_noise(net_gcn, c=1e-5) adj_spar, wei_spar = pruning_gat.print_sparsity(net_gcn) optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0} rewind_weight = copy.deepcopy(net_gcn.state_dict()) for epoch in range(args['mask_epoch']): optimizer.zero_grad() output = net_gcn(g, features, 0, 0) loss = loss_func(output[idx_train], labels[idx_train]) loss.backward() if args['net'] == 'gin': pruning_gin.subgradient_update_mask(net_gcn, args) # l1 norm else: pruning_gat.subgradient_update_mask(net_gcn, args) # l1 norm optimizer.step() with torch.no_grad(): net_gcn.eval() output = net_gcn(g, features, 0, 0) acc_val = f1_score(labels[idx_val].cpu().numpy(), output[idx_val].cpu().numpy().argmax(axis=1), average='micro') acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro') if acc_val > best_val_acc['val_acc']: best_val_acc['val_acc'] = acc_val best_val_acc['test_acc'] = acc_test best_val_acc['epoch'] = epoch if args['net'] == 'gin': rewind_weight, adj_spar, wei_spar = pruning_gin.get_final_mask_epoch( net_gcn, rewind_weight, args) else: rewind_weight, adj_spar, wei_spar = pruning_gat.get_final_mask_epoch( net_gcn, rewind_weight, args) print( "IMP[{}] (Get Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, epoch, args['mask_epoch'], loss, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'], adj_spar, wei_spar)) return rewind_weight
def run_fix_mask(args, imp_num, adj_percent, wei_percent): pruning.setup_seed(args['seed']) adj, features, labels, idx_train, idx_val, idx_test = load_data( args['dataset']) adj = load_adj_raw(args['dataset']) node_num = features.size()[0] class_num = labels.numpy().max() + 1 g = dgl.DGLGraph() g.add_nodes(node_num) adj = adj.tocoo() g.add_edges(adj.row, adj.col) features = features.cuda() labels = labels.cuda() loss_func = nn.CrossEntropyLoss() if args['net'] == 'gin': net_gcn = GINNet(args['embedding_dim'], g) pruning_gin.add_mask(net_gcn) pruning_gin.random_pruning(net_gcn, adj_percent, wei_percent) adj_spar, wei_spar = pruning_gin.print_sparsity(net_gcn) elif args['net'] == 'gat': net_gcn = GATNet(args['embedding_dim'], g) g.add_edges(list(range(node_num)), list(range(node_num))) pruning_gat.add_mask(net_gcn) pruning_gat.random_pruning(net_gcn, adj_percent, wei_percent) adj_spar, wei_spar = pruning_gat.print_sparsity(net_gcn) else: assert False net_gcn = net_gcn.cuda() for name, param in net_gcn.named_parameters(): if 'mask' in name: param.requires_grad = False optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0} for epoch in range(args['fix_epoch']): optimizer.zero_grad() output = net_gcn(g, features, 0, 0) loss = loss_func(output[idx_train], labels[idx_train]) loss.backward() optimizer.step() with torch.no_grad(): net_gcn.eval() output = net_gcn(g, features, 0, 0) acc_val = f1_score(labels[idx_val].cpu().numpy(), output[idx_val].cpu().numpy().argmax(axis=1), average='micro') acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro') if acc_val > best_val_acc['val_acc']: best_val_acc['val_acc'] = acc_val best_val_acc['test_acc'] = acc_test best_val_acc['epoch'] = epoch print( "RP[{}] (Fix Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]" .format(imp_num, epoch, args['fix_epoch'], loss, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'])) print( "syd final: [{},{}] RP[{}] (Fix Mask) Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(args['dataset'], args['net'], imp_num, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'], adj_spar, wei_spar))
def run_get_mask(args, imp_num, rewind_weight_mask, dataset_dict): pruning_gin.setup_seed(args.seed) num_clusters = int(args.num_clusters) dataset = args.dataset batch_size = 1 patience = 50 l2_coef = 0.0 hid_units = 16 # sparse = True sparse = False nonlinearity = 'prelu' # special name to separate parameters adj = dataset_dict['adj'] adj_sparse = dataset_dict['adj_sparse'] features = dataset_dict['features'] labels = dataset_dict['labels'] val_edges = dataset_dict['val_edges'] val_edges_false = dataset_dict['val_edges_false'] test_edges = dataset_dict['test_edges'] test_edges_false = dataset_dict['test_edges_false'] nb_nodes = features.shape[1] ft_size = features.shape[2] g = dgl.DGLGraph() g.add_nodes(nb_nodes) adj = adj.tocoo() g.add_edges(adj.row, adj.col) b_xent = nn.BCEWithLogitsLoss() b_bce = nn.BCELoss() if args.net == 'gin': model = GIC_GIN(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g) pruning_gin.add_mask(model.gcn) elif args.net == 'gat': model = GIC_GAT(nb_nodes, ft_size, hid_units, nonlinearity, num_clusters, 100, g) pruning_gat.add_mask(model.gcn) g.add_edges(list(range(nb_nodes)), list(range(nb_nodes))) else: assert False if rewind_weight_mask is not None: model.load_state_dict(rewind_weight_mask) if args.net == 'gin': pruning_gin.add_trainable_mask_noise(model.gcn, c=1e-4) adj_spar, wei_spar = pruning_gin.print_sparsity(model.gcn) else: pruning_gat.add_trainable_mask_noise(model.gcn, c=1e-4) adj_spar, wei_spar = pruning_gat.print_sparsity(model.gcn) optimiser = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=l2_coef) cnt_wait = 0 best = 1e9 best_t = 0 model.cuda() best_val_acc = {'val_acc': 0, 'epoch' : 0, 'test_acc':0} rewind_weight = copy.deepcopy(model.state_dict()) for epoch in range(1, args.mask_epoch + 1): model.train() optimiser.zero_grad() idx = np.random.permutation(nb_nodes) shuf_fts = features[:, idx, :] lbl_1 = torch.ones(batch_size, nb_nodes) lbl_2 = torch.zeros(batch_size, nb_nodes) lbl = torch.cat((lbl_1, lbl_2), 1) shuf_fts = shuf_fts.cuda() lbl = lbl.cuda() logits, logits2 = model(features, shuf_fts, g, sparse, None, None, None, 100) loss = 0.5 * b_xent(logits, lbl) + 0.5 * b_xent(logits2, lbl) loss.backward() if args.net == 'gin': pruning_gin.subgradient_update_mask(model.gcn, args) else: pruning_gat.subgradient_update_mask(model.gcn, args) optimiser.step() with torch.no_grad(): acc_val, _ = pruning.test(model, features, g, sparse, adj_sparse, val_edges, val_edges_false) acc_test, _ = pruning.test(model, features, g, sparse, adj_sparse, test_edges, test_edges_false) if acc_val > best_val_acc['val_acc']: best_val_acc['test_acc'] = acc_test best_val_acc['val_acc'] = acc_val best_val_acc['epoch'] = epoch if args.net == 'gin': rewind_weight, adj_spar, wei_spar = pruning_gin.get_final_mask_epoch(model, rewind_weight, args) else: rewind_weight, adj_spar, wei_spar = pruning_gat.get_final_mask_epoch(model, rewind_weight, args) print("IMP[{}] ({}, {} Get Mask) Epoch:[{}/{}], Loss:[{:.4f}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, args.net, args.dataset, epoch, args.mask_epoch, loss, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'], adj_spar, wei_spar)) return rewind_weight