예제 #1
0
def val(model, val_features, val_index, val_label, start_select=False):
    val_feature_dict, val_index_dict, val_label_dict, val_rows_dict = mp.combine_features_dict(val_features,
                                                                                               val_index, val_label,
                                                                                               DEVICE)
    model.eval()  # 测试模型
    with torch.no_grad():  # 关闭无用的梯度计算-防止显存爆炸
        val_logits_dict, _ = model(val_rows_dict, val_feature_dict, start_select)
        # 若要测试每类节点上的f1值,则不能把它们拼在一起
        # 我们的数据集只有一类点,拼接与否效果一样
        y_pred = []
        y_true = []
        for type in val_logits_dict:
            y_pred.extend(val_logits_dict[type].max(1)[1].cpu().numpy().tolist())  # 预测标签:对预测结果按行取argmax
            y_true.extend(val_label_dict[type].cpu().numpy().tolist())  # 计算在测试节点/数据上的准确率
        micro_f1 = sm.f1_score(y_true, y_pred, average='micro')
        macro_f1 = sm.f1_score(y_true, y_pred, average='macro')
        return micro_f1, macro_f1
예제 #2
0
def train(model, epochs, method="all_node", ablation="all"):
    # 提前加载训练集和验证集到内存,节约时间。

    def index_to_feature_wrapper(dict):
        return mp.index_to_features(dict, data.x, method)

    start_select = 50
    train_index = train_list[:, 0].tolist()
    print("Loading dataset with thread pool...")
    train_metapath = pool.map(node_search_wrapper, train_index)
    train_features = pool.map(index_to_feature_wrapper, train_metapath)
    val_index = val_list[:, 0].tolist()
    val_label = val_list[:, 1]
    val_metapath = pool.map(node_search_wrapper, val_index)
    val_features = pool.map(index_to_feature_wrapper, val_metapath)
    lr = learning_rate
    model.train()  # 训练模式
    best_micro_f1 = 0
    best_macro_f1 = 0

    type_set = set()
    metapath_set = {}
    for node in val_metapath:
        for key in node:
            type_set.add(key[0])
    for type in type_set:
        metapath_set[type] = set()
    for node in val_metapath:
        for key in node:
            if len(node) - 1 > len(metapath_set[key[0]]):
                if len(key) > 1:
                    metapath_set[key[0]].add(key)

    metapath_label = {}
    metapath_onehot = {}
    discriminator = {}
    d_optimizer = {}
    label = {}
    for type in type_set:
        metapath_label[type] = {}
        metapath_onehot[type] = {}
        label[type] = []
        for i, metapath in enumerate(metapath_set[type]):
            metapath_label[type][metapath] = torch.zeros(batch_size, data.type_num, device=DEVICE)
            metapath_onehot[type][metapath] = torch.zeros(batch_size, device=DEVICE).long()
            metapath_onehot[type][metapath][:] = i
            for all_type in data.node_dict:
                if all_type in metapath[1:]:
                    metapath_label[type][metapath][:, data.node_dict[all_type]] = 1
            label[type].append(metapath_label[type][metapath])
        label[type] = torch.cat(label[type], dim=0)
        discriminator[type] = Discriminator(info_section, data.type_num).to(DEVICE)
        d_optimizer[type] = optim.Adam(discriminator[type].parameters(), lr=0.01, weight_decay=5e-4)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    select_flag = False
    time1 = time.time()
    for e in range(epochs):
        # if single_path_limit is not None and (e + 1) % 20 == 0:
        #     print("Re-sampling...")
        #     train_metapath = pool.map(node_search_wrapper, train_index)
        #     train_features = pool.map(index_to_feature_wrapper, train_metapath)
        for batch in range(num_batch_per_epoch):
            batch_src_choice = np.random.choice(range(train_list.shape[0]), size=(batch_size,), replace=False)
            batch_src_index = train_list[batch_src_choice, 0]
            batch_src_label = train_list[batch_src_choice, 1]
            batch_feature_list = [train_features[i] for i in batch_src_choice]
            batch_train_feature_dict, batch_src_index_dict, batch_src_label_dict, batch_train_rows_dict = mp.combine_features_dict(
                batch_feature_list,
                batch_src_index,
                batch_src_label, DEVICE)

            optimizer.zero_grad()
            for type in d_optimizer:
                d_optimizer[type].zero_grad()
            if e >= start_select:
                batch_train_logits_dict, GAN_input = model(batch_train_rows_dict, batch_train_feature_dict, True)
            else:
                batch_train_logits_dict, GAN_input = model(batch_train_rows_dict, batch_train_feature_dict,
                                                           False)  # 获取模型的输出

            for type in batch_train_logits_dict:
                Loss_Classification = criterion(batch_train_logits_dict[type], batch_src_label_dict[type])
                assert ablation in ["all", "no_align"]
                if ablation == "all":
                    Pred_D = discriminator[type](GAN_input[type])
                    Pred_Shuffle = discriminator[type](GAN_input[type], True)
                    Sorted_Pred_D = []
                    Sorted_Pred_Shuffle = []

                    for metapath in metapath_set[type]:
                        Sorted_Pred_D.append(Pred_D[metapath])
                        Sorted_Pred_Shuffle.append(Pred_Shuffle[metapath])

                    Sorted_Pred_D = torch.cat(Sorted_Pred_D, dim=0)
                    Sorted_Pred_Shuffle = torch.cat(Sorted_Pred_Shuffle, dim=0)

                    Loss_D = nn.BCELoss()(Sorted_Pred_D, label[type])
                    Loss_D_Shuffle = nn.BCELoss()(Sorted_Pred_Shuffle,
                                                  torch.zeros_like(Sorted_Pred_Shuffle, device=DEVICE))

                    Loss = Loss_Classification + Loss_D + Loss_D_Shuffle

                else:
                    Loss = Loss_Classification

                Loss.backward()
                d_optimizer[type].step()
            optimizer.step()

        if e >= start_select and select_flag == False:
            select_flag = True
            pretrain_convergence = time2 - time1
            print("Start select! Best f1-score reset to 0.")
            print("Pretrain convergence time:", pretrain_convergence)
            time1 = time.time()
            best_micro_f1 = 0
            best_macro_f1 = 0

        if select_flag:
            micro_f1, macro_f1 = val(model, val_features, val_index, val_label, True)
            model.show_metapath_importance()
        else:
            micro_f1, macro_f1 = val(model, val_features, val_index, val_label)
        if micro_f1 >= best_micro_f1:
            if micro_f1 > best_micro_f1:
                time2 = time.time()
                best_micro_f1 = micro_f1
                best_macro_f1 = macro_f1
                if select_flag:
                    torch.save(model.state_dict(), "checkpoint/" + dataset + "_best_val")
            elif macro_f1 > best_macro_f1:
                best_micro_f1 = micro_f1
                best_macro_f1 = macro_f1
                if select_flag:
                    torch.save(model.state_dict(), "checkpoint/" + dataset + "_best_val")
        select_convergence = time2 - time1
        print("Epoch ", e, ",Val Micro_f1 is ", micro_f1, ", Macro_f1 is ", macro_f1, ", the best micro is ",
              best_micro_f1, ", the best macro is ",
              best_macro_f1)
        print("Select convergence time:", select_convergence)
    with open("result/" + dataset + "_time.txt", "a") as f:
        f.write("Pretrain_convergence:" + str(pretrain_convergence) + "\tSelect_convergence:" + str(select_convergence) + "\n")
    torch.save(model.state_dict(), "checkpoint/" + dataset + "_final")
