Exemple #1
0
def run_fix_mask(args, seed, adj_percent, wei_percent):

    pruning.setup_seed(seed)
    adj, features, labels, idx_train, idx_val, idx_test = load_data(
        args['dataset'])

    node_num = features.size()[0]
    class_num = labels.numpy().max() + 1

    adj = adj.cuda()
    features = features.cuda()
    labels = labels.cuda()
    loss_func = nn.CrossEntropyLoss()

    net_gcn = net.net_gcn(embedding_dim=args['embedding_dim'], adj=adj)
    pruning.add_mask(net_gcn)
    net_gcn = net_gcn.cuda()
    pruning.random_pruning(net_gcn, adj_percent, wei_percent)

    adj_spar, wei_spar = pruning.print_sparsity(net_gcn)

    for name, param in net_gcn.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(net_gcn.parameters(),
                                 lr=args['lr'],
                                 weight_decay=args['weight_decay'])
    acc_test = 0.0
    best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0}

    for epoch in range(args['total_epoch']):

        optimizer.zero_grad()
        output = net_gcn(features, adj)
        loss = loss_func(output[idx_train], labels[idx_train])
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            output = net_gcn(features, adj, val_test=True)
            acc_val = f1_score(labels[idx_val].cpu().numpy(),
                               output[idx_val].cpu().numpy().argmax(axis=1),
                               average='micro')
            acc_test = f1_score(labels[idx_test].cpu().numpy(),
                                output[idx_test].cpu().numpy().argmax(axis=1),
                                average='micro')
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['val_acc'] = acc_val
                best_val_acc['test_acc'] = acc_test
                best_val_acc['epoch'] = epoch

        print(
            "(Fix Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Final Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]"
            .format(epoch, acc_val * 100, acc_test * 100,
                    best_val_acc['val_acc'] * 100,
                    best_val_acc['test_acc'] * 100, best_val_acc['epoch']))

    return best_val_acc['val_acc'], best_val_acc['test_acc'], best_val_acc[
        'epoch'], adj_spar, wei_spar
Exemple #2
0
def main_fixed_mask(args, imp_num, resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = PygLinkPropPredDataset(name=args.dataset)
    data = dataset[0]
    # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128])
    split_edge = dataset.get_edge_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)

    edge_index = data.edge_index.to(device)

    args.in_channels = data.x.size(-1)
    args.num_tasks = 1

    model = DeeperGCN(args).to(device)
    pruning.add_mask(model, args)
    predictor = LinkPredictor(args).to(device)
    
    rewind_weight_mask, adj_spar, wei_spar = pruning.resume_change(resume_train_ckpt, model, args)
    model.load_state_dict(rewind_weight_mask)
    predictor.load_state_dict(resume_train_ckpt['predictor_state_dict'])

    # model.load_state_dict(rewind_weight_mask)
    # predictor.load_state_dict(rewind_predict_weight)
    adj_spar, wei_spar = pruning.print_sparsity(model, args)

    for name, param in model.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr)
    #results = {}
    results = {'epoch': 0 }
    keys = ['highest_valid', 'final_train', 'final_test', 'highest_train']
    hits = ['Hits@10', 'Hits@50', 'Hits@100']
    
    for key in keys:
        results[key] = {k: 0 for k in hits}
    results['adj_spar'] = adj_spar
    results['wei_spar'] = wei_spar
    
    start_epoch = 1
    
    for epoch in range(start_epoch, args.fix_epochs + 1):

        t0 = time.time()
        epoch_loss = train.train_fixed(model, predictor, x, edge_index, split_edge, optimizer, args.batch_size, args)
        result = train.test(model, predictor, x, edge_index, split_edge, evaluator, args.batch_size, args)
        # return a tuple
        k = 'Hits@50'
        train_result, valid_result, test_result = result[k]

        if train_result > results['highest_train'][k]:
            results['highest_train'][k] = train_result

        if valid_result > results['highest_valid'][k]:
            results['highest_valid'][k] = valid_result
            results['final_train'][k] = train_result
            results['final_test'][k] = test_result
            results['epoch'] = epoch
            pruning.save_all(model, predictor, 
                                    rewind_weight_mask, 
                                    optimizer, 
                                    imp_num, 
                                    epoch, 
                                    args.model_save_path, 
                                    'IMP{}_fixed_ckpt'.format(imp_num))

        epoch_time = (time.time() - t0) / 60
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
              'IMP:[{}] (FIX Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] Time:[{:.2f}min]'
              .format(imp_num, epoch, args.fix_epochs, epoch_loss, train_result * 100,
                                                               valid_result * 100,
                                                               test_result * 100, 
                                                               results['final_test'][k] * 100,
                                                               results['epoch'],
                                                               epoch_time))
    print("=" * 120)
    print("syd final: IMP:[{}], Train:[{:.2f}]  Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] Adj:[{:.2f}%] Wei:[{:.2f}%]"
        .format(imp_num,    results['final_train'][k] * 100,
                            results['highest_valid'][k] * 100,
                            results['epoch'],
                            results['final_test'][k] * 100,
                            results['adj_spar'],
                            results['wei_spar']))
    print("=" * 120)
