Esempio n. 1
0
def test(args,meta_model,optimizer,test_loader,train_epoch,return_val=False,inner_steps=10,seed= 0):
    ''' Meta-Testing '''
    mode='Test'
    test_graph_id_local = 0
    test_graph_id_global = 0
    args.resplit = False
    epoch=0
    args.final_test = False
    inner_test_auc_array = None
    inner_test_ap_array = None
    if return_val:
        args.inner_steps = inner_steps
        args.final_test = True
        inner_test_auc_array = np.zeros((len(test_loader)*args.test_batch_size, int(1000/5)))
        inner_test_ap_array = np.zeros((len(test_loader)*args.test_batch_size, int(1000/5)))

    meta_loss = torch.Tensor([0])
    test_avg_auc_list, test_avg_ap_list = [], []
    test_inner_avg_auc_list, test_inner_avg_ap_list = [], []
    for j,data in enumerate(test_loader):
        if args.adamic_adar_baseline:
            # Val Ratio is Fixed at 0.1
            meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio
            data = meta_model.split_edges(data[0],val_ratio=args.meta_val_edge_ratio,\
                    test_ratio=meta_test_edge_ratio)
            G_test = create_nx_graph(data)
            auc, ap = calc_adamic_adar_score(G_test,data.test_pos_edge_index,data.test_neg_edge_index)
            test_avg_auc_list.append(auc)
            test_avg_ap_list.append(ap)
            test_graph_id_global += 1
            continue
        if args.deepwalk_baseline or args.deepwalk_and_mlp:
            # Val Ratio is Fixed at 0.2
            meta_test_edge_ratio = 1 - args.meta_val_edge_ratio - args.meta_train_edge_ratio
            data = meta_model.split_edges(data[0], val_ratio=args.meta_val_edge_ratio, \
                                          test_ratio=meta_test_edge_ratio)
            G = create_nx_graph_deepwalk(data)
            node_vectors, entity2index, index2entity = train_deepwalk_model(G,seed=seed)
            if args.deepwalk_and_mlp:
                early_stopping = EarlyStopping(patience=args.patience, verbose=False)
                input_dim = args.num_features + node_vectors.shape[1]
                mlp = MLPEncoder(args, input_dim,
                                 args.num_channels).to(args.dev)
                mlp_optimizer = torch.optim.Adam(mlp.parameters(),
                                                 lr=args.mlp_lr)
                # node1 = data.x[torch.tensor(list(entity2index.keys())).long()]
                all_node_list = list(range(0, len(data.x)))
                node_order = [entity2index[node_i] for node_i in all_node_list]
                node1 = torch.tensor(node_vectors[node_order])
                for mlp_epochs in range(0, args.epochs):
                    mlp_optimizer.zero_grad()
                    # node_inp = torch.cat([torch.tensor(node_vectors), node1], dim=1)
                    node_inp = torch.cat([data.x, node1], dim=1)
                    node_inp = node_inp.to(args.dev)
                    z = mlp(node_inp, edge_index=None)
                    loss = meta_model.recon_loss(z, data.train_pos_edge_index.cuda())
                    loss.backward()
                    mlp_optimizer.step()
                    if mlp_epochs % 10 == 0:
                        if mlp_epochs % 50 == 0:
                            print("Epoch %d, Loss: %f" %(mlp_epochs, loss))
                        with torch.no_grad():
                            val_auc, val_ap = meta_model.test(z, data.val_pos_edge_index,
                                                 data.val_neg_edge_index)
                        early_stopping(val_auc, meta_model)
                    if early_stopping.early_stop:
                        print("Early stopping for Graph %d | AUC: %f AP: %f" \
                                %(test_graph_id_global, val_auc, val_ap))
                        break

                node_inp = torch.cat([data.x, node1], dim=1)
                # node_inp = torch.cat([torch.tensor(node_vectors), node1], dim=1)
                node_inp = node_inp.to(args.dev)
                node_vectors = mlp(node_inp, edge_index=None)
                auc, ap = meta_model.test(z, data.test_pos_edge_index,
                                     data.test_neg_edge_index)
            else:
                node_vectors = node_vectors.detach().cpu().numpy()
                auc, ap = calc_deepwalk_score(data.test_pos_edge_index,
                                              data.test_neg_edge_index,
                                              node_vectors,entity2index)

            print("Graph %d| Test AUC: %f AP: %f" %(test_graph_id_global, auc, ap))
            test_avg_auc_list.append(auc)
            test_avg_ap_list.append(ap)
            test_graph_id_global += 1
            continue

        if not args.random_baseline and not args.adamic_adar_baseline:
            test_graph_id_local, meta_loss, test_inner_avg_auc_list, test_inner_avg_ap_list = meta_gradient_step(meta_model,\
                    args,data,optimizer,args.inner_steps,args.inner_lr,args.order,test_graph_id_local,mode,\
                    test_inner_avg_auc_list, test_inner_avg_ap_list,epoch,j,False,\
                            inner_test_auc_array,inner_test_ap_array)
        auc_list, ap_list = global_test(args,meta_model,data,OrderedDict(meta_model.named_parameters()))
        test_avg_auc_list.append(sum(auc_list)/len(auc_list))
        test_avg_ap_list.append(sum(ap_list)/len(ap_list))

        ''' Test Logging '''
        if args.comet:
            if len(auc_list) > 0 and len(ap_list) > 0:
                auc_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AUC'
                ap_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AP'
                args.experiment.log_metric(auc_metric,sum(auc_list)/len(auc_list),step=train_epoch)
                args.experiment.log_metric(ap_metric,sum(ap_list)/len(ap_list),step=train_epoch)
        if args.wandb:
            if len(auc_list) > 0 and len(ap_list) > 0:
                auc_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AUC'
                ap_metric = 'Test_Outer_Batch_Graph_' + str(j) +'_AP'
                wandb.log({auc_metric:sum(auc_list)/len(auc_list),\
                        ap_metric:sum(ap_list)/len(ap_list),"x":epoch},commit=False)

    print("Failed on %d graphs" %(args.fail_counter))
    print("Epoch: %d | Test Global Avg Auc %f | Test Global Avg AP %f" \
            %(train_epoch, sum(test_avg_auc_list)/len(test_avg_auc_list),\
                    sum(test_avg_ap_list)/len(test_avg_ap_list)))
    if args.comet:
        if len(auc_list) > 0 and len(ap_list) > 0:
            auc_metric = 'Test_Avg_' +'_AUC'
            ap_metric = 'Test_Avg_' +'_AP'
            inner_auc_metric = 'Test_Inner_Avg' +'_AUC'
            inner_ap_metric = 'Test_Inner_Avg' +'_AP'
            args.experiment.log_metric(auc_metric,sum(test_avg_auc_list)/len(test_avg_auc_list),step=train_epoch)
            args.experiment.log_metric(ap_metric,sum(test_avg_ap_list)/len(test_avg_ap_list),step=train_epoch)
            args.experiment.log_metric(inner_auc_metric,sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list),step=train_epoch)
            args.experiment.log_metric(inner_ap_metric,sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list),step=train_epoch)
    if args.wandb:
        if len(test_avg_auc_list) > 0 and len(test_avg_ap_list) > 0:
            auc_metric = 'Test_Avg' +'_AUC'
            ap_metric = 'Test_Avg' +'_AP'
            wandb.log({auc_metric:sum(test_avg_auc_list)/len(test_avg_auc_list),\
                    ap_metric:sum(test_avg_ap_list)/len(test_avg_ap_list),\
                    "x":train_epoch},commit=False)
        if len(test_inner_avg_auc_list) > 0 and len(test_inner_avg_ap_list) > 0:
            inner_auc_metric = 'Test_Inner_Avg' +'_AUC'
            inner_ap_metric = 'Test_Inner_Avg' +'_AP'
            wandb.log({inner_auc_metric:sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list),
                    inner_ap_metric:sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list),
                    "x":train_epoch},commit=False)
    if len(test_inner_avg_ap_list) > 0:
        print('Epoch {:01d} | Test Inner AUC: {:.4f}, AP: {:.4f}'.format(train_epoch,sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list),sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list)))

    if return_val:
        test_avg_auc = sum(test_avg_auc_list)/len(test_avg_auc_list)
        test_avg_ap = sum(test_avg_ap_list)/len(test_avg_ap_list)
        if len(test_inner_avg_ap_list) > 0:
            test_inner_avg_auc = sum(test_inner_avg_auc_list)/len(test_inner_avg_auc_list)
            test_inner_avg_ap = sum(test_inner_avg_ap_list)/len(test_inner_avg_ap_list)
        #Remove All zero rows
        test_auc_array = inner_test_auc_array[~np.all(inner_test_auc_array == 0, axis=1)]
        test_ap_array = inner_test_ap_array[~np.all(inner_test_ap_array == 0, axis=1)]
        test_aggr_auc = np.sum(test_auc_array,axis=0)/len(test_loader)
        test_aggr_ap = np.sum(test_ap_array,axis=0)/len(test_loader)
        max_auc = np.max(test_aggr_auc)
        max_ap = np.max(test_aggr_ap)
        auc_metric = 'Test_Complete' +'_AUC'
        ap_metric = 'Test_Complete' +'_AP'
        for val_idx in range(0,test_auc_array.shape[1]):
            auc = test_aggr_auc[val_idx]
            ap = test_aggr_ap[val_idx]
            if args.comet:
                args.experiment.log_metric(auc_metric,auc,step=val_idx)
                args.experiment.log_metric(ap_metric,ap,step=val_idx)
            if args.wandb:
                wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx})
        print("Test Max AUC :%f | Test Max AP: %f" %(max_auc,max_ap))

        ''' Save Local final params '''
        if not os.path.exists('../saved_models/'):
            os.makedirs('../saved_models/')
        save_path = '../saved_models/' + args.namestr + '_local.pt'
        torch.save(meta_model.state_dict(), save_path)
        return max_auc, max_ap
