loss=MarginLoss(margin=3.0),
                           batch_size=train_dataloader.get_batch_size())

for k, v in model_r.named_parameters():
    if k == 'model.rel_embeddings.weight':
        v.requires_grad = False

# train transr
# transr.set_parameters(parameters)
trainer = Trainer(model=model_r,
                  data_loader=train_dataloader,
                  train_times=200,
                  alpha=1.0,
                  use_gpu=False)
trainer.run()
transr.save_checkpoint('./checkpoint/fault_dataset_transr.ckpt')

epoch = trainer.epoch
loss = trainer.loss
line = plt.plot(epoch, loss, label=u'TransR')
plt.xlabel(u'epoch')
plt.ylabel(u'loss')
plt.show()
plt.savefig("./embedding/TransR/loss_epoch.png")

entity = trainer.model.model.ent_embeddings.weight.data
relationship = trainer.model.model.rel_embeddings.weight.data
transfer_matrix = trainer.model.model.transfer_matrix.weight.data

entity_np = np.array(entity)
np.savetxt('./embedding/TransR/fault_dataset_entity_result.txt', entity_np)
示例#2
0
                  alpha=1.0,
                  use_gpu=False)
trainer.run()
parameters = transe.get_parameters()
transe.save_parameters("./result/transr_transe.json")

# train transr
# transr.set_parameters(parameters)
transr.ent_embeddings = transe.ent_embeddings
trainer = Trainer(model=model_r,
                  data_loader=train_dataloader,
                  train_times=1000,
                  alpha=0.1,
                  use_gpu=False)
trainer.run()
transr.save_checkpoint('./checkpoint/transr.ckpt')

epoch = trainer.epoch
loss = trainer.loss
line = plt.plot(epoch, loss, label=u'TransR')
plt.xlabel(u'epoch')
plt.ylabel(u'loss')
plt.show()
plt.savefig("./embedding/TransR/loss_epoch.png")

entity = trainer.model.model.ent_embeddings.weight.data
relationship = trainer.model.model.rel_embeddings.weight.data

entity_np = np.array(entity)
np.savetxt('./embedding/TransR/entity_result.txt', entity_np)
示例#3
0
	ent_tot = train_dataloader.get_ent_tot(),
	rel_tot = train_dataloader.get_rel_tot(),
	dim_e = 20,
	dim_r = 20,
	p_norm = 1, 
	norm_flag = True,
	rand_init = False)

model_r = NegativeSampling(
	model = transr,
	loss = MarginLoss(margin = 4.0),
	batch_size = train_dataloader.get_batch_size()
)

# pretrain transe
trainer = Trainer(model = model_e, data_loader = train_dataloader, train_times = 1, alpha = 0.5, use_gpu = False)
trainer.run()
parameters = transe.get_parameters()
transe.save_parameters("./result/transr_transe_lumb.json")

# train transr
transr.set_parameters(parameters)
trainer = Trainer(model = model_r, data_loader = train_dataloader, train_times = 100, alpha = 1.0, use_gpu = False)
trainer.run()
transr.save_checkpoint('./checkpoint/transr_lumb2.ckpt')

# test the model
# transr.load_checkpoint('./checkpoint/transr.ckpt')
# tester = Tester(model = transr, data_loader = test_dataloader, use_gpu = False)
# tester.run_link_prediction(type_constrain = False)