def run_fix_mask(args, seed, adj_percent, wei_percent): pruning.setup_seed(seed) adj, features, labels, idx_train, idx_val, idx_test = load_data( args['dataset']) node_num = features.size()[0] class_num = labels.numpy().max() + 1 adj = adj.cuda() features = features.cuda() labels = labels.cuda() loss_func = nn.CrossEntropyLoss() net_gcn = net.net_gcn(embedding_dim=args['embedding_dim'], adj=adj) pruning.add_mask(net_gcn) net_gcn = net_gcn.cuda() pruning.random_pruning(net_gcn, adj_percent, wei_percent) adj_spar, wei_spar = pruning.print_sparsity(net_gcn) for name, param in net_gcn.named_parameters(): if 'mask' in name: param.requires_grad = False optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) acc_test = 0.0 best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0} for epoch in range(args['total_epoch']): optimizer.zero_grad() output = net_gcn(features, adj) loss = loss_func(output[idx_train], labels[idx_train]) loss.backward() optimizer.step() with torch.no_grad(): output = net_gcn(features, adj, val_test=True) acc_val = f1_score(labels[idx_val].cpu().numpy(), output[idx_val].cpu().numpy().argmax(axis=1), average='micro') acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro') if acc_val > best_val_acc['val_acc']: best_val_acc['val_acc'] = acc_val best_val_acc['test_acc'] = acc_test best_val_acc['epoch'] = epoch print( "(Fix Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]" .format(epoch, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'])) return best_val_acc['val_acc'], best_val_acc['test_acc'], best_val_acc[ 'epoch'], adj_spar, wei_spar
def main_fixed_mask(args): device = torch.device("cuda:" + str(args.device)) dataset = PygNodePropPredDataset(name=args.dataset) data = dataset[0] split_idx = dataset.get_idx_split() evaluator = Evaluator(args.dataset) x = data.x.to(device) y_true = data.y.to(device) train_idx = split_idx['train'].to(device) edge_index = data.edge_index.to(device) edge_index = to_undirected(edge_index, data.num_nodes) if args.self_loop: edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0] args.in_channels = data.x.size(-1) args.num_tasks = dataset.num_classes model = DeeperGCN(args).to(device) pruning.add_mask(model, args.num_layers) for name, param in model.named_parameters(): if 'mask' in name: param.requires_grad = False optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) results = {'highest_valid': 0, 'final_train': 0, 'final_test': 0, 'highest_train': 0, 'epoch': 0} start_epoch = 1 for epoch in range(start_epoch, args.epochs + 1): epoch_loss = train_fixed(model, x, edge_index, y_true, train_idx, optimizer, args) result = test(model, x, edge_index, y_true, split_idx, evaluator) train_accuracy, valid_accuracy, test_accuracy = result if valid_accuracy > results['highest_valid']: results['highest_valid'] = valid_accuracy results['final_train'] = train_accuracy results['final_test'] = test_accuracy results['epoch'] = epoch print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'Baseline (FIX Mask) Epoch:[{}/{}]\t LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}]' .format(epoch, args.epochs, epoch_loss, train_accuracy * 100, valid_accuracy * 100, test_accuracy * 100, results['final_test'] * 100, results['epoch'])) print("=" * 120) print("syd final: Baseline, Train:[{:.2f}] Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}]" .format( results['final_train'] * 100, results['highest_valid'] * 100, results['epoch'], results['final_test'] * 100)) print("=" * 120)
def main_fixed_mask(args, imp_num, resume_train_ckpt=None): device = torch.device("cuda:" + str(args.device)) dataset = PygLinkPropPredDataset(name=args.dataset) data = dataset[0] # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128]) split_edge = dataset.get_edge_split() evaluator = Evaluator(args.dataset) x = data.x.to(device) edge_index = data.edge_index.to(device) args.in_channels = data.x.size(-1) args.num_tasks = 1 model = DeeperGCN(args).to(device) pruning.add_mask(model, args) predictor = LinkPredictor(args).to(device) rewind_weight_mask, adj_spar, wei_spar = pruning.resume_change(resume_train_ckpt, model, args) model.load_state_dict(rewind_weight_mask) predictor.load_state_dict(resume_train_ckpt['predictor_state_dict']) # model.load_state_dict(rewind_weight_mask) # predictor.load_state_dict(rewind_predict_weight) adj_spar, wei_spar = pruning.print_sparsity(model, args) for name, param in model.named_parameters(): if 'mask' in name: param.requires_grad = False optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) #results = {} results = {'epoch': 0 } keys = ['highest_valid', 'final_train', 'final_test', 'highest_train'] hits = ['Hits@10', 'Hits@50', 'Hits@100'] for key in keys: results[key] = {k: 0 for k in hits} results['adj_spar'] = adj_spar results['wei_spar'] = wei_spar start_epoch = 1 for epoch in range(start_epoch, args.fix_epochs + 1): t0 = time.time() epoch_loss = train.train_fixed(model, predictor, x, edge_index, split_edge, optimizer, args.batch_size, args) result = train.test(model, predictor, x, edge_index, split_edge, evaluator, args.batch_size, args) # return a tuple k = 'Hits@50' train_result, valid_result, test_result = result[k] if train_result > results['highest_train'][k]: results['highest_train'][k] = train_result if valid_result > results['highest_valid'][k]: results['highest_valid'][k] = valid_result results['final_train'][k] = train_result results['final_test'][k] = test_result results['epoch'] = epoch pruning.save_all(model, predictor, rewind_weight_mask, optimizer, imp_num, epoch, args.model_save_path, 'IMP{}_fixed_ckpt'.format(imp_num)) epoch_time = (time.time() - t0) / 60 print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'IMP:[{}] (FIX Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] Time:[{:.2f}min]' .format(imp_num, epoch, args.fix_epochs, epoch_loss, train_result * 100, valid_result * 100, test_result * 100, results['final_test'][k] * 100, results['epoch'], epoch_time)) print("=" * 120) print("syd final: IMP:[{}], Train:[{:.2f}] Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, results['final_train'][k] * 100, results['highest_valid'][k] * 100, results['epoch'], results['final_test'][k] * 100, results['adj_spar'], results['wei_spar'])) print("=" * 120)
def main_get_mask(args, imp_num, rewind_weight_mask=None, rewind_predict_weight=None, resume_train_ckpt=None): device = torch.device("cuda:" + str(args.device)) dataset = PygLinkPropPredDataset(name=args.dataset) data = dataset[0] # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128]) split_edge = dataset.get_edge_split() evaluator = Evaluator(args.dataset) x = data.x.to(device) edge_index = data.edge_index.to(device) args.in_channels = data.x.size(-1) args.num_tasks = 1 model = DeeperGCN(args).to(device) pruning.add_mask(model, args) predictor = LinkPredictor(args).to(device) pruning.add_trainable_mask_noise(model, args, c=1e-4) optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) results = {'epoch': 0 } keys = ['highest_valid', 'final_train', 'final_test', 'highest_train'] hits = ['Hits@10', 'Hits@50', 'Hits@100'] for key in keys: results[key] = {k: 0 for k in hits} start_epoch = 1 if resume_train_ckpt: start_epoch = resume_train_ckpt['epoch'] rewind_weight_mask = resume_train_ckpt['rewind_weight_mask'] ori_model_dict = model.state_dict() over_lap = {k : v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys()} ori_model_dict.update(over_lap) model.load_state_dict(ori_model_dict) print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format(imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys()))) optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict']) adj_spar, wei_spar = pruning.print_sparsity(model, args) else: rewind_weight_mask = copy.deepcopy(model.state_dict()) rewind_predict_weight = copy.deepcopy(predictor.state_dict()) for epoch in range(start_epoch, args.mask_epochs + 1): t0 = time.time() epoch_loss, prune_info_dict = train.train_mask(model, predictor, x, edge_index, split_edge, optimizer, args) result = train.test(model, predictor, x, edge_index, split_edge, evaluator, args.batch_size, args) k = 'Hits@50' train_result, valid_result, test_result = result[k] if train_result > results['highest_train'][k]: results['highest_train'][k] = train_result if valid_result > results['highest_valid'][k]: results['highest_valid'][k] = valid_result results['final_train'][k] = train_result results['final_test'][k] = test_result results['epoch'] = epoch pruning.save_all(model, predictor, rewind_weight_mask, optimizer, imp_num, epoch, args.model_save_path, 'IMP{}_train_ckpt'.format(imp_num)) epoch_time = (time.time() - t0) / 60 print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'IMP:[{}] (GET Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] | Adj[{:.3f}%] Wei[{:.3f}%] Time:[{:.2f}min]' .format(imp_num, epoch, args.mask_epochs, epoch_loss, train_result * 100, valid_result * 100, test_result * 100, results['final_test'][k] * 100, results['epoch'], prune_info_dict['adj_spar'], prune_info_dict['wei_spar'], epoch_time)) rewind_weight_mask, adj_spar, wei_spar = pruning.change(rewind_weight_mask, model, args) print('-' * 100) print("INFO : IMP:[{}] (GET MASK) Final Result Train:[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Adj:[{:.3f}%] Wei:[{:.3f}%]" .format(imp_num, results['final_train'][k] * 100, results['highest_valid'][k] * 100, results['final_test'][k] * 100, adj_spar, wei_spar)) print('-' * 100) return rewind_weight_mask, rewind_predict_weight
def main_fixed_mask(args, imp_num, final_state_dict=None, resume_train_ckpt=None): device = torch.device("cuda:" + str(args.device)) dataset = OGBNDataset(dataset_name=args.dataset) nf_path = dataset.extract_node_features(args.aggr) args.num_tasks = dataset.num_tasks args.nf_path = nf_path evaluator = Evaluator(args.dataset) criterion = torch.nn.BCEWithLogitsLoss() valid_data_list = [] for i in range(args.num_evals): parts = dataset.random_partition_graph(dataset.total_no_of_nodes, cluster_number=args.valid_cluster_number) valid_data = dataset.generate_sub_graphs(parts, cluster_number=args.valid_cluster_number) valid_data_list.append(valid_data) print("-" * 120) model = DeeperGCN(args).to(device) pruning.add_mask(model) if final_state_dict is not None: pruning.retrain_operation(dataset, model, final_state_dict) adj_spar, wei_spar = pruning.print_sparsity(dataset, model) for name, param in model.named_parameters(): if 'mask' in name: param.requires_grad = False optimizer = optim.Adam(model.parameters(), lr=args.lr) results = {'highest_valid': 0, 'final_train': 0, 'final_test': 0, 'highest_train': 0, 'epoch':0} results['adj_spar'] = adj_spar results['wei_spar'] = wei_spar start_epoch = 1 if resume_train_ckpt: dataset.adj = resume_train_ckpt['adj'] start_epoch = resume_train_ckpt['epoch'] rewind_weight_mask = resume_train_ckpt['rewind_weight_mask'] ori_model_dict = model.state_dict() over_lap = {k : v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys()} ori_model_dict.update(over_lap) model.load_state_dict(ori_model_dict) print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format(imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys()))) optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict']) adj_spar, wei_spar = pruning.print_sparsity(dataset, model) for epoch in range(start_epoch, args.epochs + 1): # do random partition every epoch t0 = time.time() train_parts = dataset.random_partition_graph(dataset.total_no_of_nodes, cluster_number=args.cluster_number) data = dataset.generate_sub_graphs(train_parts, cluster_number=args.cluster_number, ifmask=True) epoch_loss = train.train_fixed(data, dataset, model, optimizer, criterion, device, args) result = train.multi_evaluate(valid_data_list, dataset, model, evaluator, device) train_result = result['train']['rocauc'] valid_result = result['valid']['rocauc'] test_result = result['test']['rocauc'] if valid_result > results['highest_valid']: results['highest_valid'] = valid_result results['final_train'] = train_result results['final_test'] = test_result results['epoch'] = epoch final_state_dict = pruning.save_all(dataset, model, None, optimizer, imp_num, epoch, args.model_save_path, 'IMP{}_fixed_ckpt'.format(imp_num)) epoch_time = (time.time() - t0) / 60 print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'IMP:[{}] (FIX Mask) Epoch[{}/{}] LOSS[{:.4f}] Train[{:.2f}] Valid[{:.2f}] Test[{:.2f}] | Update Test[{:.2f}] at epoch[{}] | Adj[{:.2f}%] Wei[{:.2f}%] Time[{:.2f}min]' .format(imp_num, epoch, args.epochs, epoch_loss, train_result * 100, valid_result * 100, test_result * 100, results['final_test'] * 100, results['epoch'], results['adj_spar'] * 100, results['wei_spar'] * 100, epoch_time)) print("=" * 120) print("INFO final: IMP:[{}], Train:[{:.2f}] Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] | Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, results['final_train'] * 100, results['highest_valid'] * 100, results['epoch'], results['final_test'] * 100, results['adj_spar'] * 100, results['wei_spar'] * 100)) print("=" * 120)
def main_fixed_mask(args, imp_num, adj_percent, wei_percent, resume_train_ckpt=None): device = torch.device("cuda:" + str(args.device)) dataset = PygNodePropPredDataset(name=args.dataset) data = dataset[0] split_idx = dataset.get_idx_split() evaluator = Evaluator(args.dataset) x = data.x.to(device) y_true = data.y.to(device) train_idx = split_idx['train'].to(device) edge_index = data.edge_index.to(device) edge_index = to_undirected(edge_index, data.num_nodes) if args.self_loop: edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0] args.in_channels = data.x.size(-1) args.num_tasks = dataset.num_classes model = DeeperGCN(args).to(device) pruning.add_mask(model, args.num_layers) pruning.random_pruning(model, args, adj_percent, wei_percent) adj_spar, wei_spar = pruning.print_sparsity(model, args) for name, param in model.named_parameters(): if 'mask' in name: param.requires_grad = False optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) results = { 'highest_valid': 0, 'final_train': 0, 'final_test': 0, 'highest_train': 0, 'epoch': 0 } results['adj_spar'] = adj_spar results['wei_spar'] = wei_spar start_epoch = 1 if resume_train_ckpt: start_epoch = resume_train_ckpt['epoch'] ori_model_dict = model.state_dict() over_lap = { k: v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys() } ori_model_dict.update(over_lap) model.load_state_dict(ori_model_dict) print("(RP FIXED MASK) Resume at epoch:[{}] len:[{}/{}]!".format( resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys()))) optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict']) adj_spar, wei_spar = pruning.print_sparsity(model, args) for epoch in range(start_epoch, args.epochs + 1): epoch_loss = train_fixed(model, x, edge_index, y_true, train_idx, optimizer, args) result = test(model, x, edge_index, y_true, split_idx, evaluator) train_accuracy, valid_accuracy, test_accuracy = result if valid_accuracy > results['highest_valid']: results['highest_valid'] = valid_accuracy results['final_train'] = train_accuracy results['final_test'] = test_accuracy results['epoch'] = epoch pruning.save_all(model, None, optimizer, imp_num, epoch, args.model_save_path, 'RP{}_fixed_ckpt'.format(imp_num)) print( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'RP:[{}] (FIX Mask) Epoch:[{}/{}]\t LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}]' .format(imp_num, epoch, args.epochs, epoch_loss, train_accuracy * 100, valid_accuracy * 100, test_accuracy * 100, results['final_test'] * 100, results['epoch'])) print("=" * 120) print( "syd final: RP:[{}], Train:[{:.2f}] Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] Adj:[{:.2f}%] Wei:[{:.2f}%]" .format(imp_num, results['final_train'] * 100, results['highest_valid'] * 100, results['epoch'], results['final_test'] * 100, results['adj_spar'], results['wei_spar'])) print("=" * 120)
def run_get_admm_weight_mask(args, index, wei_percent, seed): adj = np.load("./ADMM/admm_{}/adj_{}.npy".format(args['dataset'], index)) adj = utils.normalize_adj(adj) adj = utils.sparse_mx_to_torch_sparse_tensor(adj) pruning.setup_seed(seed) _, features, labels, idx_train, idx_val, idx_test = load_data( args['dataset']) adj = adj.to_dense() node_num = features.size()[0] class_num = labels.numpy().max() + 1 adj = adj.cuda() features = features.cuda() labels = labels.cuda() loss_func = nn.CrossEntropyLoss() net_gcn = net.net_gcn_baseline(embedding_dim=args['embedding_dim']) pruning.add_mask(net_gcn) net_gcn = net_gcn.cuda() for name, param in net_gcn.named_parameters(): if 'mask' in name: param.requires_grad = False print("NAME:{}\tSHAPE:{}\tGRAD:{}".format(name, param.shape, param.requires_grad)) optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) acc_test = 0.0 best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0} rewind_weight = copy.deepcopy(net_gcn.state_dict()) for epoch in range(args['total_epoch']): optimizer.zero_grad() output = net_gcn(features, adj) loss = loss_func(output[idx_train], labels[idx_train]) loss.backward() optimizer.step() with torch.no_grad(): output = net_gcn(features, adj, val_test=True) acc_val = f1_score(labels[idx_val].cpu().numpy(), output[idx_val].cpu().numpy().argmax(axis=1), average='micro') acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro') if acc_val > best_val_acc['val_acc']: best_val_acc['test_acc'] = acc_test best_val_acc['val_acc'] = acc_val best_val_acc['epoch'] = epoch best_epoch_mask = pruning.get_final_weight_mask_epoch( net_gcn, wei_percent=wei_percent) print( "(ADMM Get Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]" .format(epoch, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'])) return best_epoch_mask, rewind_weight
def run_get_mask(args, seed, imp_num, rewind_weight_mask=None): pruning.setup_seed(seed) adj, features, labels, idx_train, idx_val, idx_test = load_data( args['dataset']) # adj = coo_matrix(adj) # adj_dict = {} # adj_dict['adj'] = adj # torch.save(adj_dict, "./adjs/pubmed/original.pt") # pdb.set_trace() node_num = features.size()[0] class_num = labels.numpy().max() + 1 adj = adj.cuda() features = features.cuda() labels = labels.cuda() loss_func = nn.CrossEntropyLoss() net_gcn = net.net_gcn(embedding_dim=args['embedding_dim'], adj=adj) pruning.add_mask(net_gcn) net_gcn = net_gcn.cuda() if args['weight_dir']: print("load : {}".format(args['weight_dir'])) encoder_weight = {} cl_ckpt = torch.load(args['weight_dir'], map_location='cuda') encoder_weight['weight_orig_weight'] = cl_ckpt['gcn.fc.weight'] ori_state_dict = net_gcn.net_layer[0].state_dict() ori_state_dict.update(encoder_weight) net_gcn.net_layer[0].load_state_dict(ori_state_dict) if rewind_weight_mask: net_gcn.load_state_dict(rewind_weight_mask) pruning.soft_mask_init(net_gcn, args['init_soft_mask_type'], seed) adj_spar, wei_spar = pruning.print_sparsity(net_gcn) else: pruning.soft_mask_init(net_gcn, args['init_soft_mask_type'], seed) optimizer = torch.optim.Adam(net_gcn.parameters(), lr=args['lr'], weight_decay=args['weight_decay']) acc_test = 0.0 best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0} rewind_weight = copy.deepcopy(net_gcn.state_dict()) for epoch in range(args['mask_epoch']): optimizer.zero_grad() output = net_gcn(features, adj) loss = loss_func(output[idx_train], labels[idx_train]) loss.backward() pruning.subgradient_update_mask(net_gcn, args) # l1 norm optimizer.step() with torch.no_grad(): output = net_gcn(features, adj, val_test=True) acc_val = f1_score(labels[idx_val].cpu().numpy(), output[idx_val].cpu().numpy().argmax(axis=1), average='micro') acc_test = f1_score(labels[idx_test].cpu().numpy(), output[idx_test].cpu().numpy().argmax(axis=1), average='micro') if acc_val > best_val_acc['val_acc']: best_val_acc['test_acc'] = acc_test best_val_acc['val_acc'] = acc_val best_val_acc['epoch'] = epoch best_epoch_mask = pruning.get_final_mask_epoch( net_gcn, adj_percent=args['pruning_percent_adj'], wei_percent=args['pruning_percent_wei']) print( "(Get Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]" .format(epoch, acc_val * 100, acc_test * 100, best_val_acc['val_acc'] * 100, best_val_acc['test_acc'] * 100, best_val_acc['epoch'])) return best_epoch_mask, rewind_weight
def main_get_mask(args, imp_num): device = torch.device("cuda:" + str(args.device)) dataset = PygLinkPropPredDataset(name=args.dataset) data = dataset[0] # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128]) split_edge = dataset.get_edge_split() evaluator = Evaluator(args.dataset) x = data.x.to(device) edge_index = data.edge_index.to(device) args.in_channels = data.x.size(-1) args.num_tasks = 1 model = DeeperGCN(args).to(device) pruning.add_mask(model, args) for name, param in model.named_parameters(): if 'mask' in name: param.requires_grad = False predictor = LinkPredictor(args).to(device) optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr) results = {'epoch': 0} keys = ['highest_valid', 'final_train', 'final_test', 'highest_train'] hits = ['Hits@10', 'Hits@50', 'Hits@100'] for key in keys: results[key] = {k: 0 for k in hits} start_epoch = 1 for epoch in range(start_epoch, args.mask_epochs + 1): t0 = time.time() epoch_loss = train.train_fixed(model, predictor, x, edge_index, split_edge, optimizer, args.batch_size, args) result = train.test(model, predictor, x, edge_index, split_edge, evaluator, args.batch_size, args) k = 'Hits@50' train_result, valid_result, test_result = result[k] if train_result > results['highest_train'][k]: results['highest_train'][k] = train_result if valid_result > results['highest_valid'][k]: results['highest_valid'][k] = valid_result results['final_train'][k] = train_result results['final_test'][k] = test_result results['epoch'] = epoch epoch_time = (time.time() - t0) / 60 print( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'IMP:[{}] (GET Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] Time:[{:.2f}min]' .format(imp_num, epoch, args.mask_epochs, epoch_loss, train_result * 100, valid_result * 100, test_result * 100, results['final_test'][k] * 100, results['epoch'], epoch_time)) print('-' * 100) print( "syd : IMP:[{}] (FIX Mask) Final Result Train:[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}]" .format(imp_num, results['final_train'][k] * 100, results['highest_valid'][k] * 100, results['final_test'][k] * 100)) print('-' * 100)
def main_get_mask(args, imp_num, rewind_weight_mask=None, resume_train_ckpt=None): device = torch.device("cuda:" + str(args.device)) dataset = PygNodePropPredDataset(name=args.dataset) data = dataset[0] split_idx = dataset.get_idx_split() evaluator = Evaluator(args.dataset) x = data.x.to(device) y_true = data.y.to(device) train_idx = split_idx['train'].to(device) edge_index = data.edge_index.to(device) edge_index = to_undirected(edge_index, data.num_nodes) if args.self_loop: edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0] args.in_channels = data.x.size(-1) args.num_tasks = dataset.num_classes print("-" * 120) model = DeeperGCN(args).to(device) pruning.add_mask(model, args.num_layers) if rewind_weight_mask: model.load_state_dict(rewind_weight_mask) adj_spar, wei_spar = pruning.print_sparsity(model, args) pruning.add_trainable_mask_noise(model, args, c=1e-5) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) results = { 'highest_valid': 0, 'final_train': 0, 'final_test': 0, 'highest_train': 0, 'epoch': 0 } start_epoch = 1 if resume_train_ckpt: start_epoch = resume_train_ckpt['epoch'] rewind_weight_mask = resume_train_ckpt['rewind_weight_mask'] ori_model_dict = model.state_dict() over_lap = { k: v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys() } ori_model_dict.update(over_lap) model.load_state_dict(ori_model_dict) print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format( imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys()))) optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict']) adj_spar, wei_spar = pruning.print_sparsity(model, args) else: rewind_weight_mask = copy.deepcopy(model.state_dict()) for epoch in range(start_epoch, args.mask_epochs + 1): epoch_loss = train(model, x, edge_index, y_true, train_idx, optimizer, args) result = test(model, x, edge_index, y_true, split_idx, evaluator) train_accuracy, valid_accuracy, test_accuracy = result if valid_accuracy > results['highest_valid']: results['highest_valid'] = valid_accuracy results['final_train'] = train_accuracy results['final_test'] = test_accuracy results['epoch'] = epoch rewind_weight_mask, adj_spar, wei_spar = pruning.get_final_mask_epoch( model, rewind_weight_mask, args) #pruning.save_all(model, rewind_weight_mask, optimizer, imp_num, epoch, args.model_save_path, 'IMP{}_train_ckpt'.format(imp_num)) print( time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' + 'IMP:[{}] (GET Mask) Epoch:[{}/{}]\t LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]' .format(imp_num, epoch, args.mask_epochs, epoch_loss, train_accuracy * 100, valid_accuracy * 100, test_accuracy * 100, results['final_test'] * 100, results['epoch'], adj_spar, wei_spar)) print('-' * 100) print( "INFO : IMP:[{}] (GET MASK) Final Result Train:[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Adj:[{:.2f}%] Wei:[{:.2f}%] " .format(imp_num, results['final_train'] * 100, results['highest_valid'] * 100, results['final_test'] * 100, adj_spar, wei_spar)) print('-' * 100) return rewind_weight_mask