示例#1
0
def main(args):
    if args.randomseed > 0:
        torch.manual_seed(args.randomseed)

    date_time = datetime.now().strftime('%m-%d-%H:%M:%S')
    log_path = os.path.join(args.log_root, args.log_path, args.save_name,
                            date_time)
    if os.path.isdir(log_path) == False:
        try:
            os.makedirs(log_path)
        except FileExistsError:
            print('Folder exists!')
    load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1]
    if load_func == 'WebKB':
        load_func = WebKB
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikipediaNetwork':
        load_func = WikipediaNetwork
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikiCS':
        load_func = WikiCS
        dataset = load_func(root=args.data_path)
    elif load_func == 'cora_ml':
        dataset = citation_datasets(
            root='../dataset/data/tmp/cora_ml/cora_ml.npz')
    elif load_func == 'citeseer_npz':
        dataset = citation_datasets(
            root='../dataset/data/tmp/citeseer_npz/citeseer_npz.npz')
    else:
        dataset = load_syn(args.data_path + args.dataset, None)

    if os.path.isdir(log_path) == False:
        os.makedirs(log_path)

    data = dataset[0]
    if not data.__contains__('edge_weight'):
        data.edge_weight = None
    data.y = data.y.long()
    num_classes = (data.y.max() - data.y.min() + 1).detach().numpy()
    data = data.to(device)
    if data.edge_weight is not None:
        data.edge_weight = torch.FloatTensor(data.edge_weight).to(device)
    if args.to_undirected:
        data.edge_index, data.edge_weight = to_undirected(
            data.edge_index, data.edge_weight)
    # normalize label, the minimum should be 0 as class index
    splits = data.train_mask.shape[1]
    if len(data.test_mask.shape) == 1:
        data.test_mask = data.test_mask.unsqueeze(1).repeat(1, splits)

    results = np.zeros((splits, 4))
    for split in range(splits):
        log_str_full = ''
        graphmodel = APPNP_Model(data.x.size(-1),
                                 num_classes,
                                 filter_num=args.num_filter,
                                 alpha=args.alpha,
                                 dropout=args.dropout,
                                 layer=args.layer).to(device)
        model = graphmodel  # nn.DataParallel(graphmodel)
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)

        #################################
        # Train/Validation/Test
        #################################
        best_test_err = 1000.0
        early_stopping = 0
        for epoch in range(args.epochs):
            start_time = time.time()
            ####################
            # Train
            ####################
            train_loss, train_acc = 0.0, 0.0

            # for loop for batch loading
            model.train()
            out = model(data)

            train_loss = F.nll_loss(out[data.train_mask[:, split]],
                                    data.y[data.train_mask[:, split]])
            pred_label = out.max(dim=1)[1]

            train_acc = acc(pred_label, data.y, data.train_mask[:, split])

            opt.zero_grad()
            train_loss.backward()
            opt.step()

            outstrtrain = 'Train loss:, %.6f, acc:, %.3f,' % (
                train_loss.detach().item(), train_acc)
            #scheduler.step()

            ####################
            # Validation
            ####################
            model.eval()
            test_loss, test_acc = 0.0, 0.0

            out = model(data)
            pred_label = out.max(dim=1)[1]

            test_loss = F.nll_loss(out[data.val_mask[:, split]],
                                   data.y[data.val_mask[:, split]])
            test_acc = acc(pred_label, data.y, data.val_mask[:, split])

            outstrval = ' Test loss:, %.6f, acc: ,%.3f,' % (
                test_loss.detach().item(), test_acc)

            duration = "---, %.4f, seconds ---" % (time.time() - start_time)
            log_str = (
                "%d, / ,%d, epoch," %
                (epoch, args.epochs)) + outstrtrain + outstrval + duration
            log_str_full += log_str + '\n'
            #print(log_str)

            ####################
            # Save weights
            ####################
            save_perform = test_loss.detach().item()
            if save_perform <= best_test_err:
                early_stopping = 0
                best_test_err = save_perform
                torch.save(model.state_dict(),
                           log_path + '/model' + str(split) + '.t7')
            else:
                early_stopping += 1
            if early_stopping > 500 or epoch == (args.epochs - 1):
                torch.save(model.state_dict(),
                           log_path + '/model_latest' + str(split) + '.t7')
                break

        write_log(vars(args), log_path)

        ####################
        # Testing
        ####################
        model.load_state_dict(
            torch.load(log_path + '/model' + str(split) + '.t7'))
        model.eval()
        preds = model(data)
        pred_label = preds.max(dim=1)[1]
        np.save(log_path + '/pred' + str(split), pred_label.to('cpu'))

        acc_train = acc(pred_label, data.y, data.val_mask[:, split])
        acc_test = acc(pred_label, data.y, data.test_mask[:, split])

        model.load_state_dict(
            torch.load(log_path + '/model_latest' + str(split) + '.t7'))
        model.eval()
        preds = model(data)
        pred_label = preds.max(dim=1)[1]

        np.save(log_path + '/pred_latest' + str(split), pred_label.to('cpu'))

        acc_train_latest = acc(pred_label, data.y, data.val_mask[:, split])
        acc_test_latest = acc(pred_label, data.y, data.test_mask[:, split])

        ####################
        # Save testing results
        ####################
        logstr = 'val_acc: ' + str(
            np.round(acc_train, 3)) + ' test_acc: ' + str(np.round(
                acc_test, 3)) + ' val_acc_latest: ' + str(
                    np.round(acc_train_latest,
                             3)) + ' test_acc_latest: ' + str(
                                 np.round(acc_test_latest, 3))
        print(logstr)
        results[split] = [
            acc_train, acc_test, acc_train_latest, acc_test_latest
        ]
        log_str_full += logstr
        with open(log_path + '/log' + str(split) + '.csv', 'w') as file:
            file.write(log_str_full)
            file.write('\n')
        torch.cuda.empty_cache()
    return results