Exemple #3
0
def main_get_mask(args, imp_num, rewind_weight_mask=None, rewind_predict_weight=None, resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = PygLinkPropPredDataset(name=args.dataset)
    data = dataset[0]
    
    # Data(edge_index=[2, 2358104], edge_weight=[2358104, 1], edge_year=[2358104, 1], x=[235868, 128])
    split_edge = dataset.get_edge_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)

    edge_index = data.edge_index.to(device)

    args.in_channels = data.x.size(-1)
    args.num_tasks = 1

    model = DeeperGCN(args).to(device)
    pruning.add_mask(model, args)
    predictor = LinkPredictor(args).to(device)
    pruning.add_trainable_mask_noise(model, args, c=1e-4)
    optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.lr)

    results = {'epoch': 0 }
    keys = ['highest_valid', 'final_train', 'final_test', 'highest_train']
    hits = ['Hits@10', 'Hits@50', 'Hits@100']

    for key in keys:
        results[key] = {k: 0 for k in hits}

    start_epoch = 1
    if resume_train_ckpt:
        
        start_epoch = resume_train_ckpt['epoch']
        rewind_weight_mask = resume_train_ckpt['rewind_weight_mask']
        ori_model_dict = model.state_dict()
        over_lap = {k : v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys()}
        ori_model_dict.update(over_lap)
        model.load_state_dict(ori_model_dict)
        print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format(imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys())))
        optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict'])
        adj_spar, wei_spar = pruning.print_sparsity(model, args)
    else:
        rewind_weight_mask = copy.deepcopy(model.state_dict())
        rewind_predict_weight = copy.deepcopy(predictor.state_dict())

    for epoch in range(start_epoch, args.mask_epochs + 1):

        t0 = time.time()     
        
        epoch_loss, prune_info_dict = train.train_mask(model, predictor, 
                                                                 x, 
                                                                 edge_index, 
                                                                 split_edge, 
                                                                 optimizer, args)

        result = train.test(model, predictor, 
                                   x, 
                                   edge_index, 
                                   split_edge, 
                                   evaluator, 
                                   args.batch_size, args)

        k = 'Hits@50'
        train_result, valid_result, test_result = result[k]
        if train_result > results['highest_train'][k]:
            results['highest_train'][k] = train_result

        if valid_result > results['highest_valid'][k]:
            results['highest_valid'][k] = valid_result
            results['final_train'][k] = train_result
            results['final_test'][k] = test_result
            results['epoch'] = epoch
            pruning.save_all(model, predictor, 
                                    rewind_weight_mask, 
                                    optimizer, 
                                    imp_num, 
                                    epoch, 
                                    args.model_save_path, 
                                    'IMP{}_train_ckpt'.format(imp_num))

        epoch_time = (time.time() - t0) / 60
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
              'IMP:[{}] (GET Mask) Epoch:[{}/{}] LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] | Adj[{:.3f}%] Wei[{:.3f}%] Time:[{:.2f}min]'
              .format(imp_num, epoch, args.mask_epochs, 
                                      epoch_loss, 
                                      train_result * 100,
                                      valid_result * 100,
                                      test_result * 100,
                                      results['final_test'][k] * 100,
                                      results['epoch'],
                                      prune_info_dict['adj_spar'], 
                                      prune_info_dict['wei_spar'],
                                      epoch_time))

    rewind_weight_mask, adj_spar, wei_spar = pruning.change(rewind_weight_mask, model, args)
    print('-' * 100)
    print("INFO : IMP:[{}] (GET MASK) Final Result Train:[{:.2f}]  Valid:[{:.2f}]  Test:[{:.2f}] | Adj:[{:.3f}%] Wei:[{:.3f}%]"
        .format(imp_num, results['final_train'][k] * 100,
                         results['highest_valid'][k] * 100,
                         results['final_test'][k] * 100,
                         adj_spar, 
                         wei_spar))
    print('-' * 100)
    return rewind_weight_mask, rewind_predict_weight
