def transe(data_path, embedding_dims, margin_value, score_func, batch_size,
           learning_rate, n_generator, n_rank_calculator, max_epoch):
    kg = KnowledgeGraph(data_dir=data_path)
    kge_model = TransE(kg=kg,
                       embedding_dim=embedding_dims,
                       margin_value=margin_value,
                       score_func=score_func,
                       batch_size=batch_size,
                       learning_rate=learning_rate,
                       n_generator=n_generator,
                       n_rank_calculator=n_rank_calculator)
    gpu_config = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    with tf.Session(config=sess_config) as sess:
        print('-----Initializing tf graph-----')
        tf.global_variables_initializer().run()
        print('-----Initialization accomplished-----')
        entity_embedding, relation_embedding = kge_model.check_norm(
            session=sess)
        summary_writer = tf.summary.FileWriter(logdir='../summary/',
                                               graph=sess.graph)
        for epoch in range(max_epoch):
            print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)
            kge_model.launch_training(session=sess,
                                      summary_writer=summary_writer)
            if (epoch + 1) % 10 == 0:
                kge_model.launch_evaluation(session=sess)
    return entity_embedding, relation_embedding
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--data_dir', type=str, default='./data/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=5000)
    parser.add_argument('--learning_rate', type=float, default=0.003)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--eval_freq', type=int, default=10000000)
    args = parser.parse_args()
    print(args)
    kg = KnowledgeGraph(data_dir=args.data_dir)
    kge_model = TransE(kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value,
                       score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate,
                       n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator)
    gpu_config = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    with tf.Session(config=sess_config) as sess:
        print('-----Initializing tf graph-----')
        tf.global_variables_initializer().run()
        print('-----Initialization accomplished-----')
        kge_model.check_norm(session=sess)
        summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph)
        for epoch in range(args.max_epoch):
            print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)
            kge_model.launch_training(session=sess, summary_writer=summary_writer)
            if (epoch + 1) % args.eval_freq == 0:
                kge_model.launch_evaluation(session=sess)
            if (epoch + 1) % 10 == 0:
                kge_model.save_embedding(session=sess)
def transe(id,data_path,embedding_dims,margin_value,score_func,batch_size,learning_rate,n_generator,n_rank_calculator,max_epoch,d):
    kg = KnowledgeGraph(data_dir=data_path)
    content = []
    kge_model = TransE(kg=kg, embedding_dim=embedding_dims, margin_value=margin_value,
                       score_func=score_func, batch_size=batch_size, learning_rate=learning_rate,
                       n_generator=n_generator, n_rank_calculator=n_rank_calculator,id=id)
    gpu_config = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    with tf.Session(config=sess_config) as sess:
        # print('-----Initializing tf graph-----')
        tf.global_variables_initializer().run()
        # print('-----Initialization accomplished-----')
        # loss,entity_embedding,relation_embedding = kge_model.check_norm(session=sess)
        summary_writer = tf.summary.FileWriter(logdir='../summary/', graph=sess.graph)
        for epoch in range(max_epoch):
            # print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)            
            # print(loss)
            kge_model.launch_training(session=sess, summary_writer=summary_writer)
            if (epoch + 1) % 50 == 0:
                kge_model.launch_evaluation(session=sess)
        loss,entity_embedding,relation_embedding = kge_model.check_norm(session=sess)
        content.append(loss)
        content.append(entity_embedding)
        content.append(relation_embedding)
        # print(relation_embedding.shape)
        d[id] = content
        print('FB15k-{} loss:{}'.format(id,d[id][0]))
        # print(type(d))
    return entity_embedding,relation_embedding
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--data_dir', type=str, default='../data/FB15k/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=4800)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--eval_freq', type=int, default=10)
    args = parser.parse_args()
    print(args)
    kg = KnowledgeGraph(data_dir=args.data_dir)
    kge_model = TransE(kg=kg, embedding_dim=args.embedding_dim, margin_value=args.margin_value,
                       score_func=args.score_func, batch_size=args.batch_size, learning_rate=args.learning_rate,
                       n_generator=args.n_generator, n_rank_calculator=args.n_rank_calculator, max_epoch=args.max_epoch,
                       eval_freq=args.eval_freq)

    gpu_config = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    with tf.Session(config=sess_config) as sess:
        kge_model.train(sess)
Exemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--data_dir', type=str, default='../data/FB15k/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=4800)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--eval_freq', type=int, default=10)
    args = parser.parse_args()
    print(args)
    kg = KnowledgeGraph(data_dir=args.data_dir)
    kge_model = TransE(kg=kg,
                       embedding_dim=args.embedding_dim,
                       margin_value=args.margin_value,
                       score_func=args.score_func,
                       batch_size=args.batch_size,
                       learning_rate=args.learning_rate,
                       n_generator=args.n_generator,
                       n_rank_calculator=args.n_rank_calculator,
                       model_name=args.model_name,
                       ckpt_dir=args.ckpt_dir)
    gpu_config = tf.GPUOptions(allow_growth=False)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    with tf.Session(config=sess_config) as sess:
        print('-----Initializing tf graph-----')
        tf.global_variables_initializer().run()
        print('-----Initialization accomplished-----')
        kge_model.check_norm(session=sess)
        summary_writer = tf.summary.FileWriter(logdir=args.summary_dir,
                                               graph=sess.graph)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=500)
        for epoch in range(args.max_epoch):
            print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)
            kge_model.launch_training(session=sess,
                                      summary_writer=summary_writer)
            if (epoch + 1) % args.eval_freq == 0:
                kge_model.launch_evaluation(session=sess, saver=saver)

            print('-----Save checkpoint-----')
            step_str = str(kge_model.global_step.eval(session=sess))
            save_path = args.ckpt_dir + '/' + args.model_name + step_str + '.ckpt'
            saver_path = saver.save(sess, save_path)
            tf.saved_model.simple_save(
                sess,
                args.ckpt_dir + '/model-' + step_str,
                inputs={'triple': kge_model.eval_triple},
                outputs={
                    'entity-embedding': kge_model.entity_embedding,
                    'relation-embedding': kge_model.relation_embedding
                })

            print("Model saved in path: %s" % saver_path)
Exemplo n.º 6
0
def run():
    # load data
    ent2id = read_elements(os.path.join(config.data_path, "entities.dict"))
    rel2id = read_elements(os.path.join(config.data_path, "relations.dict"))
    ent_num = len(ent2id)
    rel_num = len(rel2id)
    train_triples = read_triples(os.path.join(config.data_path, "train.txt"),
                                 ent2id, rel2id)
    valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"),
                                 ent2id, rel2id)
    test_triples = read_triples(os.path.join(config.data_path, "test.txt"),
                                ent2id, rel2id)
    symmetry_test = read_triples(
        os.path.join(config.data_path, "symmetry_test.txt"), ent2id, rel2id)
    inversion_test = read_triples(
        os.path.join(config.data_path, "inversion_test.txt"), ent2id, rel2id)
    composition_test = read_triples(
        os.path.join(config.data_path, "composition_test.txt"), ent2id, rel2id)
    others_test = read_triples(
        os.path.join(config.data_path, "other_test.txt"), ent2id, rel2id)
    logging.info("#ent_num: %d" % ent_num)
    logging.info("#rel_num: %d" % rel_num)
    logging.info("#train triple num: %d" % len(train_triples))
    logging.info("#valid triple num: %d" % len(valid_triples))
    logging.info("#test triple num: %d" % len(test_triples))
    logging.info("#Model: %s" % config.model)

    # 创建模型
    kge_model = TransE(ent_num, rel_num)
    if config.model == "TransH":
        kge_model = TransH(ent_num, rel_num)
    elif config.model == "TransR":
        kge_model = SimpleTransR(ent_num, rel_num)
    elif config.model == "TransD":
        kge_model = TransD(ent_num, rel_num)
    elif config.model == "STransE":
        kge_model = STransE(ent_num, rel_num)
    elif config.model == "LineaRE":
        kge_model = LineaRE(ent_num, rel_num)
    elif config.model == "DistMult":
        kge_model = DistMult(ent_num, rel_num)
    elif config.model == "ComplEx":
        kge_model = ComplEx(ent_num, rel_num)
    elif config.model == "RotatE":
        kge_model = RotatE(ent_num, rel_num)
    elif config.model == "TransIJ":
        kge_model = TransIJ(ent_num, rel_num)

    kge_model = kge_model.cuda(torch.device("cuda:0"))
    logging.info("Model Parameter Configuration:")
    for name, param in kge_model.named_parameters():
        logging.info("Parameter %s: %s, require_grad = %s" %
                     (name, str(param.size()), str(param.requires_grad)))

    # 训练
    train(model=kge_model,
          triples=(train_triples, valid_triples, test_triples, symmetry_test,
                   inversion_test, composition_test, others_test),
          ent_num=ent_num)