示例#2
0
def main(args):

    date_time = datetime.now().strftime('%m-%d-%H:%M:%S')
    log_path = os.path.join(args.log_root, args.log_path, args.save_name,
                            date_time)

    load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1]
    if load_func == 'WebKB':
        load_func = WebKB
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikipediaNetwork':
        load_func = WikipediaNetwork
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikiCS':
        load_func = WikiCS
        dataset = load_func(root=args.data_path)
    elif load_func == 'cora_ml':
        dataset = load_citation_link(
            root='../dataset/data/tmp/cora_ml/cora_ml.npz')
    elif load_func == 'citeseer':
        dataset = load_citation_link(
            root='../dataset/data/tmp/citeseer_npz/citeseer_npz.npz')
        #load telegram/synthetic here
    else:
        dataset = load_syn(args.data_path + args.dataset, None)

    if os.path.isdir(log_path) == False:
        os.makedirs(log_path)

    # load dataset
    if 'dataset' in locals():
        data = dataset[0]
        edge_index = data.edge_index

    size = torch.max(edge_index).item() + 1

    size = torch.max(edge_index).item() + 1
    # generate edge index dataset
    #if args.task == 2:
    #    datasets = generate_dataset_2class(edge_index, splits = 10, test_prob = args.drop_prob)
    #else:
    save_file = args.data_path + args.dataset + '/' + subset
    datasets = generate_dataset_3class(edge_index,
                                       size,
                                       save_file,
                                       splits=10,
                                       probs=args.split_prob,
                                       task=args.task,
                                       label_dim=args.num_class_link)

    if args.task != 2:
        results = np.zeros((10, 4))
    else:
        results = np.zeros((10, 4, 5))
    for i in range(10):
        log_str_full = ''
        edges = datasets[i]['graph']
        if args.to_undirected:
            edges = to_undirected(edges)

        ########################################
        # initialize model and load dataset
        ########################################
        #x = torch.ones(size).unsqueeze(-1).to(device)
        x = in_out_degree(edges, size).to(device)
        edges = edges.long().to(device)

        model = GIN_Link(x.size(-1),
                         args.num_class_link,
                         filter_num=args.num_filter,
                         dropout=args.dropout).to(device)
        #model = nn.DataParallel(graphmodel)
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)

        y_train = datasets[i]['train']['label']
        y_val = datasets[i]['validate']['label']
        y_test = datasets[i]['test']['label']
        y_train = torch.from_numpy(y_train).long().to(device)
        y_val = torch.from_numpy(y_val).long().to(device)
        y_test = torch.from_numpy(y_test).long().to(device)

        train_index = torch.from_numpy(
            datasets[i]['train']['pairs']).to(device)
        val_index = torch.from_numpy(
            datasets[i]['validate']['pairs']).to(device)
        test_index = torch.from_numpy(datasets[i]['test']['pairs']).to(device)

        #################################
        # Train/Validation/Test
        #################################
        best_test_err = 1000.0
        early_stopping = 0
        for epoch in range(args.epochs):
            start_time = time.time()
            if early_stopping > 500:
                break
            ####################
            # Train
            ####################
            train_loss, train_acc = 0.0, 0.0
            model.train()
            out = model(x, edges, train_index)

            train_loss = F.nll_loss(out, y_train)
            pred_label = out.max(dim=1)[1]
            train_acc = acc(pred_label, y_train)

            opt.zero_grad()
            train_loss.backward()
            opt.step()
            outstrtrain = 'Train loss: %.6f, acc: %.3f' % (
                train_loss.detach().item(), train_acc)

            ####################
            # Validation
            ####################
            train_loss, train_acc = 0.0, 0.0
            model.eval()
            out = model(x, edges, val_index)

            test_loss = F.nll_loss(out, y_val)
            pred_label = out.max(dim=1)[1]
            test_acc = acc(pred_label, y_val)

            outstrval = ' Test loss: %.6f, acc: %.3f' % (
                test_loss.detach().item(), test_acc)
            duration = "--- %.4f seconds ---" % (time.time() - start_time)
            log_str = (
                "%d / %d epoch" %
                (epoch, args.epochs)) + outstrtrain + outstrval + duration
            #print(log_str)
            log_str_full += log_str + '\n'
            ####################
            # Save weights
            ####################
            save_perform = test_loss.detach().item()
            if save_perform <= best_test_err:
                early_stopping = 0
                best_test_err = save_perform
                torch.save(model.state_dict(),
                           log_path + '/model' + str(i) + '.t7')
            else:
                early_stopping += 1

        write_log(vars(args), log_path)
        torch.save(model.state_dict(),
                   log_path + '/model_latest' + str(i) + '.t7')
        if args.task != 2:
            ####################
            # Testing
            ####################
            model.load_state_dict(
                torch.load(log_path + '/model' + str(i) + '.t7'))
            model.eval()
            out = model(x, edges, val_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            val_acc = acc(pred_label, y_val)

            out = model(x, edges, test_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            test_acc = acc(pred_label, y_test)

            model.load_state_dict(
                torch.load(log_path + '/model_latest' + str(i) + '.t7'))
            model.eval()
            out = model(x, edges, val_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            val_acc_latest = acc(pred_label, y_val)

            out = model(x, edges, test_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            test_acc_latest = acc(pred_label, y_test)
            ####################
            # Save testing results
            ####################
            log_str = ('val_acc: {val_acc:.4f}, ' +
                       'test_acc: {test_acc:.4f}, ')
            log_str1 = log_str.format(val_acc=val_acc, test_acc=test_acc)
            log_str_full += log_str1

            log_str = ('val_acc_latest: {val_acc_latest:.4f}, ' +
                       'test_acc_latest: {test_acc_latest:.4f}, ')
            log_str2 = log_str.format(val_acc_latest=val_acc_latest,
                                      test_acc_latest=test_acc_latest)
            log_str_full += log_str2 + '\n'
            print(log_str1 + log_str2)

            results[i] = [val_acc, test_acc, val_acc_latest, test_acc_latest]
        else:
            model.load_state_dict(
                torch.load(log_path + '/model' + str(i) + '.t7'))
            model.eval()
            out_val = model(x, edges, val_index)
            out_test = model(x, edges, test_index)
            [[val_acc_full, val_acc, val_auc, val_f1_micro, val_f1_macro],
             [test_acc_full, test_acc, test_auc, test_f1_micro, test_f1_macro]
             ] = link_prediction_evaluation(out_val, out_test, y_val, y_test)

            model.load_state_dict(
                torch.load(log_path + '/model_latest' + str(i) + '.t7'))
            model.eval()
            out_val = model(x, edges, val_index)
            out_test = model(x, edges, test_index)
            [[
                val_acc_full_latest, val_acc_latest, val_auc_latest,
                val_f1_micro_latest, val_f1_macro_latest
            ],
             [
                 test_acc_full_latest, test_acc_latest, test_auc_latest,
                 test_f1_micro_latest, test_f1_macro_latest
             ]] = link_prediction_evaluation(out_val, out_test, y_val, y_test)
            ####################
            # Save testing results
            ####################
            log_str = (
                'val_acc_full:{val_acc_full:.4f}, val_acc: {val_acc:.4f}, Val_auc: {val_auc:.4f},'
                +
                'val_f1_micro: {val_f1_micro:.4f}, val_f1_macro: {val_f1_macro:.4f}, '
                +
                'test_acc_full:{test_acc_full:.4f}, test_acc: {test_acc:.4f}, '
                +
                'test_f1_micro: {test_f1_micro:.4f}, test_f1_macro: {test_f1_macro:.4f}'
            )
            log_str = log_str.format(val_acc_full=val_acc_full,
                                     val_acc=val_acc,
                                     val_auc=val_auc,
                                     val_f1_micro=val_f1_micro,
                                     val_f1_macro=val_f1_macro,
                                     test_acc_full=test_acc_full,
                                     test_acc=val_acc,
                                     test_f1_micro=val_f1_micro,
                                     test_f1_macro=val_f1_macro)
            log_str_full += log_str + '\n'
            print(log_str)

            log_str = (
                'val_acc_full_latest:{val_acc_full_latest:.4f}, val_acc_latest: {val_acc_latest:.4f}, Val_auc_latest: {val_auc_latest:.4f},'
                +
                'val_f1_micro_latest: {val_f1_micro_latest:.4f}, val_f1_macro_latest: {val_f1_macro_latest:.4f},'
                +
                'test_acc_full_latest:{test_acc_full_latest:.4f}, test_acc_latest: {test_acc_latest:.4f}, '
                +
                'test_f1_micro_latest: {test_f1_micro_latest:.4f}, test_f1_macro_latest: {test_f1_macro_latest:.4f}'
            )
            log_str = log_str.format(val_acc_full_latest=val_acc_full_latest,
                                     val_acc_latest=val_acc_latest,
                                     val_auc_latest=val_auc_latest,
                                     val_f1_micro_latest=test_f1_micro_latest,
                                     val_f1_macro_latest=val_f1_macro_latest,
                                     test_acc_full_latest=test_acc_full_latest,
                                     test_acc_latest=val_acc,
                                     test_f1_micro_latest=test_f1_micro_latest,
                                     test_f1_macro_latest=test_f1_macro_latest)
            log_str_full += log_str + '\n'
            print(log_str)

            results[i] = [
                [val_acc_full, val_acc, val_auc, val_f1_micro, val_f1_macro],
                [
                    test_acc_full, test_acc, test_auc, test_f1_micro,
                    test_f1_macro
                ],
                [
                    val_acc_full_latest, val_acc_latest, val_auc_latest,
                    val_f1_micro_latest, val_f1_macro_latest
                ],
                [
                    test_acc_full_latest, test_acc_latest, test_auc_latest,
                    test_f1_micro_latest, test_f1_macro_latest
                ]
            ]

        with open(log_path + '/log' + str(i) + '.csv', 'w') as file:
            file.write(log_str_full)
            file.write('\n')
        torch.cuda.empty_cache()
    return results