コード例 #1
0
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from args import read_args
import numpy as np
import string
import re
import math
args = read_args()


class HetAgg(nn.Module):
	def __init__(self, args, feature_list, a_neigh_list_train, p_neigh_list_train, v_neigh_list_train,\
		 a_train_id_list, p_train_id_list, v_train_id_list):
		super(HetAgg, self).__init__()
		embed_d = args.embed_d
		in_f_d = args.in_f_d
		self.args = args 
		self.P_n = args.P_n
		self.A_n = args.A_n
		self.V_n = args.V_n
		self.feature_list = feature_list
		self.a_neigh_list_train = a_neigh_list_train
		self.p_neigh_list_train = p_neigh_list_train
		self.v_neigh_list_train = v_neigh_list_train
		self.a_train_id_list = a_train_id_list
		self.p_train_id_list = p_train_id_list
		self.v_train_id_list = v_train_id_list

		#self.fc_a_agg = nn.Linear(embed_d * 4, embed_d)
コード例 #2
0
ファイル: main.py プロジェクト: mywasd/MHGNN_2020
def main(args):
    # If args['hetero'] is True, g would be a heterogeneous graph.
    # Otherwise, it will be a list of homogeneous graphs.
    args_academic = read_args()
    data = dataprocess_han.input_data_han(args_academic)
    #g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
    #val_mask, test_mask = load_data(args['dataset'])
    features = torch.tensor(data.a_text_embed, dtype=torch.float32)
    labels = torch.tensor(data.a_class)

    APA_g = dgl.graph(data.APA_matrix, ntype='author', etype='coauthor')
    APVPA_g = dgl.graph(data.APVPA_matrix, ntype='author', etype='attendance')
    APPA_g = dgl.graph(data.APPA_matrix, ntype='author', etype='reference')

    #g = [APA_g, APPA_g]
    g = [APA_g, APVPA_g, APPA_g]

    num_classes = 4
    features = features.to(args['device'])
    labels = labels.to(args['device'])

    #if args['hetero']:
    #from model_hetero import HAN
    #model = HAN(meta_paths=[['pa', 'ap'], ['pf', 'fp']],
    #in_size=features.shape[1],
    #hidden_size=args['hidden_units'],
    #out_size=num_classes,
    #num_heads=args['num_heads'],
    #dropout=args['dropout']).to(args['device'])
    #else:
    model = HAN(num_meta_paths=len(g),
                in_size=features.shape[1],
                hidden_size=args['hidden_units'],
                out_size=num_classes,
                num_heads=args['num_heads'],
                dropout=args['dropout']).to(args['device'])

    stopper = EarlyStopping(patience=args['patience'])
    loss_fcn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args['lr'],
                                 weight_decay=args['weight_decay'])
    model.load_state_dict(torch.load("./model_para.pt"))

    for epoch in range(args['num_epochs']):

        X = [[i] for i in range(args_academic.A_n)]
        train_X, test_X, _, _ = train_test_split(X, X, test_size=0.8)  #
        train_X, test_X, _, _ = train_test_split(train_X,
                                                 train_X,
                                                 test_size=0.2)  #

        train_mask = get_binary_mask(args_academic.A_n, train_X)
        test_mask = get_binary_mask(args_academic.A_n, test_X)

        #train_mask = torch.tensor(data.train_mask)
        #test_mask = torch.tensor(data.test_mask)
        val_mask = test_mask
        train_mask = train_mask.to(args['device'])
        val_mask = val_mask.to(args['device'])
        test_mask = test_mask.to(args['device'])
        model.train()
        logits, _ = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_acc, train_micro_f1, train_macro_f1 = score(
            logits[train_mask], labels[train_mask])
        val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
            model, g, features, labels, val_mask, loss_fcn)
        early_stop = stopper.step(val_loss.data.item(), val_acc, model)

        print(
            'Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
            'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.
            format(epoch + 1, loss.item(), train_micro_f1, train_macro_f1,
                   val_loss.item(), val_micro_f1, val_macro_f1))

        if early_stop:
            break

    stopper.load_checkpoint(model)
    model.eval()
    _, embedding = model(g, features)
    embed_file = open("./node_embedding.txt", "w")
    for k in range(embedding.shape[0]):
        embed_file.write('a' + str(k) + " ")
        for l in range(embedding.shape[1] - 1):
            embed_file.write(str(embedding[k][l].item()) + " ")
        embed_file.write(str(embedding[k][-1].item()) + "\n")
    embed_file.close()
    #test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
    #print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
    #test_loss.item(), test_micro_f1, test_macro_f1))
    torch.save(model.state_dict(), "./model_para.pt")