예제 #1
0
 def valid_desc(self, stop_metric):
     print("valid desc")
     valid_links = pd.DataFrame(self.kgs.valid_links)
     desc1 = self.e_desc.loc[valid_links.values[:, 0]].values
     desc2 = self.e_desc.loc[valid_links.values[:, 1]].values
     desc_em1 = self.word_em[desc1]
     desc_em2 = self.word_em[desc2]
     dem1, dem2 = self.session.run(
         fetches=[self.desc_embedding1, self.desc_embedding2],
         feed_dict={
             self.desc1: desc_em1,
             self.desc2: desc_em2
         })
     hits1_12, mrr_12 = valid(dem1,
                              dem2,
                              None,
                              self.args.top_k,
                              self.args.test_threads_num,
                              metric=self.args.eval_metric,
                              normalize=self.args.eval_norm,
                              csls_k=0,
                              accurate=False)
     # del dem1, dem2, desc_em1, desc_em2, desc1, desc2
     # gc.collect()
     return hits1_12 if stop_metric == 'hits1' else mrr_12
예제 #2
0
def test(model, embed_choice='avg', w=(1, 1, 1)):
    if embed_choice == 'nv':
        ent_embeds = model.name_embeds.eval(session=model.session)
    elif embed_choice == 'rv':
        ent_embeds = model.rv_ent_embeds.eval(session=model.session)
    elif embed_choice == 'av':
        ent_embeds = model.av_ent_embeds.eval(session=model.session)
    elif embed_choice == 'final':
        ent_embeds = model.ent_embeds.eval(session=model.session)
    elif embed_choice == 'avg':
        ent_embeds = w[0] * model.name_embeds.eval(session=model.session) + \
                     w[1] * model.rv_ent_embeds.eval(session=model.session) + \
                     w[2] * model.av_ent_embeds.eval(session=model.session)
    else:  # wavg
        ent_embeds = model.ent_embeds
    print(embed_choice, 'test results:')
    embeds1 = ent_embeds[model.kgs.test_entities1, ]
    embeds2 = ent_embeds[model.kgs.test_entities2, ]
    hits1_12, mrr_12 = eva.valid(embeds1,
                                 embeds2,
                                 None,
                                 model.args.top_k,
                                 model.args.test_threads_num,
                                 normalize=True)
    del embeds1, embeds2
    gc.collect()
    return mrr_12
예제 #3
0
파일: rdgcn.py 프로젝트: zhenglinyi/OpenEA
 def valid_(self, stop_metric):
     embedding = self.sess.run(self.output)
     embeds1 = np.array([embedding[e] for e in self.kgs.valid_entities1])
     embeds2 = np.array([embedding[e] for e in self.kgs.valid_entities2 + self.kgs.test_entities2])
     hits1_12, mrr_12 = valid(embeds1, embeds2, None, self.args.top_k, self.args.test_threads_num,
                              metric=self.args.eval_metric)
     if stop_metric == 'hits1':
         return hits1_12
     return mrr_12
예제 #4
0
 def valid(self, stop_metric):
     embeds1, embeds2, mapping = self._eval_valid_embeddings()
     hits1_12, mrr_12 = valid(embeds1,
                              embeds2,
                              mapping,
                              self.args.top_k,
                              self.args.test_threads_num,
                              metric=self.args.eval_metric,
                              normalize=self.args.eval_norm,
                              csls_k=0,
                              accurate=False)
     return hits1_12 if stop_metric == 'hits1' else mrr_12
예제 #5
0
def valid_WVA(model):
    nv_ent_embeds1 = tf.nn.embedding_lookup(
        model.name_embeds,
        model.kgs.valid_entities1).eval(session=model.session)
    rv_ent_embeds1 = tf.nn.embedding_lookup(
        model.rv_ent_embeds,
        model.kgs.valid_entities1).eval(session=model.session)
    av_ent_embeds1 = tf.nn.embedding_lookup(
        model.av_ent_embeds,
        model.kgs.valid_entities1).eval(session=model.session)
    weight11, weight21, weight31 = wva(nv_ent_embeds1, rv_ent_embeds1,
                                       av_ent_embeds1)

    test_list = model.kgs.valid_entities2 + model.kgs.test_entities2
    nv_ent_embeds2 = tf.nn.embedding_lookup(
        model.name_embeds, test_list).eval(session=model.session)
    rv_ent_embeds2 = tf.nn.embedding_lookup(
        model.rv_ent_embeds, test_list).eval(session=model.session)
    av_ent_embeds2 = tf.nn.embedding_lookup(
        model.av_ent_embeds, test_list).eval(session=model.session)
    weight12, weight22, weight32 = wva(nv_ent_embeds2, rv_ent_embeds2,
                                       av_ent_embeds2)

    weight1 = weight11 + weight12
    weight2 = weight21 + weight22
    weight3 = weight31 + weight32
    all_weight = weight1 + weight2 + weight3
    weight1 /= all_weight
    weight2 /= all_weight
    weight3 /= all_weight

    print('weights', weight1, weight2, weight3)

    embeds1 = weight1 * nv_ent_embeds1 + \
              weight2 * rv_ent_embeds1 + \
              weight3 * av_ent_embeds1
    embeds2 = weight1 * nv_ent_embeds2 + \
              weight2 * rv_ent_embeds2 + \
              weight3 * av_ent_embeds2
    print('wvag valid results:')
    hits1_12, mrr_12 = eva.valid(embeds1,
                                 embeds2,
                                 None,
                                 model.args.top_k,
                                 model.args.test_threads_num,
                                 normalize=True)

    del nv_ent_embeds1, rv_ent_embeds1, av_ent_embeds1
    del nv_ent_embeds2, rv_ent_embeds2, av_ent_embeds2
    del embeds1, embeds2
    gc.collect()

    return mrr_12
예제 #6
0
 def valid(self):
     mat = get_ent_embeds_from_attributes(self.kgs,
                                          self.eval_attribute_embeddings(),
                                          self.selected_attributes)
     embeds1 = mat[self.kgs.valid_entities1, ]
     embeds2 = mat[self.kgs.valid_entities2, ]
     hits1_12, mrr_12 = evaluation.valid(embeds1,
                                         embeds2,
                                         None,
                                         self.args.top_k,
                                         self.args.test_threads_num,
                                         metric=self.args.eval_metric)
     if self.args.stop_metric == 'hits1':
         return hits1_12
     return mrr_12
예제 #7
0
 def valid_(self, stop_metric):
     se = self.session.run(self.model_se.outputs, feed_dict=self.feed_dict_se)
     if self.args.test_method == "sa":
         ae = self.session.run(self.model_ae.outputs, feed_dict=self.feed_dict_ae)
         beta = self.args.beta
         embeddings = np.concatenate([se*beta, ae*(1.0-beta)], axis=1)
     else:
         embeddings = se
     embeds1 = np.array([embeddings[e] for e in self.kgs.valid_entities1])
     embeds2 = np.array([embeddings[e] for e in self.kgs.valid_entities2 + self.kgs.test_entities2])
     hits1_12, mrr_12 = valid(embeds1, embeds2, None, self.args.top_k, self.args.test_threads_num,
                              metric=self.args.eval_metric)
     if stop_metric == 'hits1':
         return hits1_12
     return mrr_12