예제 #1
0
from config import config, overwrite_config_with_args
from logger_init import logger_init
from data_utils import inplace_shuffle, heads_tails
from select_gpu import select_gpu
from trans_e import TransE
from trans_d import TransD
from distmult import DistMult
from compl_ex import ComplEx

logger_init()
torch.cuda.set_device(select_gpu())
overwrite_config_with_args()

task_dir = config().task.dir
kb_index = index_ent_rel(os.path.join(task_dir, 'train.txt'),
                         os.path.join(task_dir, 'valid.txt'),
                         os.path.join(task_dir, 'test.txt'))
n_ent, n_rel = graph_size(kb_index)

train_data = read_data(os.path.join(task_dir, 'train.txt'), kb_index)
inplace_shuffle(*train_data)
valid_data = read_data(os.path.join(task_dir, 'valid.txt'), kb_index)
test_data = read_data(os.path.join(task_dir, 'test.txt'), kb_index)
heads, tails = heads_tails(n_ent, train_data, valid_data, test_data)
valid_data = [torch.LongTensor(vec) for vec in valid_data]
test_data = [torch.LongTensor(vec) for vec in test_data]
tester = lambda: gen.test_link(valid_data, n_ent, heads, tails)
train_data = [torch.LongTensor(vec) for vec in train_data]

mdl_type = config().pretrain_config
gen_config = config()[mdl_type]
예제 #2
0
파일: gan_train.py 프로젝트: cai-lw/KBGAN
from trans_d import TransD
from distmult import DistMult
from compl_ex import ComplEx
from logger_init import logger_init
from select_gpu import select_gpu
from corrupter import BernCorrupterMulti


logger_init()
torch.cuda.set_device(select_gpu())
overwrite_config_with_args()
dump_config()

task_dir = config().task.dir
kb_index = index_ent_rel(os.path.join(task_dir, 'train.txt'),
                         os.path.join(task_dir, 'valid.txt'),
                         os.path.join(task_dir, 'test.txt'))
n_ent, n_rel = graph_size(kb_index)

models = {'TransE': TransE, 'TransD': TransD, 'DistMult': DistMult, 'ComplEx': ComplEx}
gen_config = config()[config().g_config]
dis_config = config()[config().d_config]
gen = models[config().g_config](n_ent, n_rel, gen_config)
dis = models[config().d_config](n_ent, n_rel, dis_config)
gen.load(os.path.join(task_dir, gen_config.model_file))
dis.load(os.path.join(task_dir, dis_config.model_file))

train_data = read_data(os.path.join(task_dir, 'train.txt'), kb_index)
inplace_shuffle(*train_data)
valid_data = read_data(os.path.join(task_dir, 'valid.txt'), kb_index)
test_data = read_data(os.path.join(task_dir, 'test.txt'), kb_index)