def main_fixed_mask(args, imp_num, final_state_dict=None, resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = OGBNDataset(dataset_name=args.dataset)
    nf_path = dataset.extract_node_features(args.aggr)

    args.num_tasks = dataset.num_tasks
    args.nf_path = nf_path

    evaluator = Evaluator(args.dataset)
    criterion = torch.nn.BCEWithLogitsLoss()

    valid_data_list = []
    for i in range(args.num_evals):

        parts = dataset.random_partition_graph(dataset.total_no_of_nodes, cluster_number=args.valid_cluster_number)
        valid_data = dataset.generate_sub_graphs(parts, cluster_number=args.valid_cluster_number)
        valid_data_list.append(valid_data)

    print("-" * 120)
    model = DeeperGCN(args).to(device)
    pruning.add_mask(model)

    if final_state_dict is not None:
        
        pruning.retrain_operation(dataset, model, final_state_dict)
        adj_spar, wei_spar = pruning.print_sparsity(dataset, model)

    for name, param in model.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    results = {'highest_valid': 0, 'final_train': 0, 'final_test': 0, 'highest_train': 0, 'epoch':0}
    results['adj_spar'] = adj_spar
    results['wei_spar'] = wei_spar
    
    start_epoch = 1
    if resume_train_ckpt:
        dataset.adj = resume_train_ckpt['adj']
        start_epoch = resume_train_ckpt['epoch']
        rewind_weight_mask = resume_train_ckpt['rewind_weight_mask']
        ori_model_dict = model.state_dict()
        over_lap = {k : v for k, v in resume_train_ckpt['model_state_dict'].items() if k in ori_model_dict.keys()}
        ori_model_dict.update(over_lap)
        model.load_state_dict(ori_model_dict)
        print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format(imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()), len(ori_model_dict.keys())))
        optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict'])
        adj_spar, wei_spar = pruning.print_sparsity(dataset, model)
    
    for epoch in range(start_epoch, args.epochs + 1):
        # do random partition every epoch
        t0 = time.time()
        train_parts = dataset.random_partition_graph(dataset.total_no_of_nodes, cluster_number=args.cluster_number)
        data = dataset.generate_sub_graphs(train_parts, cluster_number=args.cluster_number, ifmask=True)
        epoch_loss = train.train_fixed(data, dataset, model, optimizer, criterion, device, args)
        result = train.multi_evaluate(valid_data_list, dataset, model, evaluator, device)

        train_result = result['train']['rocauc']
        valid_result = result['valid']['rocauc']
        test_result = result['test']['rocauc']

        if valid_result > results['highest_valid']:
            results['highest_valid'] = valid_result
            results['final_train'] = train_result
            results['final_test'] = test_result
            results['epoch'] = epoch
            final_state_dict = pruning.save_all(dataset, 
                                                model, 
                                                None, 
                                                optimizer, 
                                                imp_num, 
                                                epoch, 
                                                args.model_save_path, 
                                                'IMP{}_fixed_ckpt'.format(imp_num))
        epoch_time = (time.time() - t0) / 60
        print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
              'IMP:[{}] (FIX Mask) Epoch[{}/{}] LOSS[{:.4f}] Train[{:.2f}] Valid[{:.2f}] Test[{:.2f}] | Update Test[{:.2f}] at epoch[{}] | Adj[{:.2f}%] Wei[{:.2f}%] Time[{:.2f}min]'
              .format(imp_num, epoch, args.epochs, epoch_loss, train_result * 100,
                                                               valid_result * 100,
                                                               test_result * 100,
                                                               results['final_test'] * 100,
                                                               results['epoch'],
                                                               results['adj_spar'] * 100,
                                                               results['wei_spar'] * 100,
                                                               epoch_time))
    print("=" * 120)
    print("INFO final: IMP:[{}], Train:[{:.2f}]  Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] | Adj:[{:.2f}%] Wei:[{:.2f}%]"
        .format(imp_num,    results['final_train'] * 100,
                            results['highest_valid'] * 100,
                            results['epoch'],
                            results['final_test'] * 100,
                            results['adj_spar'] * 100,
                            results['wei_spar'] * 100))
    print("=" * 120)