Esempio n. 2
0
def main(args):
    assert args.model in ['GAE', 'VGAE']
    kwargs = {'GAE': MyGAE, 'VGAE': MyVGAE}
    kwargs_enc = {'GCN': MetaEncoder, 'FC': MLPEncoder, 'MLP': MetaMLPEncoder,
                  'GraphSignature': MetaSignatureEncoder,
                  'GatedGraphSignature': MetaGatedSignatureEncoder}

    path = osp.join(
        osp.dirname(osp.realpath(__file__)), '..', 'data', args.dataset)
    train_loader, val_loader, test_loader = load_dataset(args.dataset,args)
    meta_model = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev)
    if args.train_only_gs:
        trainable_parameters = []
        for name, p in meta_model.named_parameters():
            if "signature" in name:
                trainable_parameters.append(p)
            else:
                p.requires_grad = False
        optimizer = torch.optim.Adam(trainable_parameters, lr=args.meta_lr)
    else:
        optimizer = torch.optim.Adam(meta_model.parameters(), lr=args.meta_lr)

    total_loss = 0
    if not args.do_kl_anneal:
        args.kl_anneal = 1

    if args.encoder == 'GraphSignature' or args.encoder == 'GatedGraphSignature':
        args.allow_unused = True
    else:
        args.allow_unused = False

    ''' Random or Adamic Adar Baseline '''
    if args.random_baseline or args.adamic_adar_baseline or args.deepwalk_baseline:
        test_inner_avg_auc, test_inner_avg_ap = test(args,meta_model,optimizer,test_loader,0,\
                return_val=True,inner_steps=1000,seed=args.seed)
        sys.exit()

    ''' Run WL-Kernel '''
    if args.wl:
        load_path = '../saved_models/' + args.namestr + '.pt'
        meta_model.load_state_dict(torch.load(load_path))
        run_analysis(args, meta_model,train_loader)
        test(args,meta_model,optimizer,test_loader,0)
        sys.exit()

    ''' Meta-training '''
    mode = 'Train'
    meta_loss = torch.Tensor([0])
    args.final_test = False
    for epoch in range(0,args.epochs):
        graph_id_local = 0
        graph_id_global = 0
        train_inner_avg_auc_list, train_inner_avg_ap_list = [], []
        if epoch > 0 and args.dataset !='PPI':
            args.resplit = False
        for i,data in enumerate(train_loader):
            if args.debug:
                ''' Print the Computation Graph '''
                dot = make_dot(meta_gradient_step(meta_model,args,data,optimizer,args.inner_steps,args.inner_lr,\
                        args.order,graph_id_local,mode,test_inner_avg_auc_list, test_inner_avg_ap_list, \
                        epoch,i,True)[1],params=dict(meta_model.named_parameters()))
                dot.format = 'png'
                dot.render(args.debug_name)
                quit()

            graph_id_local, meta_loss, train_inner_avg_auc_list, train_inner_avg_ap_list = meta_gradient_step(meta_model,\
                    args,data,optimizer,args.inner_steps,args.inner_lr,args.order,graph_id_local,\
                    mode,train_inner_avg_auc_list, train_inner_avg_ap_list,epoch,i,True)
            if args.do_kl_anneal:
                args.kl_anneal = args.kl_anneal + 1/args.epochs

            auc_list, ap_list = global_test(args,meta_model,data,OrderedDict(meta_model.named_parameters()))
            if args.comet:
                if len(ap_list) > 0:
                    auc_metric = 'Train_Global_Batch_Graph_' + str(i) +'_AUC'
                    ap_metric = 'Train_Global_Batch_Graph_' + str(i) +'_AP'
                    args.experiment.log_metric(auc_metric,sum(auc_list)/len(auc_list),step=epoch)
                    args.experiment.log_metric(ap_metric,sum(ap_list)/len(ap_list),step=epoch)
            if args.wandb:
                if len(ap_list) > 0:
                        auc_metric = 'Train_Global_Batch_Graph_' + str(i) +'_AUC'
                        ap_metric = 'Train_Global_Batch_Graph_' + str(i) +'_AP'
                        wandb.log({auc_metric:sum(auc_list)/len(auc_list),\
                                ap_metric:sum(ap_list)/len(ap_list),"x":epoch},commit=False)
            graph_id_global += len(ap_list)

            if args.wandb:
                wandb.log()

        if args.comet:
            if len(train_inner_avg_ap_list) > 0:
                auc_metric = 'Train_Inner_Avg' +'_AUC'
                ap_metric = 'Train_Inner_Avg' + str(i) +'_AP'
                args.experiment.log_metric(auc_metric,sum(train_inner_avg_auc_list)/len(train_inner_avg_auc_list),step=epoch)
                args.experiment.log_metric(ap_metric,sum(train_inner_avg_ap_list)/len(train_inner_avg_ap_list),step=epoch)
        if args.wandb:
            if len(train_inner_avg_ap_list) > 0:
                    auc_metric = 'Train_Inner_Avg' +'_AUC'
                    ap_metric = 'Train_Inner_Avg' + str(i) +'_AP'
                    wandb.log({auc_metric:sum(train_inner_avg_auc_list)/len(train_inner_avg_auc_list),\
                            ap_metric:sum(train_inner_avg_ap_list)/len(train_inner_avg_ap_list),\
                            "x":epoch},commit=False)

        if len(train_inner_avg_ap_list) > 0:
            print('Train Inner AUC: {:.4f}, AP: {:.4f}'.format(sum(train_inner_avg_auc_list)/len(train_inner_avg_auc_list),\
                            sum(train_inner_avg_ap_list)/len(train_inner_avg_ap_list)))

        ''' Meta-Testing After every Epoch'''
        meta_model_copy = kwargs[args.model](kwargs_enc[args.encoder](args, args.num_features, args.num_channels)).to(args.dev)
        meta_model_copy.load_state_dict(meta_model.state_dict())
        if args.train_only_gs:
            optimizer_copy = torch.optim.Adam(trainable_parameters, lr=args.meta_lr)
        else:
            optimizer_copy = torch.optim.Adam(meta_model_copy.parameters(), lr=args.meta_lr)
        optimizer_copy.load_state_dict(optimizer.state_dict())
        validation(args,meta_model_copy,optimizer_copy,val_loader,epoch)
        test(args,meta_model_copy,optimizer_copy,test_loader,epoch,inner_steps=args.inner_steps)

    print("Failed on %d Training graphs" %(args.fail_counter))

    ''' Save Global Params '''
    if not os.path.exists('../saved_models/'):
        os.makedirs('../saved_models/')
    save_path = '../saved_models/meta_vgae.pt'
    save_path = '../saved_models/' + args.namestr + '_global_.pt'
    torch.save(meta_model.state_dict(), save_path)

    ''' Run to Convergence '''
    if args.ego:
        optimizer = torch.optim.Adam(meta_model.parameters(), lr=args.meta_lr)
        args.inner_lr = args.inner_lr * args.reset_inner_factor
    val_inner_avg_auc, val_inner_avg_ap = test(args,meta_model,optimizer,val_loader,epoch,\
            return_val=True,inner_steps=1000)
    if args.ego:
        optimizer = torch.optim.Adam(meta_model.parameters(), lr=args.meta_lr)
        args.inner_lr = args.inner_lr * args.reset_inner_factor
    test_inner_avg_auc, test_inner_avg_ap = test(args,meta_model,optimizer,test_loader,epoch,\
            return_val=True,inner_steps=1000)
    if args.comet:
        args.experiment.end()

    val_eval_metric = 0.5*val_inner_avg_auc + 0.5*val_inner_avg_ap
    test_eval_metric = 0.5*test_inner_avg_auc + 0.5*test_inner_avg_ap
    return val_eval_metric