예제 #3
0
def test(model, batch_size=200, test_method="best_val"):
    model.load_state_dict(torch.load("checkpoint/" + dataset + "_" + test_method))

    def index_to_feature_wrapper(dict):
        return mp.index_to_features(dict, data.x)

    test_index = test_list[:, 0].tolist()
    print("Loading dataset with thread pool...")
    time1 = time.time()
    test_metapath = pool.map(node_search_wrapper, test_index)
    test_features = pool.map(index_to_feature_wrapper, test_metapath)
    time2 = time.time()
    print("Dataset Loaded. Time consumption:", time2 - time1)

    # 若要测试每类节点上的f1值,则不能把它们拼在一起
    # 我们的数据集只有一类点,拼接与否效果一样
    y_pred = []
    y_true = []
    model.eval()  # 测试模型
    with torch.no_grad():
        batch = 0
        while batch < len(test_index):
            end = batch + batch_size if batch + batch_size <= len(test_index) else len(test_index)
            batch_test_index = test_list[batch:end, 0]
            batch_test_label = test_list[batch:end, 1]
            batch_feature_list = [test_features[i] for i in range(batch, end)]

            batch_test_feature_dict, batch_test_index_dict, batch_test_label_dict, batch_test_rows_dict = mp.combine_features_dict(
                batch_feature_list,
                batch_test_index,
                batch_test_label, DEVICE)
            batch += batch_size
            batch_test_logits_dict, _ = model(batch_test_rows_dict, batch_test_feature_dict, True)
            for type in batch_test_logits_dict:
                y_pred.extend(batch_test_logits_dict[type].max(1)[1].cpu().numpy().tolist())
                y_true.extend(batch_test_label_dict[type].cpu().numpy().tolist())
        micro_f1 = sm.f1_score(y_true, y_pred, average='micro')
        macro_f1 = sm.f1_score(y_true, y_pred, average='macro')
        print("Final F1 @ " + test_method + ":")
        print("Micro_F1:\t", micro_f1, "\tMacro_F1:", macro_f1)
        model.show_metapath_importance()
        if shuffle == True:
            feature_mode = "shuffle_"
        else:
            feature_mode = ""
        with open("result/" + dataset + "_" + str(
                train_percent) + "_" + test_method + "_" + feature_mode + ablation + ".txt",
                  "a") as f:
            f.write("Micro_F1:" + str(micro_f1) + "\tMacro_F1:" + str(macro_f1) + "\n")
        return micro_f1
