Exemplo n.º 1
0
import time
import argparse
import numpy as np
# from train import train
import pickle as pkl
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
# from args import get_text_args
from args import get_citation_args
from utils import *
from train import train_linear, adj, sp_adj, label_dict, index_dict
import torch.nn.functional as F
from models import get_model
from math import log

# args = get_text_args()
args = get_citation_args()
set_seed(args.seed, args.cuda)

adj_dense = sparse_to_torch_dense(sp_adj, device='cpu')
feat_dict, precompute_time = sgc_precompute(adj, adj_dense, args.degree - 1,
                                            index_dict)
if args.dataset == "mr": nclass = 1
else: nclass = label_dict["train"].max().item() + 1


def linear_objective(space):
    model = get_model(args.model,
                      nfeat=feat_dict["train"].size(1),
                      nclass=nclass,
                      nhid=0,
                      dropout=0,
Exemplo n.º 2
0
def main():
    args = get_citation_args()
    n_way = args.n_way
    train_shot = args.train_shot
    test_shot = args.test_shot
    step = args.step
    node_num = args.node_num
    iteration = args.iteration

    accuracy_meta_test = []
    total_accuracy_meta_test = []

    set_seed(args.seed, args.cuda)

    adj, features, labels = load_citation(args.dataset, args.normalization,
                                          args.cuda)

    if args.dataset == 'cora':
        class_label = [0, 1, 2, 3, 4, 5, 6]
        combination = list(combinations(class_label, n_way))
    elif args.dataset == 'citeseer':
        node_num = 3327
        iteration = 15
        class_label = [0, 1, 2, 3, 4, 5]
        combination = list(combinations(class_label, n_way))

    if args.model == 'SGC':
        features = sgc_precompute(features, adj, args.degree)

    for i in range(len(combination)):
        print('Cross_Validation: ', i + 1)
        test_label = list(combination[i])
        train_label = [n for n in class_label if n not in test_label]
        print('Cross_Validation {} Train_Label_List {}: '.format(
            i + 1, train_label))
        print('Cross_Validation {} Test_Label_List {}: '.format(
            i + 1, test_label))
        model = get_model(args.model, features.size(1), n_way, args.cuda)

        for j in range(iteration):
            labels_local = labels.clone().detach()
            select_class = random.sample(train_label, n_way)
            print('Cross_Validation {} ITERATION {} Train_Label: {}'.format(
                i + 1, j + 1, select_class))
            class1_idx = []
            class2_idx = []
            for k in range(node_num):
                if (labels_local[k] == select_class[0]):
                    class1_idx.append(k)
                    labels_local[k] = 0
                elif (labels_local[k] == select_class[1]):
                    class2_idx.append(k)
                    labels_local[k] = 1
            for m in range(step):
                class1_train = random.sample(class1_idx, train_shot)
                class2_train = random.sample(class2_idx, train_shot)
                class1_test = [
                    n1 for n1 in class1_idx if n1 not in class1_train
                ]
                class2_test = [
                    n2 for n2 in class2_idx if n2 not in class2_train
                ]
                train_idx = class1_train + class2_train
                random.shuffle(train_idx)
                test_idx = class1_test + class2_test
                random.shuffle(test_idx)

                model = train_regression(model, features[train_idx],
                                         labels_local[train_idx], args.epochs,
                                         args.weight_decay, args.lr)
                acc_query = test_regression(model, features[test_idx],
                                            labels_local[test_idx])
                reset_array()

        torch.save(model.state_dict(), 'model.pkl')

        labels_local = labels.clone().detach()
        select_class = random.sample(test_label, 2)
        class1_idx = []
        class2_idx = []
        reset_array()
        print('Cross_Validation {} Test_Label {}: '.format(
            i + 1, select_class))

        for k in range(node_num):
            if (labels_local[k] == select_class[0]):
                class1_idx.append(k)
                labels_local[k] = 0
            elif (labels_local[k] == select_class[1]):
                class2_idx.append(k)
                labels_local[k] = 1

        for m in range(step):
            class1_train = random.sample(class1_idx, test_shot)
            class2_train = random.sample(class2_idx, test_shot)
            class1_test = [n1 for n1 in class1_idx if n1 not in class1_train]
            class2_test = [n2 for n2 in class2_idx if n2 not in class2_train]
            train_idx = class1_train + class2_train
            random.shuffle(train_idx)
            test_idx = class1_test + class2_test
            random.shuffle(test_idx)

            model_meta_trained = get_model(args.model, features.size(1), n_way,
                                           args.cuda).cuda()
            model_meta_trained.load_state_dict(torch.load('model.pkl'))

            model_meta_trained = train_regression(model_meta_trained,
                                                  features[train_idx],
                                                  labels_local[train_idx],
                                                  args.epochs,
                                                  args.weight_decay, args.lr)
            acc_test = test_regression(model_meta_trained, features[test_idx],
                                       labels_local[test_idx])
            accuracy_meta_test.append(acc_test)
            total_accuracy_meta_test.append(acc_test)
            reset_array()
        if args.dataset == 'cora':
            with open('cora.txt', 'a') as f:
                f.write('Cross_Validation: {} Meta-Test_Accuracy: {}'.format(
                    i + 1,
                    torch.tensor(accuracy_meta_test).numpy().mean()))
                f.write('\n')
        elif args.dataset == 'citeseer':
            with open('citeseer.txt', 'a') as f:
                f.write('Cross_Validation: {} Meta-Test_Accuracy: {}'.format(
                    i + 1,
                    torch.tensor(accuracy_meta_test).numpy().mean()))
                f.write('\n')
        accuracy_meta_test = []
    if args.dataset == 'cora':
        with open('cora.txt', 'a') as f:
            f.write('Dataset: {}, Train_Shot: {}, Test_Shot: {}'.format(
                args.dataset, train_shot, test_shot))
            f.write('\n')
            f.write('Total_Meta-Test_Accuracy: {}'.format(
                torch.tensor(total_accuracy_meta_test).numpy().mean()))
            f.write('\n')
            f.write('\n\n\n')
    elif args.dataset == 'citeseer':
        with open('citeseer.txt', 'a') as f:
            f.write('Dataset: {}, Train_Shot: {}, Test_Shot: {}'.format(
                args.dataset, train_shot, test_shot))
            f.write('\n')
            f.write('Total_Meta-Test_Accuracy: {}'.format(
                torch.tensor(total_accuracy_meta_test).numpy().mean()))
            f.write('\n')
            f.write('\n\n\n')