Esempio n. 3
0
def validation(args,meta_model,optimizer,val_loader,train_epoch,return_val=False):
    ''' Meta-Valing '''
    mode='Val'
    val_graph_id_local = 0
    val_graph_id_global = 0
    args.resplit = True
    epoch=0
    meta_loss = torch.Tensor([0])
    val_avg_auc_list, val_avg_ap_list = [], []
    val_inner_avg_auc_list, val_inner_avg_ap_list = [], []
    args.final_val = False
    inner_val_auc_array = None
    inner_val_ap_array = None
    if return_val:
        args.inner_steps = inner_steps
        args.final_val = True
        inner_val_auc_array = np.zeros((len(val_loader)*args.val_batch_size, int(1000/5)))
        inner_val_ap_array = np.zeros((len(val_loader)*args.val_batch_size, int(1000/5)))
    for j,data in enumerate(val_loader):
        if not args.random_baseline:
            val_graph_id_local, meta_loss, val_inner_avg_auc_list, val_inner_avg_ap_list = meta_gradient_step(meta_model,\
                    args,data,optimizer,args.inner_steps,args.inner_lr,args.order,val_graph_id_local,mode,\
                    val_inner_avg_auc_list,val_inner_avg_ap_list,epoch,j,False,\
                            inner_val_auc_array,inner_val_ap_array)
        auc_list, ap_list = global_test(args,meta_model,data,OrderedDict(meta_model.named_parameters()))
        val_avg_auc_list.append(sum(auc_list)/len(auc_list))
        val_avg_ap_list.append(sum(ap_list)/len(ap_list))
        if args.comet:
            if len(auc_list) > 0 and len(ap_list) > 0:
                auc_metric = 'Val_Batch_Graph_' + str(j) +'_AUC'
                ap_metric = 'Val_Batch_Graph_' + str(j) +'_AP'
                args.experiment.log_metric(auc_metric,sum(auc_list)/len(auc_list),step=train_epoch)
                args.experiment.log_metric(ap_metric,sum(ap_list)/len(ap_list),step=train_epoch)
        if args.wandb:
            if len(auc_list) > 0 and len(ap_list) > 0:
                auc_metric = 'Val_Batch_Graph_' + str(j) +'_AUC'
                ap_metric = 'Val_Batch_Graph_' + str(j) +'_AP'
                wandb.log({auc_metric:sum(auc_list)/len(auc_list),\
                        ap_metric:sum(ap_list)/len(ap_list),"x":epoch},commit=False)

    print("Val Avg Auc %f | Val Avg AP %f" %(sum(val_avg_auc_list)/len(val_avg_auc_list),\
            sum(val_avg_ap_list)/len(val_avg_ap_list)))
    if len(val_inner_avg_ap_list) > 0:
        print("Val Inner Avg Auc %f | Val Avg AP %f" %(sum(val_inner_avg_auc_list)/len(val_inner_avg_auc_list),\
                sum(val_inner_avg_ap_list)/len(val_inner_avg_ap_list)))
    if args.comet:
        if len(auc_list) > 0 and len(ap_list) > 0:
            auc_metric = 'Val_Avg_' +'_AUC'
            ap_metric = 'Val_Avg_' +'_AP'
            inner_auc_metric = 'Val_Inner_Avg' +'_AUC'
            inner_ap_metric = 'Val_Inner_Avg' +'_AP'
            args.experiment.log_metric(auc_metric,sum(val_avg_auc_list)/len(val_avg_auc_list),step=train_epoch)
            args.experiment.log_metric(ap_metric,sum(val_avg_ap_list)/len(val_avg_ap_list),step=train_epoch)
            args.experiment.log_metric(inner_auc_metric,sum(val_inner_avg_auc_list)/len(val_inner_avg_auc_list),step=train_epoch)
            args.experiment.log_metric(inner_ap_metric,sum(val_inner_avg_ap_list)/len(val_inner_avg_ap_list),step=train_epoch)
    if args.wandb:
        if len(auc_list) > 0 and len(ap_list) > 0:
            auc_metric = 'Val_Avg' +'_AUC'
            ap_metric = 'Val_Avg' +'_AP'
            inner_auc_metric = 'Val_Inner_Avg' +'_AUC'
            inner_ap_metric = 'Val_Inner_Avg' +'_AP'
            wandb.log({auc_metric:sum(val_avg_auc_list)/len(val_avg_auc_list),\
                    ap_metric:sum(val_avg_ap_list)/len(val_avg_ap_list),\
                    inner_auc_metric:sum(val_inner_avg_auc_list)/len(val_inner_avg_auc_list),\
                    inner_ap_metric:sum(val_inner_avg_ap_list)/len(val_inner_avg_ap_list),\
                    "x":epoch},commit=False)

    if return_val:
        val_avg_auc = sum(val_avg_auc_list)/len(val_avg_auc_list)
        val_avg_ap = sum(val_avg_ap_list)/len(val_avg_ap_list)
        val_inner_avg_auc = sum(val_inner_avg_auc_list)/len(val_inner_avg_auc_list)
        val_inner_avg_ap = sum(val_inner_avg_ap_list)/len(val_inner_avg_ap_list)
        #Remove All zero rows
        val_auc_array = inner_val_auc_array[~np.all(inner_val_auc_array == 0, axis=1)]
        val_ap_array = inner_val_ap_array[~np.all(inner_val_ap_array == 0, axis=1)]

        val_aggr_auc = np.sum(val_auc_array,axis=0)/len(val_loader)
        val_aggr_ap = np.sum(val_ap_array,axis=0)/len(val_loader)
        max_auc = np.max(val_aggr_auc)
        max_ap = np.max(val_aggr_ap)
        auc_metric = 'Val_Complete' +'_AUC'
        ap_metric = 'Val_Complete' +'_AP'
        for val_idx in range(0,val_auc_array.shape[1]):
            auc = val_aggr_auc[val_idx]
            ap = val_aggr_ap[val_idx]
            if args.comet:
                args.experiment.log_metric(auc_metric,auc,step=val_idx)
                args.experiment.log_metric(ap_metric,ap,step=val_idx)
            if args.wandb:
                wandb.log({auc_metric:auc,ap_metric:ap,"x":val_idx})
        print("Val Max AUC :%f | Val Max AP: %f" %(max_auc,max_ap))
        return max_auc, max_ap