예제 #4
0
def test(model, batch_size=200, test_method="best_val"):
    model.load_state_dict(
        torch.load("checkpoint/" + dataset + "_" + test_method))

    def index_to_feature_wrapper(dict):
        return mp.index_to_features(dict, data.x)

    test_index = test_list[:, 0].tolist()
    print("Loading dataset with thread pool...")
    time1 = time.time()
    test_metapath = pool.map(node_search_wrapper, test_index)
    test_features = pool.map(index_to_feature_wrapper, test_metapath)
    time2 = time.time()
    print("Dataset Loaded. Time consumption:", time2 - time1)

    print("Ready for visualization...")
    type_set = set()
    metapath_set = {}
    for node in test_metapath:
        for key in node:
            type_set.add(key[0])
    for type in type_set:
        metapath_set[type] = set()
    for node in test_metapath:
        for key in node:
            if len(node) - 1 > len(metapath_set[key[0]]):
                if len(key) > 1:
                    metapath_set[key[0]].add(key)

    metapath_label = {}
    metapath_onehot = {}
    onehot2name = {}
    label = {}
    one_hot = {}
    for type in type_set:
        metapath_label[type] = {}
        metapath_onehot[type] = {}
        label[type] = []
        one_hot[type] = []
        for i, metapath in enumerate(metapath_set[type]):
            metapath_label[type][metapath] = torch.zeros(
                batch_size, data.type_num)
            metapath_onehot[type][metapath] = torch.zeros(batch_size).long()
            metapath_onehot[type][metapath][:] = i
            onehot2name[i] = metapath
            for all_type in data.node_dict:
                if all_type in metapath[1:]:
                    metapath_label[type][metapath][:, data.
                                                   node_dict[all_type]] = 1
            label[type].append(metapath_label[type][metapath])
            one_hot[type].append(metapath_onehot[type][metapath])
        label[type] = torch.cat(label[type], dim=0)
        one_hot[type] = torch.cat(one_hot[type], dim=0)

    # 若要测试每类节点上的f1值,则不能把它们拼在一起
    # 我们的数据集只有一类点,拼接与否效果一样
    y_pred = []
    y_true = []
    model.eval()  # 测试模型
    path = []
    final_onehot = []
    with torch.no_grad():
        batch = 0
        while batch < len(test_index):
            end = batch + batch_size if batch + batch_size <= len(
                test_index) else len(test_index)
            batch_test_index = test_list[batch:end, 0]
            batch_test_label = test_list[batch:end, 1]
            batch_feature_list = [test_features[i] for i in range(batch, end)]
            batch_test_feature_dict, batch_test_index_dict, batch_test_label_dict, batch_test_rows_dict = mp.combine_features_dict(
                batch_feature_list, batch_test_index, batch_test_label, DEVICE)
            batch_test_logits_dict, path_vec = model(batch_test_rows_dict,
                                                     batch_test_feature_dict,
                                                     True)  # 获取模型的输出
            for type in batch_test_logits_dict:
                for metapath in metapath_set[type]:
                    if end - batch == batch_size:
                        path.append(path_vec[type][metapath].cpu())
                if end - batch == batch_size:
                    final_onehot.append(one_hot[type])
                y_pred.extend(batch_test_logits_dict[type].max(
                    1)[1].cpu().numpy().tolist())  # 预测标签:对预测结果按行取argmax
                y_true.extend(batch_test_label_dict[type].cpu().numpy().tolist(
                ))  # 计算在测试节点/数据上的准确率
            batch += batch_size
    path = torch.cat(path, dim=0).numpy()
    final_onehot = torch.cat(final_onehot, dim=0).numpy()

    micro_f1 = sm.f1_score(y_true, y_pred, average='micro')
    macro_f1 = sm.f1_score(y_true, y_pred, average='macro')
    print("Final F1 @ " + test_method + ":")
    print("Micro_F1:\t", micro_f1, "\tMacro_F1:", macro_f1)
    model.show_metapath_importance()

    index = np.random.choice(list(range(path.shape[0])), 1000, replace=False)
    path = path[index, :]
    final_onehot = final_onehot[index]
    pca(path, final_onehot, 3, onehot2name)