示例#1
0
    total_f1_list = []
    total_loss_list = []
    best_f1_list = []

    # log_file = open(f'log_{args.experiment_name}_run{l}.csv',"w")

    total_runs = args.total_runs

    for l in range(total_runs):
        # log_csv = open(f'./results/csv_{args.experiment_name}_run{str(l)}.csv',"w")
        # log_file = open(f'./results/log_{args.experiment_name}_run{str(l)}.txt',"a")
        model = GTN(num_edge=A.shape[-1],
                    num_channels=num_channels,
                    w_in=node_features.shape[1],
                    w_out=node_dim,
                    num_class=num_classes,
                    num_layers=num_layers,
                    norm=norm)
        if adaptive_lr == 'false':
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=0.005,
                                         weight_decay=0.001)
        else:
            optimizer = torch.optim.Adam([{
                'params': model.weight
            }, {
                'params': model.linear1.parameters()
            }, {
                'params': model.linear2.parameters()
            }, {
示例#2
0
 
 node_features = torch.from_numpy(node_features).type(torch.FloatTensor)
 train_node = torch.from_numpy(np.array(labels[0])[:,0]).type(torch.LongTensor)
 train_target = torch.from_numpy(np.array(labels[0])[:,1]).type(torch.LongTensor)
 valid_node = torch.from_numpy(np.array(labels[1])[:,0]).type(torch.LongTensor)
 valid_target = torch.from_numpy(np.array(labels[1])[:,1]).type(torch.LongTensor)
 test_node = torch.from_numpy(np.array(labels[2])[:,0]).type(torch.LongTensor)
 test_target = torch.from_numpy(np.array(labels[2])[:,1]).type(torch.LongTensor)
 
 num_classes = torch.max(train_target).item()+1
 final_f1 = 0
 for l in range(1):
     model = GTN(num_edge=A.shape[-1],
                         num_channels=num_channels,
                         w_in = node_features.shape[1],
                         w_out = node_dim,
                         num_class=num_classes,
                         num_layers=num_layers,
                         norm=norm)
     if adaptive_lr == 'false':
         optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
     else:
         optimizer = torch.optim.Adam([{'params':model.weight},
                                     {'params':model.linear1.parameters()},
                                     {'params':model.linear2.parameters()},
                                     {"params":model.layers.parameters(), "lr":0.5}
                                     ], lr=0.005, weight_decay=0.001)
     loss = nn.CrossEntropyLoss()
     # Train & Valid & Test
     best_val_loss = 10000
     best_test_loss = 10000
def test(args):
    node_dim = args.node_dim
    num_channels = args.num_channels
    # lr = args.lr
    # weight_decay = args.weight_decay
    num_layers = args.num_layers
    norm = args.norm
    # adaptive_lr = args.adaptive_lr

    if args.ogb_arxiv:
        print("Using OGB arxiv")

    else:
        with open('data/' + args.dataset + '/node_features.pkl', 'rb') as f:
            node_features = pickle.load(f)
        with open('data/' + args.dataset + '/edges.pkl', 'rb') as f:
            edges = pickle.load(f)
        with open('data/' + args.dataset + '/labels.pkl', 'rb') as f:
            labels = pickle.load(f)

    num_nodes = edges[0].shape[0]
    # print("Current Dataset : ",args.dataset)
    # print("Number of nodes : ",num_nodes)
    # print("Node Feature shape , sample: ", node_features.shape, node_features[0])
    # print("Edges shape , sample : ", edges[0].shape, edges[0])
    # print("labels shape , sample : ", labels[0].shape, labels[0])
    # print("Node Feature shape : ", node_features.shape)

    for i, edge in enumerate(edges):
        if i == 0:
            A = torch.from_numpy(edge.todense()).type(
                torch.FloatTensor).unsqueeze(-1)
        else:
            A = torch.cat([
                A,
                torch.from_numpy(edge.todense()).type(
                    torch.FloatTensor).unsqueeze(-1)
            ],
                          dim=-1)
    A = torch.cat(
        [A, torch.eye(num_nodes).type(torch.FloatTensor).unsqueeze(-1)],
        dim=-1)

    node_features = torch.from_numpy(node_features).type(torch.FloatTensor)
    test_node = torch.from_numpy(np.array(labels[2])[:,
                                                     0]).type(torch.LongTensor)
    test_target = torch.from_numpy(np.array(labels[2])[:, 1]).type(
        torch.LongTensor)

    model = GTN(num_edge=A.shape[-1],
                num_channels=num_channels,
                w_in=node_features.shape[1],
                w_out=node_dim,
                num_class=num_classes,
                num_layers=num_layers,
                norm=norm)

    print('loading pretrained model from %s' % args.saved_model)
    model.load_state_dict(torch.load(args.saved_model))

    loss = nn.CrossEntropyLoss()

    dur_test = time.time()
    best_model.eval()
    with torch.no_grad():
        test_loss, y_test, W = best_model.forward(A, node_features, test_node,
                                                  test_target)
        # # print("W matrix shape", W.shape)
        # print("W type and length",type(W), len(W))
        # print("W[0] type and length", type(W[0]), len(W[0]))
        # print("W[0][0] type and shape", type(W[0][0]), W[0][0].shape)
        # # print("W[0][0][0] type and len", type(W[0][0][0]),len(W[0][0][0]))
        # print(W[0][0])
        test_f1 = torch.mean(
            f1_score(torch.argmax(y_test, dim=1),
                     test_target,
                     num_classes=num_classes)).cpu().numpy()
    # best_test_loss = min(best_test_loss,test_loss.detach().cpu().numpy())
    # best_test_f1 = max(best_test_f1,test_f1)
    dur_test = (time.time() - dur_test) / 60.0

    print('Test - Loss: {}, Macro_F1: {}, Duration: {}'.format(
        test_loss.detach().cpu().numpy(), test_f1, dur_test))
def main():
    for i in range(77):
        args = arg_parse()
        print("Main : ", args)
        with open('LH_edges', 'rb') as f:
            graphs_list = pickle.load(f)
            graphs_list = graphs_list[i]
            graphs_list = np.asarray(graphs_list)
            graphs_list = np.reshape(graphs_list, (1, 35, 35, 4))

        for subject in range(len(graphs_list)):
            for view in range(graphs_list[0].shape[2]):
                graphs_list[subject][:, :, view] = minmax_sc(
                    graphs_list[subject][:, :, view])

        print(graphs_list.shape)

        graphs_stacked = np.stack(graphs_list, axis=0)
        graphs_torch = torch.from_numpy(graphs_stacked)
        print(graphs_torch.shape)

        edges_number = graphs_list[0].shape[-1]

        graph_dataloader = graph_data_loader.GraphDataLoader(graphs_list)

        model_GTN = GTN(num_edge=edges_number,
                        num_channels=2,
                        num_layers=args.num_layers,
                        norm=True)
        model_PopulationFusion = PopulationWeightedFusion(
            num_subjects=len(graphs_list))
        H_population, Epoch_losses, GTN_losses, Fusion_losses = train(
            args, graph_dataloader, graphs_torch, model_GTN,
            model_PopulationFusion)
        arr2[i] = H_population
        mean[i] = arr2[i].mean()
        print(arr2[i])
        # save epoch losses and fusion output tensor
        with open('H_population{}'.format(i), 'wb') as f:
            pickle.dump(H_population, f)
        with open('epoch_losses{}'.format(i), 'wb') as f:
            pickle.dump(Epoch_losses, f)

        # plot loss evolution across epochs
        x_epochs = [i for i in range(args.num_epochs)]
        plt1, = plt.plot(x_epochs, Epoch_losses, label='Total loss'.format(i))
        plt2, = plt.plot(x_epochs,
                         GTN_losses,
                         label='Transformer loss'.format(i))
        plt3, = plt.plot(x_epochs,
                         Fusion_losses,
                         label='Fusion loss'.format(i))
        plt.legend(handles=[plt1, plt2, plt3])
        plt.xlabel('Epochs'.format(i))
        plt.ylabel('Loss'.format(i))
        plt.grid(True)
        plt.savefig("H{}".format(i))

        # plot fusion template tensor
        mask_adj = np.zeros_like(H_population)
        mask_adj[np.triu_indices_from(mask_adj)] = True

        with sns.axes_style("white"):
            f, ax = plt.subplots(figsize=(30, 30))
            ax = sns.heatmap(H_population,
                             mask=mask_adj,
                             square=True,
                             annot=True)
        Epoch_losses = 0
        GTN_losses = 0
        Fusion_losses = 0