def val(model, val_features, val_index, val_label, start_select=False): val_feature_dict, val_index_dict, val_label_dict, val_rows_dict = mp.combine_features_dict(val_features, val_index, val_label, DEVICE) model.eval() # 测试模型 with torch.no_grad(): # 关闭无用的梯度计算-防止显存爆炸 val_logits_dict, _ = model(val_rows_dict, val_feature_dict, start_select) # 若要测试每类节点上的f1值,则不能把它们拼在一起 # 我们的数据集只有一类点,拼接与否效果一样 y_pred = [] y_true = [] for type in val_logits_dict: y_pred.extend(val_logits_dict[type].max(1)[1].cpu().numpy().tolist()) # 预测标签:对预测结果按行取argmax y_true.extend(val_label_dict[type].cpu().numpy().tolist()) # 计算在测试节点/数据上的准确率 micro_f1 = sm.f1_score(y_true, y_pred, average='micro') macro_f1 = sm.f1_score(y_true, y_pred, average='macro') return micro_f1, macro_f1
def train(model, epochs, method="all_node", ablation="all"): # 提前加载训练集和验证集到内存,节约时间。 def index_to_feature_wrapper(dict): return mp.index_to_features(dict, data.x, method) start_select = 50 train_index = train_list[:, 0].tolist() print("Loading dataset with thread pool...") train_metapath = pool.map(node_search_wrapper, train_index) train_features = pool.map(index_to_feature_wrapper, train_metapath) val_index = val_list[:, 0].tolist() val_label = val_list[:, 1] val_metapath = pool.map(node_search_wrapper, val_index) val_features = pool.map(index_to_feature_wrapper, val_metapath) lr = learning_rate model.train() # 训练模式 best_micro_f1 = 0 best_macro_f1 = 0 type_set = set() metapath_set = {} for node in val_metapath: for key in node: type_set.add(key[0]) for type in type_set: metapath_set[type] = set() for node in val_metapath: for key in node: if len(node) - 1 > len(metapath_set[key[0]]): if len(key) > 1: metapath_set[key[0]].add(key) metapath_label = {} metapath_onehot = {} discriminator = {} d_optimizer = {} label = {} for type in type_set: metapath_label[type] = {} metapath_onehot[type] = {} label[type] = [] for i, metapath in enumerate(metapath_set[type]): metapath_label[type][metapath] = torch.zeros(batch_size, data.type_num, device=DEVICE) metapath_onehot[type][metapath] = torch.zeros(batch_size, device=DEVICE).long() metapath_onehot[type][metapath][:] = i for all_type in data.node_dict: if all_type in metapath[1:]: metapath_label[type][metapath][:, data.node_dict[all_type]] = 1 label[type].append(metapath_label[type][metapath]) label[type] = torch.cat(label[type], dim=0) discriminator[type] = Discriminator(info_section, data.type_num).to(DEVICE) d_optimizer[type] = optim.Adam(discriminator[type].parameters(), lr=0.01, weight_decay=5e-4) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4) select_flag = False time1 = time.time() for e in range(epochs): # if single_path_limit is not None and (e + 1) % 20 == 0: # print("Re-sampling...") # train_metapath = pool.map(node_search_wrapper, train_index) # train_features = pool.map(index_to_feature_wrapper, train_metapath) for batch in range(num_batch_per_epoch): batch_src_choice = np.random.choice(range(train_list.shape[0]), size=(batch_size,), replace=False) batch_src_index = train_list[batch_src_choice, 0] batch_src_label = train_list[batch_src_choice, 1] batch_feature_list = [train_features[i] for i in batch_src_choice] batch_train_feature_dict, batch_src_index_dict, batch_src_label_dict, batch_train_rows_dict = mp.combine_features_dict( batch_feature_list, batch_src_index, batch_src_label, DEVICE) optimizer.zero_grad() for type in d_optimizer: d_optimizer[type].zero_grad() if e >= start_select: batch_train_logits_dict, GAN_input = model(batch_train_rows_dict, batch_train_feature_dict, True) else: batch_train_logits_dict, GAN_input = model(batch_train_rows_dict, batch_train_feature_dict, False) # 获取模型的输出 for type in batch_train_logits_dict: Loss_Classification = criterion(batch_train_logits_dict[type], batch_src_label_dict[type]) assert ablation in ["all", "no_align"] if ablation == "all": Pred_D = discriminator[type](GAN_input[type]) Pred_Shuffle = discriminator[type](GAN_input[type], True) Sorted_Pred_D = [] Sorted_Pred_Shuffle = [] for metapath in metapath_set[type]: Sorted_Pred_D.append(Pred_D[metapath]) Sorted_Pred_Shuffle.append(Pred_Shuffle[metapath]) Sorted_Pred_D = torch.cat(Sorted_Pred_D, dim=0) Sorted_Pred_Shuffle = torch.cat(Sorted_Pred_Shuffle, dim=0) Loss_D = nn.BCELoss()(Sorted_Pred_D, label[type]) Loss_D_Shuffle = nn.BCELoss()(Sorted_Pred_Shuffle, torch.zeros_like(Sorted_Pred_Shuffle, device=DEVICE)) Loss = Loss_Classification + Loss_D + Loss_D_Shuffle else: Loss = Loss_Classification Loss.backward() d_optimizer[type].step() optimizer.step() if e >= start_select and select_flag == False: select_flag = True pretrain_convergence = time2 - time1 print("Start select! Best f1-score reset to 0.") print("Pretrain convergence time:", pretrain_convergence) time1 = time.time() best_micro_f1 = 0 best_macro_f1 = 0 if select_flag: micro_f1, macro_f1 = val(model, val_features, val_index, val_label, True) model.show_metapath_importance() else: micro_f1, macro_f1 = val(model, val_features, val_index, val_label) if micro_f1 >= best_micro_f1: if micro_f1 > best_micro_f1: time2 = time.time() best_micro_f1 = micro_f1 best_macro_f1 = macro_f1 if select_flag: torch.save(model.state_dict(), "checkpoint/" + dataset + "_best_val") elif macro_f1 > best_macro_f1: best_micro_f1 = micro_f1 best_macro_f1 = macro_f1 if select_flag: torch.save(model.state_dict(), "checkpoint/" + dataset + "_best_val") select_convergence = time2 - time1 print("Epoch ", e, ",Val Micro_f1 is ", micro_f1, ", Macro_f1 is ", macro_f1, ", the best micro is ", best_micro_f1, ", the best macro is ", best_macro_f1) print("Select convergence time:", select_convergence) with open("result/" + dataset + "_time.txt", "a") as f: f.write("Pretrain_convergence:" + str(pretrain_convergence) + "\tSelect_convergence:" + str(select_convergence) + "\n") torch.save(model.state_dict(), "checkpoint/" + dataset + "_final")
def test(model, batch_size=200, test_method="best_val"): model.load_state_dict(torch.load("checkpoint/" + dataset + "_" + test_method)) def index_to_feature_wrapper(dict): return mp.index_to_features(dict, data.x) test_index = test_list[:, 0].tolist() print("Loading dataset with thread pool...") time1 = time.time() test_metapath = pool.map(node_search_wrapper, test_index) test_features = pool.map(index_to_feature_wrapper, test_metapath) time2 = time.time() print("Dataset Loaded. Time consumption:", time2 - time1) # 若要测试每类节点上的f1值,则不能把它们拼在一起 # 我们的数据集只有一类点,拼接与否效果一样 y_pred = [] y_true = [] model.eval() # 测试模型 with torch.no_grad(): batch = 0 while batch < len(test_index): end = batch + batch_size if batch + batch_size <= len(test_index) else len(test_index) batch_test_index = test_list[batch:end, 0] batch_test_label = test_list[batch:end, 1] batch_feature_list = [test_features[i] for i in range(batch, end)] batch_test_feature_dict, batch_test_index_dict, batch_test_label_dict, batch_test_rows_dict = mp.combine_features_dict( batch_feature_list, batch_test_index, batch_test_label, DEVICE) batch += batch_size batch_test_logits_dict, _ = model(batch_test_rows_dict, batch_test_feature_dict, True) for type in batch_test_logits_dict: y_pred.extend(batch_test_logits_dict[type].max(1)[1].cpu().numpy().tolist()) y_true.extend(batch_test_label_dict[type].cpu().numpy().tolist()) micro_f1 = sm.f1_score(y_true, y_pred, average='micro') macro_f1 = sm.f1_score(y_true, y_pred, average='macro') print("Final F1 @ " + test_method + ":") print("Micro_F1:\t", micro_f1, "\tMacro_F1:", macro_f1) model.show_metapath_importance() if shuffle == True: feature_mode = "shuffle_" else: feature_mode = "" with open("result/" + dataset + "_" + str( train_percent) + "_" + test_method + "_" + feature_mode + ablation + ".txt", "a") as f: f.write("Micro_F1:" + str(micro_f1) + "\tMacro_F1:" + str(macro_f1) + "\n") return micro_f1
def test(model, batch_size=200, test_method="best_val"): model.load_state_dict( torch.load("checkpoint/" + dataset + "_" + test_method)) def index_to_feature_wrapper(dict): return mp.index_to_features(dict, data.x) test_index = test_list[:, 0].tolist() print("Loading dataset with thread pool...") time1 = time.time() test_metapath = pool.map(node_search_wrapper, test_index) test_features = pool.map(index_to_feature_wrapper, test_metapath) time2 = time.time() print("Dataset Loaded. Time consumption:", time2 - time1) print("Ready for visualization...") type_set = set() metapath_set = {} for node in test_metapath: for key in node: type_set.add(key[0]) for type in type_set: metapath_set[type] = set() for node in test_metapath: for key in node: if len(node) - 1 > len(metapath_set[key[0]]): if len(key) > 1: metapath_set[key[0]].add(key) metapath_label = {} metapath_onehot = {} onehot2name = {} label = {} one_hot = {} for type in type_set: metapath_label[type] = {} metapath_onehot[type] = {} label[type] = [] one_hot[type] = [] for i, metapath in enumerate(metapath_set[type]): metapath_label[type][metapath] = torch.zeros( batch_size, data.type_num) metapath_onehot[type][metapath] = torch.zeros(batch_size).long() metapath_onehot[type][metapath][:] = i onehot2name[i] = metapath for all_type in data.node_dict: if all_type in metapath[1:]: metapath_label[type][metapath][:, data. node_dict[all_type]] = 1 label[type].append(metapath_label[type][metapath]) one_hot[type].append(metapath_onehot[type][metapath]) label[type] = torch.cat(label[type], dim=0) one_hot[type] = torch.cat(one_hot[type], dim=0) # 若要测试每类节点上的f1值,则不能把它们拼在一起 # 我们的数据集只有一类点,拼接与否效果一样 y_pred = [] y_true = [] model.eval() # 测试模型 path = [] final_onehot = [] with torch.no_grad(): batch = 0 while batch < len(test_index): end = batch + batch_size if batch + batch_size <= len( test_index) else len(test_index) batch_test_index = test_list[batch:end, 0] batch_test_label = test_list[batch:end, 1] batch_feature_list = [test_features[i] for i in range(batch, end)] batch_test_feature_dict, batch_test_index_dict, batch_test_label_dict, batch_test_rows_dict = mp.combine_features_dict( batch_feature_list, batch_test_index, batch_test_label, DEVICE) batch_test_logits_dict, path_vec = model(batch_test_rows_dict, batch_test_feature_dict, True) # 获取模型的输出 for type in batch_test_logits_dict: for metapath in metapath_set[type]: if end - batch == batch_size: path.append(path_vec[type][metapath].cpu()) if end - batch == batch_size: final_onehot.append(one_hot[type]) y_pred.extend(batch_test_logits_dict[type].max( 1)[1].cpu().numpy().tolist()) # 预测标签:对预测结果按行取argmax y_true.extend(batch_test_label_dict[type].cpu().numpy().tolist( )) # 计算在测试节点/数据上的准确率 batch += batch_size path = torch.cat(path, dim=0).numpy() final_onehot = torch.cat(final_onehot, dim=0).numpy() micro_f1 = sm.f1_score(y_true, y_pred, average='micro') macro_f1 = sm.f1_score(y_true, y_pred, average='macro') print("Final F1 @ " + test_method + ":") print("Micro_F1:\t", micro_f1, "\tMacro_F1:", macro_f1) model.show_metapath_importance() index = np.random.choice(list(range(path.shape[0])), 1000, replace=False) path = path[index, :] final_onehot = final_onehot[index] pca(path, final_onehot, 3, onehot2name)