# define the model
transe = TransE(ent_tot=train_dataloader.get_ent_tot(),
                rel_tot=train_dataloader.get_rel_tot(),
                dim=200,
                p_norm=1,
                norm_flag=True)

model_e = NegativeSampling(model=transe,
                           loss=MarginLoss(margin=5.0),
                           batch_size=train_dataloader.get_batch_size())

transr = TransR(ent_tot=train_dataloader.get_ent_tot(),
                rel_tot=train_dataloader.get_rel_tot(),
                dim_e=200,
                dim_r=200,
                p_norm=1,
                norm_flag=True,
                rand_init=False)

model_r = NegativeSampling(model=transr,
                           loss=MarginLoss(margin=4.0),
                           batch_size=train_dataloader.get_batch_size())

# pretrain transe
trainer = Trainer(model=model_e,
                  data_loader=train_dataloader,
                  train_times=1,
                  alpha=0.5,
                  use_gpu=True)
trainer.run()
예제 #2
0
# define the model
transe = TransE(ent_tot=train_dataloader.get_ent_tot(),
                rel_tot=train_dataloader.get_rel_tot(),
                dim=50,
                p_norm=1,
                norm_flag=True)

model_e = NegativeSampling(model=transe,
                           loss=MarginLoss(margin=4.0),
                           batch_size=train_dataloader.get_batch_size())

transr = TransR(ent_tot=train_dataloader.get_ent_tot(),
                rel_tot=train_dataloader.get_rel_tot(),
                dim_e=50,
                dim_r=100,
                p_norm=1,
                norm_flag=True,
                rand_init=False)

model_r = NegativeSampling(model=transr,
                           loss=MarginLoss(margin=4.0),
                           batch_size=train_dataloader.get_batch_size())

# pretrain transe
# trainer = Trainer(model = model_e, data_loader = train_dataloader, train_times = 1000, alpha = 0.5, use_gpu = False)
trainer = Trainer(model=model_e,
                  data_loader=train_dataloader,
                  train_times=1000,
                  alpha=1.0,
                  use_gpu=False)
extract_path_vec_list = []
with open("./benchmarks/FKB/relation2id.txt") as f:
    f.readline()
    for line in f.readlines():
        extract_path_vec_list.append(path_vec_list[int(line.split('\t')[0])])
f.close()

rel_embedding = nn.Embedding.from_pretrained(
    torch.from_numpy(
        np.array(extract_path_vec_list).astype(dtype='float64')).float())

# define the model
transr = TransR(ent_tot=train_dataloader.get_ent_tot(),
                rel_tot=train_dataloader.get_rel_tot(),
                dim_e=30,
                dim_r=50,
                p_norm=1,
                norm_flag=True,
                rand_init=False)

transr.load_rel_embeddings(rel_embedding)

model_r = NegativeSampling(model=transr,
                           loss=MarginLoss(margin=3.0),
                           batch_size=train_dataloader.get_batch_size())

for k, v in model_r.named_parameters():
    if k == 'model.rel_embeddings.weight':
        v.requires_grad = False

# train transr