def main_fixed_mask(args,
                    imp_num,
                    adj_percent,
                    wei_percent,
                    resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = PygNodePropPredDataset(name=args.dataset)
    data = dataset[0]
    split_idx = dataset.get_idx_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)
    y_true = data.y.to(device)
    train_idx = split_idx['train'].to(device)

    edge_index = data.edge_index.to(device)
    edge_index = to_undirected(edge_index, data.num_nodes)

    if args.self_loop:
        edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0]

    args.in_channels = data.x.size(-1)
    args.num_tasks = dataset.num_classes

    model = DeeperGCN(args).to(device)
    pruning.add_mask(model, args.num_layers)
    pruning.random_pruning(model, args, adj_percent, wei_percent)
    adj_spar, wei_spar = pruning.print_sparsity(model, args)

    for name, param in model.named_parameters():
        if 'mask' in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    results = {
        'highest_valid': 0,
        'final_train': 0,
        'final_test': 0,
        'highest_train': 0,
        'epoch': 0
    }
    results['adj_spar'] = adj_spar
    results['wei_spar'] = wei_spar

    start_epoch = 1
    if resume_train_ckpt:

        start_epoch = resume_train_ckpt['epoch']
        ori_model_dict = model.state_dict()
        over_lap = {
            k: v
            for k, v in resume_train_ckpt['model_state_dict'].items()
            if k in ori_model_dict.keys()
        }
        ori_model_dict.update(over_lap)
        model.load_state_dict(ori_model_dict)
        print("(RP FIXED MASK) Resume at epoch:[{}] len:[{}/{}]!".format(
            resume_train_ckpt['epoch'], len(over_lap.keys()),
            len(ori_model_dict.keys())))
        optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict'])
        adj_spar, wei_spar = pruning.print_sparsity(model, args)

    for epoch in range(start_epoch, args.epochs + 1):

        epoch_loss = train_fixed(model, x, edge_index, y_true, train_idx,
                                 optimizer, args)
        result = test(model, x, edge_index, y_true, split_idx, evaluator)
        train_accuracy, valid_accuracy, test_accuracy = result

        if valid_accuracy > results['highest_valid']:
            results['highest_valid'] = valid_accuracy
            results['final_train'] = train_accuracy
            results['final_test'] = test_accuracy
            results['epoch'] = epoch
            pruning.save_all(model, None, optimizer, imp_num, epoch,
                             args.model_save_path,
                             'RP{}_fixed_ckpt'.format(imp_num))

        print(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
            'RP:[{}] (FIX Mask) Epoch:[{}/{}]\t LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}]'
            .format(imp_num, epoch, args.epochs, epoch_loss, train_accuracy *
                    100, valid_accuracy * 100, test_accuracy *
                    100, results['final_test'] * 100, results['epoch']))
    print("=" * 120)
    print(
        "syd final: RP:[{}], Train:[{:.2f}]  Best Val:[{:.2f}] at epoch:[{}] | Final Test Acc:[{:.2f}] Adj:[{:.2f}%] Wei:[{:.2f}%]"
        .format(imp_num, results['final_train'] * 100,
                results['highest_valid'] * 100, results['epoch'],
                results['final_test'] * 100, results['adj_spar'],
                results['wei_spar']))
    print("=" * 120)
