Esempio n. 1
0
def main(args):

    date_time = datetime.now().strftime('%m-%d-%H:%M:%S')
    log_path = os.path.join(args.log_root, args.log_path, args.save_name,
                            date_time)

    load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1]
    if load_func == 'WebKB':
        load_func = WebKB
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikipediaNetwork':
        load_func = WikipediaNetwork
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikiCS':
        load_func = WikiCS
        dataset = load_func(root=args.data_path)
    elif load_func == 'cora_ml':
        dataset = load_citation_link(
            root='../dataset/data/tmp/cora_ml/cora_ml.npz')
    elif load_func == 'citeseer':
        dataset = load_citation_link(
            root='../dataset/data/tmp/citeseer_npz/citeseer_npz.npz')
        #load telegram/synthetic here
    else:
        dataset = load_syn(args.data_path + args.dataset, None)

    if os.path.isdir(log_path) == False:
        os.makedirs(log_path)

    # load dataset
    if 'dataset' in locals():
        data = dataset[0]
        edge_index = data.edge_index

    size = torch.max(edge_index).item() + 1

    size = torch.max(edge_index).item() + 1
    # generate edge index dataset
    #if args.task == 2:
    #    datasets = generate_dataset_2class(edge_index, splits = 10, test_prob = args.drop_prob)
    #else:
    save_file = args.data_path + args.dataset + '/' + subset
    datasets = generate_dataset_3class(edge_index,
                                       size,
                                       save_file,
                                       splits=10,
                                       probs=args.split_prob,
                                       task=args.task,
                                       label_dim=args.num_class_link)

    if args.task != 2:
        results = np.zeros((10, 4))
    else:
        results = np.zeros((10, 4, 5))
    for i in range(10):
        log_str_full = ''
        edges = datasets[i]['graph']
        if args.to_undirected:
            edges = to_undirected(edges)

        ########################################
        # initialize model and load dataset
        ########################################
        #x = torch.ones(size).unsqueeze(-1).to(device)
        x = in_out_degree(edges, size).to(device)
        edges = edges.long().to(device)

        model = GIN_Link(x.size(-1),
                         args.num_class_link,
                         filter_num=args.num_filter,
                         dropout=args.dropout).to(device)
        #model = nn.DataParallel(graphmodel)
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)

        y_train = datasets[i]['train']['label']
        y_val = datasets[i]['validate']['label']
        y_test = datasets[i]['test']['label']
        y_train = torch.from_numpy(y_train).long().to(device)
        y_val = torch.from_numpy(y_val).long().to(device)
        y_test = torch.from_numpy(y_test).long().to(device)

        train_index = torch.from_numpy(
            datasets[i]['train']['pairs']).to(device)
        val_index = torch.from_numpy(
            datasets[i]['validate']['pairs']).to(device)
        test_index = torch.from_numpy(datasets[i]['test']['pairs']).to(device)

        #################################
        # Train/Validation/Test
        #################################
        best_test_err = 1000.0
        early_stopping = 0
        for epoch in range(args.epochs):
            start_time = time.time()
            if early_stopping > 500:
                break
            ####################
            # Train
            ####################
            train_loss, train_acc = 0.0, 0.0
            model.train()
            out = model(x, edges, train_index)

            train_loss = F.nll_loss(out, y_train)
            pred_label = out.max(dim=1)[1]
            train_acc = acc(pred_label, y_train)

            opt.zero_grad()
            train_loss.backward()
            opt.step()
            outstrtrain = 'Train loss: %.6f, acc: %.3f' % (
                train_loss.detach().item(), train_acc)

            ####################
            # Validation
            ####################
            train_loss, train_acc = 0.0, 0.0
            model.eval()
            out = model(x, edges, val_index)

            test_loss = F.nll_loss(out, y_val)
            pred_label = out.max(dim=1)[1]
            test_acc = acc(pred_label, y_val)

            outstrval = ' Test loss: %.6f, acc: %.3f' % (
                test_loss.detach().item(), test_acc)
            duration = "--- %.4f seconds ---" % (time.time() - start_time)
            log_str = (
                "%d / %d epoch" %
                (epoch, args.epochs)) + outstrtrain + outstrval + duration
            #print(log_str)
            log_str_full += log_str + '\n'
            ####################
            # Save weights
            ####################
            save_perform = test_loss.detach().item()
            if save_perform <= best_test_err:
                early_stopping = 0
                best_test_err = save_perform
                torch.save(model.state_dict(),
                           log_path + '/model' + str(i) + '.t7')
            else:
                early_stopping += 1

        write_log(vars(args), log_path)
        torch.save(model.state_dict(),
                   log_path + '/model_latest' + str(i) + '.t7')
        if args.task != 2:
            ####################
            # Testing
            ####################
            model.load_state_dict(
                torch.load(log_path + '/model' + str(i) + '.t7'))
            model.eval()
            out = model(x, edges, val_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            val_acc = acc(pred_label, y_val)

            out = model(x, edges, test_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            test_acc = acc(pred_label, y_test)

            model.load_state_dict(
                torch.load(log_path + '/model_latest' + str(i) + '.t7'))
            model.eval()
            out = model(x, edges, val_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            val_acc_latest = acc(pred_label, y_val)

            out = model(x, edges, test_index)[:, :2]
            pred_label = out.max(dim=1)[1]
            test_acc_latest = acc(pred_label, y_test)
            ####################
            # Save testing results
            ####################
            log_str = ('val_acc: {val_acc:.4f}, ' +
                       'test_acc: {test_acc:.4f}, ')
            log_str1 = log_str.format(val_acc=val_acc, test_acc=test_acc)
            log_str_full += log_str1

            log_str = ('val_acc_latest: {val_acc_latest:.4f}, ' +
                       'test_acc_latest: {test_acc_latest:.4f}, ')
            log_str2 = log_str.format(val_acc_latest=val_acc_latest,
                                      test_acc_latest=test_acc_latest)
            log_str_full += log_str2 + '\n'
            print(log_str1 + log_str2)

            results[i] = [val_acc, test_acc, val_acc_latest, test_acc_latest]
        else:
            model.load_state_dict(
                torch.load(log_path + '/model' + str(i) + '.t7'))
            model.eval()
            out_val = model(x, edges, val_index)
            out_test = model(x, edges, test_index)
            [[val_acc_full, val_acc, val_auc, val_f1_micro, val_f1_macro],
             [test_acc_full, test_acc, test_auc, test_f1_micro, test_f1_macro]
             ] = link_prediction_evaluation(out_val, out_test, y_val, y_test)

            model.load_state_dict(
                torch.load(log_path + '/model_latest' + str(i) + '.t7'))
            model.eval()
            out_val = model(x, edges, val_index)
            out_test = model(x, edges, test_index)
            [[
                val_acc_full_latest, val_acc_latest, val_auc_latest,
                val_f1_micro_latest, val_f1_macro_latest
            ],
             [
                 test_acc_full_latest, test_acc_latest, test_auc_latest,
                 test_f1_micro_latest, test_f1_macro_latest
             ]] = link_prediction_evaluation(out_val, out_test, y_val, y_test)
            ####################
            # Save testing results
            ####################
            log_str = (
                'val_acc_full:{val_acc_full:.4f}, val_acc: {val_acc:.4f}, Val_auc: {val_auc:.4f},'
                +
                'val_f1_micro: {val_f1_micro:.4f}, val_f1_macro: {val_f1_macro:.4f}, '
                +
                'test_acc_full:{test_acc_full:.4f}, test_acc: {test_acc:.4f}, '
                +
                'test_f1_micro: {test_f1_micro:.4f}, test_f1_macro: {test_f1_macro:.4f}'
            )
            log_str = log_str.format(val_acc_full=val_acc_full,
                                     val_acc=val_acc,
                                     val_auc=val_auc,
                                     val_f1_micro=val_f1_micro,
                                     val_f1_macro=val_f1_macro,
                                     test_acc_full=test_acc_full,
                                     test_acc=val_acc,
                                     test_f1_micro=val_f1_micro,
                                     test_f1_macro=val_f1_macro)
            log_str_full += log_str + '\n'
            print(log_str)

            log_str = (
                'val_acc_full_latest:{val_acc_full_latest:.4f}, val_acc_latest: {val_acc_latest:.4f}, Val_auc_latest: {val_auc_latest:.4f},'
                +
                'val_f1_micro_latest: {val_f1_micro_latest:.4f}, val_f1_macro_latest: {val_f1_macro_latest:.4f},'
                +
                'test_acc_full_latest:{test_acc_full_latest:.4f}, test_acc_latest: {test_acc_latest:.4f}, '
                +
                'test_f1_micro_latest: {test_f1_micro_latest:.4f}, test_f1_macro_latest: {test_f1_macro_latest:.4f}'
            )
            log_str = log_str.format(val_acc_full_latest=val_acc_full_latest,
                                     val_acc_latest=val_acc_latest,
                                     val_auc_latest=val_auc_latest,
                                     val_f1_micro_latest=test_f1_micro_latest,
                                     val_f1_macro_latest=val_f1_macro_latest,
                                     test_acc_full_latest=test_acc_full_latest,
                                     test_acc_latest=val_acc,
                                     test_f1_micro_latest=test_f1_micro_latest,
                                     test_f1_macro_latest=test_f1_macro_latest)
            log_str_full += log_str + '\n'
            print(log_str)

            results[i] = [
                [val_acc_full, val_acc, val_auc, val_f1_micro, val_f1_macro],
                [
                    test_acc_full, test_acc, test_auc, test_f1_micro,
                    test_f1_macro
                ],
                [
                    val_acc_full_latest, val_acc_latest, val_auc_latest,
                    val_f1_micro_latest, val_f1_macro_latest
                ],
                [
                    test_acc_full_latest, test_acc_latest, test_auc_latest,
                    test_f1_micro_latest, test_f1_macro_latest
                ]
            ]

        with open(log_path + '/log' + str(i) + '.csv', 'w') as file:
            file.write(log_str_full)
            file.write('\n')
        torch.cuda.empty_cache()
    return results
Esempio n. 2
0
def main(args):
    if args.randomseed > 0:
        torch.manual_seed(args.randomseed)

    date_time = datetime.now().strftime('%m-%d-%H:%M:%S')
    log_path = os.path.join(args.log_root, args.log_path, args.save_name,
                            date_time)
    if os.path.isdir(log_path) == False:
        try:
            os.makedirs(log_path)
        except FileExistsError:
            print('Folder exists!')
    load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1]
    if load_func == 'WebKB':
        load_func = WebKB
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikipediaNetwork':
        load_func = WikipediaNetwork
        dataset = load_func(root=args.data_path, name=subset)
    elif load_func == 'WikiCS':
        load_func = WikiCS
        dataset = load_func(root=args.data_path)
    elif load_func == 'cora_ml':
        dataset = citation_datasets(
            root='../dataset/data/tmp/cora_ml/cora_ml.npz')
    elif load_func == 'citeseer_npz':
        dataset = citation_datasets(
            root='../dataset/data/tmp/citeseer_npz/citeseer_npz.npz')
    else:
        dataset = load_syn(args.data_path + args.dataset, None)

    if os.path.isdir(log_path) == False:
        os.makedirs(log_path)

    data = dataset[0]
    if not data.__contains__('edge_weight'):
        data.edge_weight = None
    data.y = data.y.long()
    num_classes = (data.y.max() - data.y.min() + 1).detach().numpy()
    data = data.to(device)
    if data.edge_weight is not None:
        data.edge_weight = torch.FloatTensor(data.edge_weight).to(device)
    if args.to_undirected:
        data.edge_index, data.edge_weight = to_undirected(
            data.edge_index, data.edge_weight)
    # normalize label, the minimum should be 0 as class index
    splits = data.train_mask.shape[1]
    if len(data.test_mask.shape) == 1:
        data.test_mask = data.test_mask.unsqueeze(1).repeat(1, splits)

    results = np.zeros((splits, 4))
    for split in range(splits):
        log_str_full = ''
        graphmodel = APPNP_Model(data.x.size(-1),
                                 num_classes,
                                 filter_num=args.num_filter,
                                 alpha=args.alpha,
                                 dropout=args.dropout,
                                 layer=args.layer).to(device)
        model = graphmodel  # nn.DataParallel(graphmodel)
        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)

        #################################
        # Train/Validation/Test
        #################################
        best_test_err = 1000.0
        early_stopping = 0
        for epoch in range(args.epochs):
            start_time = time.time()
            ####################
            # Train
            ####################
            train_loss, train_acc = 0.0, 0.0

            # for loop for batch loading
            model.train()
            out = model(data)

            train_loss = F.nll_loss(out[data.train_mask[:, split]],
                                    data.y[data.train_mask[:, split]])
            pred_label = out.max(dim=1)[1]

            train_acc = acc(pred_label, data.y, data.train_mask[:, split])

            opt.zero_grad()
            train_loss.backward()
            opt.step()

            outstrtrain = 'Train loss:, %.6f, acc:, %.3f,' % (
                train_loss.detach().item(), train_acc)
            #scheduler.step()

            ####################
            # Validation
            ####################
            model.eval()
            test_loss, test_acc = 0.0, 0.0

            out = model(data)
            pred_label = out.max(dim=1)[1]

            test_loss = F.nll_loss(out[data.val_mask[:, split]],
                                   data.y[data.val_mask[:, split]])
            test_acc = acc(pred_label, data.y, data.val_mask[:, split])

            outstrval = ' Test loss:, %.6f, acc: ,%.3f,' % (
                test_loss.detach().item(), test_acc)

            duration = "---, %.4f, seconds ---" % (time.time() - start_time)
            log_str = (
                "%d, / ,%d, epoch," %
                (epoch, args.epochs)) + outstrtrain + outstrval + duration
            log_str_full += log_str + '\n'
            #print(log_str)

            ####################
            # Save weights
            ####################
            save_perform = test_loss.detach().item()
            if save_perform <= best_test_err:
                early_stopping = 0
                best_test_err = save_perform
                torch.save(model.state_dict(),
                           log_path + '/model' + str(split) + '.t7')
            else:
                early_stopping += 1
            if early_stopping > 500 or epoch == (args.epochs - 1):
                torch.save(model.state_dict(),
                           log_path + '/model_latest' + str(split) + '.t7')
                break

        write_log(vars(args), log_path)

        ####################
        # Testing
        ####################
        model.load_state_dict(
            torch.load(log_path + '/model' + str(split) + '.t7'))
        model.eval()
        preds = model(data)
        pred_label = preds.max(dim=1)[1]
        np.save(log_path + '/pred' + str(split), pred_label.to('cpu'))

        acc_train = acc(pred_label, data.y, data.val_mask[:, split])
        acc_test = acc(pred_label, data.y, data.test_mask[:, split])

        model.load_state_dict(
            torch.load(log_path + '/model_latest' + str(split) + '.t7'))
        model.eval()
        preds = model(data)
        pred_label = preds.max(dim=1)[1]

        np.save(log_path + '/pred_latest' + str(split), pred_label.to('cpu'))

        acc_train_latest = acc(pred_label, data.y, data.val_mask[:, split])
        acc_test_latest = acc(pred_label, data.y, data.test_mask[:, split])

        ####################
        # Save testing results
        ####################
        logstr = 'val_acc: ' + str(
            np.round(acc_train, 3)) + ' test_acc: ' + str(np.round(
                acc_test, 3)) + ' val_acc_latest: ' + str(
                    np.round(acc_train_latest,
                             3)) + ' test_acc_latest: ' + str(
                                 np.round(acc_test_latest, 3))
        print(logstr)
        results[split] = [
            acc_train, acc_test, acc_train_latest, acc_test_latest
        ]
        log_str_full += logstr
        with open(log_path + '/log' + str(split) + '.csv', 'w') as file:
            file.write(log_str_full)
            file.write('\n')
        torch.cuda.empty_cache()
    return results
Esempio n. 3
0
def main(args):
    if args.randomseed > 0:
        torch.manual_seed(args.randomseed)

    date_time = datetime.now().strftime('%m-%d-%H:%M:%S')
    log_path = os.path.join(args.log_root, args.log_path, args.save_name,
                            date_time)
    if os.path.isdir(log_path) == False:
        try:
            os.makedirs(log_path)
        except FileExistsError:
            print('Folder exists!')

    load_func, subset = args.dataset.split('/')[0], args.dataset.split('/')[1]
    if load_func == 'WebKB':
        load_func = WebKB
    elif load_func == 'WikipediaNetwork':
        load_func = WikipediaNetwork
    elif load_func == 'WikiCS':
        load_func = WikiCS
    elif load_func == 'cora_ml':
        load_func = citation_datasets
    elif load_func == 'citeseer_npz':
        load_func = citation_datasets
    else:
        load_func = load_syn

    _file_ = args.data_path + args.dataset + '/data' + str(args.q) + '_' + str(
        args.K) + '_sparse.pk'
    if os.path.isfile(_file_):
        data = pk.load(open(_file_, 'rb'))
        L = data['L']
        X, label, train_mask, val_mask, test_mask = geometric_dataset_sparse(
            args.q,
            args.K,
            root=args.data_path + args.dataset,
            subset=subset,
            dataset=load_func,
            load_only=True,
            save_pk=False)
    else:
        X, label, train_mask, val_mask, test_mask, L = geometric_dataset_sparse(
            args.q,
            args.K,
            root=args.data_path + args.dataset,
            subset=subset,
            dataset=load_func,
            load_only=False,
            save_pk=True)

    # normalize label, the minimum should be 0 as class index
    _label_ = label - np.amin(label)
    cluster_dim = np.amax(_label_) + 1

    # convert dense laplacian to sparse matrix
    L_img = []
    L_real = []
    for i in range(len(L)):
        L_img.append(sparse_mx_to_torch_sparse_tensor(L[i].imag).to(device))
        L_real.append(sparse_mx_to_torch_sparse_tensor(L[i].real).to(device))

    label = torch.from_numpy(_label_[np.newaxis]).to(device)
    X_img = torch.FloatTensor(X).to(device)
    X_real = torch.FloatTensor(X).to(device)
    criterion = nn.NLLLoss()

    splits = train_mask.shape[1]
    if len(test_mask.shape) == 1:
        #data.test_mask = test_mask.unsqueeze(1).repeat(1, splits)
        test_mask = np.repeat(test_mask[:, np.newaxis], splits, 1)

    results = np.zeros((splits, 4))
    for split in range(splits):
        log_str_full = ''

        model = ChebNet(X_real.size(-1),
                        L_real,
                        L_img,
                        K=args.K,
                        label_dim=cluster_dim,
                        layer=args.layer,
                        activation=args.activation,
                        num_filter=args.num_filter,
                        dropout=args.dropout).to(device)

        opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)

        best_test_acc = 0.0
        train_index = train_mask[:, split]
        val_index = val_mask[:, split]
        test_index = test_mask[:, split]

        #################################
        # Train/Validation/Test
        #################################
        best_test_err = 1000.0
        early_stopping = 0
        for epoch in range(args.epochs):
            start_time = time.time()
            ####################
            # Train
            ####################
            count, train_loss, train_acc = 0.0, 0.0, 0.0

            # for loop for batch loading
            count += np.sum(train_index)

            model.train()
            preds = model(X_real, X_img)
            train_loss = criterion(preds[:, :, train_index],
                                   label[:, train_index])
            pred_label = preds.max(dim=1)[1]
            train_acc = 1.0 * (
                (pred_label[:, train_index]
                 == label[:, train_index])).sum().detach().item() / count
            opt.zero_grad()
            train_loss.backward()
            opt.step()

            outstrtrain = 'Train loss:, %.6f, acc:, %.3f,' % (
                train_loss.detach().item(), train_acc)
            #scheduler.step()
            ####################
            # Validation
            ####################
            model.eval()
            count, test_loss, test_acc = 0.0, 0.0, 0.0

            # for loop for batch loading
            count += np.sum(val_index)
            preds = model(X_real, X_img)
            pred_label = preds.max(dim=1)[1]

            test_loss = criterion(preds[:, :, val_index], label[:, val_index])
            test_acc = 1.0 * (
                (pred_label[:, val_index]
                 == label[:, val_index])).sum().detach().item() / count

            outstrval = ' Test loss:, %.6f, acc:, %.3f,' % (
                test_loss.detach().item(), test_acc)

            duration = "---, %.4f, seconds ---" % (time.time() - start_time)
            log_str = (
                "%d ,/, %d ,epoch," %
                (epoch, args.epochs)) + outstrtrain + outstrval + duration
            log_str_full += log_str + '\n'
            #print(log_str)

            ####################
            # Save weights
            ####################
            save_perform = test_loss.detach().item()
            if save_perform <= best_test_err:
                early_stopping = 0
                best_test_err = save_perform
                torch.save(model.state_dict(),
                           log_path + '/model' + str(split) + '.t7')
            else:
                early_stopping += 1
            if early_stopping > 500 or epoch == (args.epochs - 1):
                torch.save(model.state_dict(),
                           log_path + '/model_latest' + str(split) + '.t7')
                break

        write_log(vars(args), log_path)

        ####################
        # Testing
        ####################
        model.load_state_dict(
            torch.load(log_path + '/model' + str(split) + '.t7'))
        model.eval()
        preds = model(X_real, X_img)
        pred_label = preds.max(dim=1)[1]
        np.save(log_path + '/pred' + str(split), pred_label.to('cpu'))

        count = np.sum(val_index)
        acc_train = (1.0 *
                     ((pred_label[:, val_index]
                       == label[:, val_index])).sum().detach().item()) / count

        count = np.sum(test_index)
        acc_test = (1.0 *
                    ((pred_label[:, test_index]
                      == label[:, test_index])).sum().detach().item()) / count

        model.load_state_dict(
            torch.load(log_path + '/model_latest' + str(split) + '.t7'))
        model.eval()
        preds = model(X_real, X_img)
        pred_label = preds.max(dim=1)[1]
        np.save(log_path + '/pred_latest' + str(split), pred_label.to('cpu'))

        count = np.sum(val_index)
        acc_train_latest = (1.0 * (
            (pred_label[:, val_index]
             == label[:, val_index])).sum().detach().item()) / count

        count = np.sum(test_index)
        acc_test_latest = (1.0 * (
            (pred_label[:, test_index]
             == label[:, test_index])).sum().detach().item()) / count

        ####################
        # Save testing results
        ####################
        logstr = 'val_acc: ' + str(
            np.round(acc_train, 3)) + ' test_acc: ' + str(np.round(
                acc_test, 3)) + ' val_acc_latest: ' + str(
                    np.round(acc_train_latest,
                             3)) + ' test_acc_latest: ' + str(
                                 np.round(acc_test_latest, 3))
        print(logstr)
        results[split] = [
            acc_train, acc_test, acc_train_latest, acc_test_latest
        ]
        log_str_full += logstr
        with open(log_path + '/log' + str(split) + '.csv', 'w') as file:
            file.write(log_str_full)
            file.write('\n')
        torch.cuda.empty_cache()
    return results