from config import config, overwrite_config_with_args from logger_init import logger_init from data_utils import inplace_shuffle, heads_tails from select_gpu import select_gpu from trans_e import TransE from trans_d import TransD from distmult import DistMult from compl_ex import ComplEx logger_init() torch.cuda.set_device(select_gpu()) overwrite_config_with_args() task_dir = config().task.dir kb_index = index_ent_rel(os.path.join(task_dir, 'train.txt'), os.path.join(task_dir, 'valid.txt'), os.path.join(task_dir, 'test.txt')) n_ent, n_rel = graph_size(kb_index) train_data = read_data(os.path.join(task_dir, 'train.txt'), kb_index) inplace_shuffle(*train_data) valid_data = read_data(os.path.join(task_dir, 'valid.txt'), kb_index) test_data = read_data(os.path.join(task_dir, 'test.txt'), kb_index) heads, tails = heads_tails(n_ent, train_data, valid_data, test_data) valid_data = [torch.LongTensor(vec) for vec in valid_data] test_data = [torch.LongTensor(vec) for vec in test_data] tester = lambda: gen.test_link(valid_data, n_ent, heads, tails) train_data = [torch.LongTensor(vec) for vec in train_data] mdl_type = config().pretrain_config gen_config = config()[mdl_type]
from trans_d import TransD from distmult import DistMult from compl_ex import ComplEx from logger_init import logger_init from select_gpu import select_gpu from corrupter import BernCorrupterMulti logger_init() torch.cuda.set_device(select_gpu()) overwrite_config_with_args() dump_config() task_dir = config().task.dir kb_index = index_ent_rel(os.path.join(task_dir, 'train.txt'), os.path.join(task_dir, 'valid.txt'), os.path.join(task_dir, 'test.txt')) n_ent, n_rel = graph_size(kb_index) models = {'TransE': TransE, 'TransD': TransD, 'DistMult': DistMult, 'ComplEx': ComplEx} gen_config = config()[config().g_config] dis_config = config()[config().d_config] gen = models[config().g_config](n_ent, n_rel, gen_config) dis = models[config().d_config](n_ent, n_rel, dis_config) gen.load(os.path.join(task_dir, gen_config.model_file)) dis.load(os.path.join(task_dir, dis_config.model_file)) train_data = read_data(os.path.join(task_dir, 'train.txt'), kb_index) inplace_shuffle(*train_data) valid_data = read_data(os.path.join(task_dir, 'valid.txt'), kb_index) test_data = read_data(os.path.join(task_dir, 'test.txt'), kb_index)