Exemplo n.º 7
0
def run():
    set_logger()

    # load data
    ent_path = os.path.join(config.data_path, "entities.dict")
    rel_path = os.path.join(config.data_path, "relations.dict")
    ent2id = read_elements(ent_path)
    rel2id = read_elements(rel_path)
    ent_num = len(ent2id)
    rel_num = len(rel2id)
    train_triples = read_triples(os.path.join(config.data_path, "train.txt"),
                                 ent2id, rel2id)
    valid_triples = read_triples(os.path.join(config.data_path, "valid.txt"),
                                 ent2id, rel2id)
    test_triples = read_triples(os.path.join(config.data_path, "test.txt"),
                                ent2id, rel2id)
    logging.info("#ent_num: %d" % ent_num)
    logging.info("#rel_num: %d" % rel_num)
    logging.info("#train triple num: %d" % len(train_triples))
    logging.info("#valid triple num: %d" % len(valid_triples))
    logging.info("#test triple num: %d" % len(test_triples))
    logging.info("#Model: %s" % config.model)

    # 创建模型
    kge_model = TransE(ent_num, rel_num)
    if config.model == "TransH":
        kge_model = TransH(ent_num, rel_num)
    elif config.model == "TransR":
        kge_model = TransR(ent_num, rel_num)
    elif config.model == "TransD":
        kge_model = TransD(ent_num, rel_num)
    elif config.model == "STransE":
        kge_model = STransE(ent_num, rel_num)
    elif config.model == "LineaRE":
        kge_model = LineaRE(ent_num, rel_num)
    elif config.model == "DistMult":
        kge_model = DistMult(ent_num, rel_num)
    elif config.model == "ComplEx":
        kge_model = ComplEx(ent_num, rel_num)
    elif config.model == "RotatE":
        kge_model = RotatE(ent_num, rel_num)

    if config.cuda:
        kge_model = kge_model.cuda()
    logging.info("Model Parameter Configuration:")
    for name, param in kge_model.named_parameters():
        logging.info("Parameter %s: %s, require_grad = %s" %
                     (name, str(param.size()), str(param.requires_grad)))

    # 训练
    train(model=kge_model,
          triples=(train_triples, valid_triples, test_triples),
          ent_num=ent_num)
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--data_dir', type=str, default='../data/FB15k/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=4800)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--eval_freq', type=int, default=10)
    args = parser.parse_args()
    print(args)
    kg = KnowledgeGraph(data_dir=args.data_dir)
    kge_model = TransE(kg=kg,
                       embedding_dim=args.embedding_dim,
                       margin_value=args.margin_value,
                       score_func=args.score_func,
                       batch_size=args.batch_size,
                       learning_rate=args.learning_rate,
                       n_generator=args.n_generator,
                       n_rank_calculator=args.n_rank_calculator,
                       model_name=args.model_name,
                       ckpt_dir=args.ckpt_dir)
    gpu_config = tf.GPUOptions(allow_growth=False)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    tf.reset_default_graph()
    with tf.Session(config=sess_config) as sess:
        print('-----Initializing tf graph-----')
        tf.global_variables_initializer().run()
        print('-----Initialization accomplished-----')
        kge_model.check_norm(session=sess)
        # summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph)
        saver = tf.train.Saver()
        saver.restore(sess, '../checkpoints/gov-g2/GOV-g2130000.ckpt.index')
        # for epoch in range(args.max_epoch):
        #     print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)
        #     kge_model.launch_training(session=sess, summary_writer=summary_writer)
        #     if (epoch + 1) % args.eval_freq == 0:
        #         kge_model.launch_evaluation(session=sess, saver=saver)
        print('-----Model Loaded-----')
Exemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--data_dir', type=str, default='../data/after_big/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=4800)
    parser.add_argument('--eval_batch_size', type=int, default=200)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=1000)
    parser.add_argument('--eval_freq', type=int, default=200)
    parser.add_argument('--log_file',
                        type=str,
                        default='../log/log_after_big.txt')

    args = parser.parse_args()
    print(args)

    file_object = open(args.log_file, 'w')
    file_object.close()

    fb15k = Dataset(data_dir=args.data_dir, log_file=args.log_file)
    kge_model = TransE(dataset=fb15k,
                       embedding_dim=args.embedding_dim,
                       margin_value=args.margin_value,
                       score_func=args.score_func,
                       batch_size=args.batch_size,
                       eval_batch_size=args.eval_batch_size,
                       learning_rate=args.learning_rate,
                       n_generator=args.n_generator,
                       n_rank_calculator=args.n_rank_calculator,
                       log_file=args.log_file)
    gpu_config = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    with tf.Session(config=sess_config) as sess:
        print('-----Initializing tf graph-----')
        file_object = open(args.log_file, 'a')
        file_object.write('-----Initializing tf graph-----\r\n')
        file_object.close()
        tf.global_variables_initializer().run()
        print('-----Initialization accomplished-----')
        print('----Check norm----')
        file_object = open(args.log_file, 'a')
        file_object.write('-----Initialization accomplished-----\r\n')
        file_object.write('----Check norm----\r\n')
        file_object.close()
        entity_embedding = kge_model.entity_embedding.eval(session=sess)
        relation_embedding = kge_model.relation_embedding.eval(session=sess)
        entity_norm = np.linalg.norm(entity_embedding, ord=2, axis=1)
        relation_norm = np.linalg.norm(relation_embedding, ord=2, axis=1)
        print('entity norm: {} relation norm: {}'.format(
            entity_norm, relation_norm))
        file_object = open(args.log_file, 'a')
        file_object.write('entity norm: {} relation norm: {}'.format(
            entity_norm, relation_norm) + '\r\n')
        file_object.close()
        summary_writer = tf.summary.FileWriter(logdir=args.summary_dir,
                                               graph=sess.graph)
        saver = tf.train.Saver(max_to_keep=1)
        for epoch in range(args.max_epoch):
            print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)
            file_object = open(args.log_file, 'a')
            file_object.write('=' * 30 + '[EPOCH {}]'.format(epoch) +
                              '=' * 30 + '\r\n')
            file_object.close()
            kge_model.launch_training(session=sess,
                                      summary_writer=summary_writer)
            saver.save(sess, '../ckpt/after_big.ckpt', global_step=epoch + 1)
            if (epoch + 1) % args.eval_freq == 0:
                kge_model.launch_evaluation(session=sess)
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--data_dir', type=str, default='../data/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=4800)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--eval_freq', type=int, default=10)
    args = parser.parse_args()
    print(args)
    '''
    传递参数
    Namespace(
        batch_size=4800, 
        ckpt_dir='../ckpt/', 
        data_dir='../data/FB15k/', 
        embedding_dim=200, 
        eval_freq=10, learning_rate=0.001, 
        margin_value=1.0, 
        max_epoch=500, 
        n_generator=24, 
        n_rank_calculator=24, 
        score_func='L1', 
        //summary_dir='../summary/')
    '''
    kg = KnowledgeGraph(data_dir=args.data_dir)#create graph
    kge_model = TransE(
        kg=kg,
        embedding_dim=args.embedding_dim,
        margin_value=args.margin_value,
        score_func=args.score_func,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        n_generator=args.n_generator,
        n_rank_calculator=args.n_rank_calculator
        )#init embd...
    gpu_config = tf.GPUOptions(allow_growth=True)
    '''
    tf.GPUOptions:可以作为设置tf.ConfigProto时的一个参数选项,一般用于限制GPU资源的使用
    allow_growth=True:动态申请现显存
    '''
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    '''
    tf.ConfigProto:
    创建session的时候,用来对session进行参数配置
    '''
    with tf.Session(config = sess_config) as sess:
        '''
        Session 是 Tensorflow 为了控制,和输出文件的执行的语句. 运行 session.run() 可以获得你要得知的运算结果
        '''
        print('-----Initializing tf graph-----')
        tf.global_variables_initializer().run()#就是 run了 所有global Variable 的 assign op,这就是初始化参数的本来面目。
        print('-----Initialization accomplished-----')
        kge_model.check_norm(session=sess)
        summary_writer = tf.summary.FileWriter(logdir=args.summary_dir, graph=sess.graph)
        # print(type(sess.graph))
        # print(type(summary_writer))
        '''
Exemplo n.º 11
0
        link_dic = {}
        entity_idx = 0
        link_idx = 0
        for row in train_tsv_reader:
            head = row[0]
            link = row[1]
            tail = row[2]
            if not head in entity_dic:
                entity_dic[head] = entity_idx
                entity_idx += 1

            if not tail in entity_dic:
                entity_dic[tail] = entity_idx
                entity_idx += 1

            if not link in link_dic:
                link_dic[link] = link_idx
                link_idx += 1
            data.append((entity_dic[head], link_dic[link], entity_dic[tail]))
        return data, entity_dic, link_dic


train_data, entity_dic, link_dic = load(TRAIN_DATASET_PATH)

#training
model = TransE(len(entity_dic), len(link_dic), 1, 50, 0.01, 50)
model.fit(np.array(train_data))

with open('models/transe.pkl', 'wb') as f:
    pickle.dump(model, f)
Exemplo n.º 12
0
 def prepareModel(self):
     print("Perpare model")
     self.model = TransE(self.n_entities, self.n_relations, embDim=100)
     if GPU:
         self.model.cuda()
Exemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(description='TransE')
    parser.add_argument('--mode', type=str, default='eval')
    parser.add_argument('--data_dir', type=str, default='../data/FB15k/')
    parser.add_argument('--embedding_dim', type=int, default=200)
    parser.add_argument('--margin_value', type=float, default=1.0)
    parser.add_argument('--score_func', type=str, default='L1')
    parser.add_argument('--batch_size', type=int, default=4800)
    parser.add_argument('--learning_rate', type=float, default=0.001)
    parser.add_argument('--n_generator', type=int, default=24)
    parser.add_argument('--n_rank_calculator', type=int, default=24)
    parser.add_argument('--hit_at_n', type=int, default=10)
    parser.add_argument('--ckpt_dir', type=str, default='../ckpt/')
    parser.add_argument('--summary_dir', type=str, default='../summary/')
    parser.add_argument('--max_epoch', type=int, default=500)
    parser.add_argument('--eval_freq', type=int, default=10)
    args = parser.parse_args()
    print(args)
    kg = KnowledgeGraph(data_dir=args.data_dir)
    kge_model = TransE(kg=kg,
                       model_path=args.ckpt_dir,
                       embedding_dim=args.embedding_dim,
                       margin_value=args.margin_value,
                       score_func=args.score_func,
                       batch_size=args.batch_size,
                       learning_rate=args.learning_rate,
                       n_generator=args.n_generator,
                       n_rank_calculator=args.n_rank_calculator,
                       hit_at_n=args.hit_at_n)
    gpu_config = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_config)
    if args.mode == 'test':
        saver = tf.train.Saver()
    else:
        saver = tf.train.Saver(tf.global_variables())
    with tf.Session(config=sess_config) as sess:
        if args.mode == 'eval':
            print('-----Loading from checkpoints-----')
            ckpt_file = tf.train.latest_checkpoint(args.ckpt_dir)
            saver.restore(sess, ckpt_file)
            kge_model.launch_evaluation(session=sess)
        elif args.mode == 'train':
            print('-----Initializing tf graph-----')
            tf.global_variables_initializer().run()
            print('-----Initialization accomplished-----')
            kge_model.check_norm(session=sess)
            summary_writer = tf.summary.FileWriter(logdir=args.summary_dir,
                                                   graph=sess.graph)
            for epoch in range(args.max_epoch):
                print('=' * 30 + '[EPOCH {}]'.format(epoch) + '=' * 30)
                kge_model.launch_training(epoch=epoch,
                                          session=sess,
                                          summary_writer=summary_writer,
                                          saver=saver)
                # if (epoch + 1) % args.eval_freq == 0:
                #     kge_model.launch_evaluation(session=sess)
        elif args.mode == 'predict':
            print('-----Loading from checkpoints-----')
            ckpt_file = tf.train.latest_checkpoint(args.ckpt_dir)
            saver.restore(sess, ckpt_file)
            kge_model.launch_prediction(session=sess)
        else:
            print('Wrong mode!! (mode = train|test|predict)')
Exemplo n.º 14
0
def main():
    opts = get_train_args()
    print("load data ...")
    data = DataSet('data/modified_triples.txt')
    dataloader = DataLoader(data, shuffle=True, batch_size=opts.batch_size)
    print("load model ...")
    if opts.model_type == 'transe':
        model = TransE(opts, data.ent_tot, data.rel_tot)
    elif opts.model_type == "distmult":
        model = DistMult(opts, data.ent_tot, data.rel_tot)
    if opts.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=opts.lr)
    elif opts.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opts.lr)
    model.cuda()
    model.relation_normalize()
    loss = torch.nn.MarginRankingLoss(margin=opts.margin)

    print("start training")
    for epoch in range(1, opts.epochs + 1):
        print("epoch : " + str(epoch))
        model.train()
        epoch_start = time.time()
        epoch_loss = 0
        tot = 0
        cnt = 0
        for i, batch_data in enumerate(dataloader):
            optimizer.zero_grad()
            batch_h, batch_r, batch_t, batch_n = batch_data
            batch_h = torch.LongTensor(batch_h).cuda()
            batch_r = torch.LongTensor(batch_r).cuda()
            batch_t = torch.LongTensor(batch_t).cuda()
            batch_n = torch.LongTensor(batch_n).cuda()
            pos_score, neg_score, dist = model.forward(batch_h, batch_r,
                                                       batch_t, batch_n)
            pos_score = pos_score.cpu()
            neg_score = neg_score.cpu()
            dist = dist.cpu()
            train_loss = loss(pos_score, neg_score,
                              torch.ones(pos_score.size(-1))) + dist
            train_loss.backward()
            optimizer.step()
            batch_loss = torch.sum(train_loss)
            epoch_loss += batch_loss
            batch_size = batch_h.size(0)
            tot += batch_size
            cnt += 1
            print('\r{:>10} epoch {} progress {} loss: {}\n'.format(
                '', epoch, tot / data.__len__(), train_loss),
                  end='')
        end = time.time()
        time_used = end - epoch_start
        epoch_loss /= cnt
        print('one epoch time: {} minutes'.format(time_used / 60))
        print('{} epochs'.format(epoch))
        print('epoch {} loss: {}'.format(epoch, epoch_loss))

        if epoch % opts.save_step == 0:
            print("save model...")
            model.entity_normalize()
            torch.save(model.state_dict(), 'model.pt')

    print("save model...")
    model.entity_normalize()
    torch.save(model.state_dict(), 'model.pt')
    print("[Saving embeddings of whole entities & relations...]")
    save_embeddings(model, opts, data.id2ent, data.id2rel)
    print("[Embedding results are saved successfully.]")