Exemple #6
0
def run_get_mask(args, seed, imp_num, rewind_weight_mask=None):

    pruning.setup_seed(seed)
    adj, features, labels, idx_train, idx_val, idx_test = load_data(
        args['dataset'])
    # adj = coo_matrix(adj)
    # adj_dict = {}
    # adj_dict['adj'] = adj
    # torch.save(adj_dict, "./adjs/pubmed/original.pt")
    # pdb.set_trace()
    node_num = features.size()[0]
    class_num = labels.numpy().max() + 1

    adj = adj.cuda()
    features = features.cuda()
    labels = labels.cuda()
    loss_func = nn.CrossEntropyLoss()

    net_gcn = net.net_gcn(embedding_dim=args['embedding_dim'], adj=adj)
    pruning.add_mask(net_gcn)
    net_gcn = net_gcn.cuda()

    if args['weight_dir']:

        print("load : {}".format(args['weight_dir']))
        encoder_weight = {}
        cl_ckpt = torch.load(args['weight_dir'], map_location='cuda')
        encoder_weight['weight_orig_weight'] = cl_ckpt['gcn.fc.weight']
        ori_state_dict = net_gcn.net_layer[0].state_dict()
        ori_state_dict.update(encoder_weight)
        net_gcn.net_layer[0].load_state_dict(ori_state_dict)

    if rewind_weight_mask:
        net_gcn.load_state_dict(rewind_weight_mask)
        pruning.soft_mask_init(net_gcn, args['init_soft_mask_type'], seed)
        adj_spar, wei_spar = pruning.print_sparsity(net_gcn)
    else:
        pruning.soft_mask_init(net_gcn, args['init_soft_mask_type'], seed)

    optimizer = torch.optim.Adam(net_gcn.parameters(),
                                 lr=args['lr'],
                                 weight_decay=args['weight_decay'])

    acc_test = 0.0
    best_val_acc = {'val_acc': 0, 'epoch': 0, 'test_acc': 0}
    rewind_weight = copy.deepcopy(net_gcn.state_dict())
    for epoch in range(args['mask_epoch']):

        optimizer.zero_grad()
        output = net_gcn(features, adj)
        loss = loss_func(output[idx_train], labels[idx_train])
        loss.backward()
        pruning.subgradient_update_mask(net_gcn, args)  # l1 norm
        optimizer.step()
        with torch.no_grad():
            output = net_gcn(features, adj, val_test=True)
            acc_val = f1_score(labels[idx_val].cpu().numpy(),
                               output[idx_val].cpu().numpy().argmax(axis=1),
                               average='micro')
            acc_test = f1_score(labels[idx_test].cpu().numpy(),
                                output[idx_test].cpu().numpy().argmax(axis=1),
                                average='micro')
            if acc_val > best_val_acc['val_acc']:
                best_val_acc['test_acc'] = acc_test
                best_val_acc['val_acc'] = acc_val
                best_val_acc['epoch'] = epoch
                best_epoch_mask = pruning.get_final_mask_epoch(
                    net_gcn,
                    adj_percent=args['pruning_percent_adj'],
                    wei_percent=args['pruning_percent_wei'])

            print(
                "(Get Mask) Epoch:[{}] Val:[{:.2f}] Test:[{:.2f}] | Best Val:[{:.2f}] Test:[{:.2f}] at Epoch:[{}]"
                .format(epoch, acc_val * 100, acc_test * 100,
                        best_val_acc['val_acc'] * 100,
                        best_val_acc['test_acc'] * 100, best_val_acc['epoch']))

    return best_epoch_mask, rewind_weight
Exemple #7
0
def main_get_mask(args,
                  imp_num,
                  rewind_weight_mask=None,
                  resume_train_ckpt=None):

    device = torch.device("cuda:" + str(args.device))
    dataset = PygNodePropPredDataset(name=args.dataset)
    data = dataset[0]
    split_idx = dataset.get_idx_split()
    evaluator = Evaluator(args.dataset)

    x = data.x.to(device)
    y_true = data.y.to(device)
    train_idx = split_idx['train'].to(device)

    edge_index = data.edge_index.to(device)
    edge_index = to_undirected(edge_index, data.num_nodes)

    if args.self_loop:
        edge_index = add_self_loops(edge_index, num_nodes=data.num_nodes)[0]

    args.in_channels = data.x.size(-1)
    args.num_tasks = dataset.num_classes

    print("-" * 120)
    model = DeeperGCN(args).to(device)
    pruning.add_mask(model, args.num_layers)

    if rewind_weight_mask:
        model.load_state_dict(rewind_weight_mask)
        adj_spar, wei_spar = pruning.print_sparsity(model, args)

    pruning.add_trainable_mask_noise(model, args, c=1e-5)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    results = {
        'highest_valid': 0,
        'final_train': 0,
        'final_test': 0,
        'highest_train': 0,
        'epoch': 0
    }

    start_epoch = 1
    if resume_train_ckpt:

        start_epoch = resume_train_ckpt['epoch']
        rewind_weight_mask = resume_train_ckpt['rewind_weight_mask']
        ori_model_dict = model.state_dict()
        over_lap = {
            k: v
            for k, v in resume_train_ckpt['model_state_dict'].items()
            if k in ori_model_dict.keys()
        }
        ori_model_dict.update(over_lap)
        model.load_state_dict(ori_model_dict)
        print("Resume at IMP:[{}] epoch:[{}] len:[{}/{}]!".format(
            imp_num, resume_train_ckpt['epoch'], len(over_lap.keys()),
            len(ori_model_dict.keys())))
        optimizer.load_state_dict(resume_train_ckpt['optimizer_state_dict'])
        adj_spar, wei_spar = pruning.print_sparsity(model, args)
    else:
        rewind_weight_mask = copy.deepcopy(model.state_dict())

    for epoch in range(start_epoch, args.mask_epochs + 1):

        epoch_loss = train(model, x, edge_index, y_true, train_idx, optimizer,
                           args)
        result = test(model, x, edge_index, y_true, split_idx, evaluator)
        train_accuracy, valid_accuracy, test_accuracy = result

        if valid_accuracy > results['highest_valid']:
            results['highest_valid'] = valid_accuracy
            results['final_train'] = train_accuracy
            results['final_test'] = test_accuracy
            results['epoch'] = epoch
            rewind_weight_mask, adj_spar, wei_spar = pruning.get_final_mask_epoch(
                model, rewind_weight_mask, args)
            #pruning.save_all(model, rewind_weight_mask, optimizer, imp_num, epoch, args.model_save_path, 'IMP{}_train_ckpt'.format(imp_num))

        print(
            time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' | ' +
            'IMP:[{}] (GET Mask) Epoch:[{}/{}]\t LOSS:[{:.4f}] Train :[{:.2f}] Valid:[{:.2f}] Test:[{:.2f}] | Update Test:[{:.2f}] at epoch:[{}] | Adj:[{:.2f}%] Wei:[{:.2f}%]'
            .format(imp_num, epoch, args.mask_epochs, epoch_loss,
                    train_accuracy * 100, valid_accuracy * 100, test_accuracy *
                    100, results['final_test'] *
                    100, results['epoch'], adj_spar, wei_spar))
    print('-' * 100)
    print(
        "INFO : IMP:[{}] (GET MASK) Final Result Train:[{:.2f}]  Valid:[{:.2f}]  Test:[{:.2f}] | Adj:[{:.2f}%] Wei:[{:.2f}%] "
        .format(imp_num, results['final_train'] * 100,
                results['highest_valid'] * 100, results['final_test'] * 100,
                adj_spar, wei_spar))
    print('-' * 100)
    return rewind_weight_mask