Exemplo n.º 1
0
def get_type13_graph_optimizaton(kbc_path,
                                 dataset,
                                 dataset_mode,
                                 similarity_metric='l2',
                                 t_norm='min'):

    try:

        env = preload_env(kbc_path, dataset, dataset_mode, '1_3')
        part1, part2, part3 = env.parts
        target_ids, lhs_norm = env.target_ids, env.lhs_norm
        kbc, chains = env.kbc, env.chains

        obj_guess_raw,closest_map,indices_rankedby_distances \
        = kbc.model.type1_3chain_optimize(chains, kbc.regularizer,\
        max_steps=1000,similarity_metric=similarity_metric, t_norm = t_norm)

        lhs_norm, guess_norm = norm_comparison(lhs_norm, obj_guess_raw)

        keys = []
        for i in range(len(indices_rankedby_distances)):

            key = [part1[i][0],part1[i][1],\
                    part2[i][0],part2[i][1],\
                        part3[i][1],part3[i][2]
                ]

            key = '_'.join(str(e.item()) for e in key)
            keys.append(key)

        hits = hits_at_k(indices_rankedby_distances,
                         target_ids,
                         keys,
                         hits=[1, 3, 5, 10, 20])

        APR = average_percentile_rank(indices_rankedby_distances, target_ids,
                                      keys)

    except RuntimeError as e:
        print(e)
        return None
    return obj_guess_raw, closest_map
Exemplo n.º 2
0
def evaluate():
  """Run evaluation on dev or test data."""
  add_inverse_edge = FLAGS.model in \
                     ["source_rel_attention", "source_path_attention"]
  if FLAGS.clueweb_data:
    train_graph = clueweb_text_graph.CWTextGraph(
        text_kg_file=FLAGS.clueweb_data,
        embeddings_file=FLAGS.clueweb_embeddings,
        sentence_vocab_file=FLAGS.clueweb_sentences,
        skip_new=True,
        kg_file=FLAGS.kg_file,
        add_reverse_graph=not add_inverse_edge,
        add_inverse_edge=add_inverse_edge,
        subsample=FLAGS.subsample_text_rels
    )
  elif FLAGS.text_kg_file:
    train_graph = text_graph.TextGraph(
        text_kg_file=FLAGS.text_kg_file,
        skip_new=True,
        max_text_len=FLAGS.max_text_len,
        max_vocab_size=FLAGS.max_vocab_size,
        min_word_freq=FLAGS.min_word_freq,
        kg_file=FLAGS.kg_file,
        add_reverse_graph=not add_inverse_edge,
        add_inverse_edge=add_inverse_edge,
        max_path_length=FLAGS.max_path_length
    )
  else:
    train_graph = graph.Graph(
        kg_file=FLAGS.kg_file,
        add_reverse_graph=not add_inverse_edge,
        add_inverse_edge=add_inverse_edge,
        max_path_length=FLAGS.max_path_length
    )
  # train_graph, _ = read_graph_data(
  #     kg_file=FLAGS.kg_file,
  #     add_reverse_graph=(FLAGS.model != "source_rel_attention"),
  #     add_inverse_edge=(FLAGS.model == "source_rel_attention"),
  #     mode="train", num_epochs=FLAGS.num_epochs, batchsize=FLAGS.batchsize,
  #     max_neighbors=FLAGS.max_neighbors,
  #     max_negatives=FLAGS.max_negatives
  # )
  val_graph = None
  if FLAGS.dev_kg_file:
    val_graph, eval_data = read_graph_data(
        kg_file=FLAGS.dev_kg_file,
        add_reverse_graph=not add_inverse_edge,
        add_inverse_edge=add_inverse_edge,
        # add_reverse_graph=False,
        # add_inverse_edge=False,
        mode="dev", num_epochs=1, batchsize=FLAGS.test_batchsize,
        max_neighbors=FLAGS.max_neighbors,
        max_negatives=FLAGS.max_negatives, train_graph=train_graph,
        text_kg_file=FLAGS.text_kg_file
    )
  if FLAGS.test_kg_file:
    _, eval_data = read_graph_data(
        kg_file=FLAGS.test_kg_file,
        add_reverse_graph=not add_inverse_edge,
        add_inverse_edge=add_inverse_edge,
        # add_reverse_graph=False,
        # add_inverse_edge=False,
        mode="test", num_epochs=1, batchsize=FLAGS.test_batchsize,
        max_neighbors=FLAGS.max_neighbors,
        max_negatives=None, train_graph=train_graph,
        text_kg_file=FLAGS.text_kg_file,
        val_graph=val_graph
    )
  if not FLAGS.dev_kg_file and not FLAGS.test_kg_file:
    raise ValueError("Evalution without a dev or test file!")

  iterator = eval_data.dataset.make_initializable_iterator()
  candidate_scores, candidates, labels, model, is_train_ph, inputs = \
    create_model(train_graph, iterator)

  # Create eval metrics
  # if FLAGS.dev_kg_file:
  batch_rr = metrics.mrr(candidate_scores, candidates, labels)
  mrr, mrr_update = tf.metrics.mean(batch_rr)
  mrr_summary = tf.summary.scalar("MRR", mrr)

  all_hits, all_hits_update, all_hits_summaries = [], [], []
  for k in [1, 3, 10]:
    batch_hits = metrics.hits_at_k(candidate_scores, candidates, labels, k=k)
    hits, hits_update = tf.metrics.mean(batch_hits)
    hits_summary = tf.summary.scalar("Hits_at_%d" % k, hits)
    all_hits.append(hits)
    all_hits_update.append(hits_update)
    all_hits_summaries.append(hits_summary)
  hits = tf.group(*all_hits)
  hits_update = tf.group(*all_hits_update)

  global_step = tf.Variable(0, name="global_step", trainable=False)
  current_step = tf.Variable(0, name="current_step", trainable=False,
                             collections=[tf.GraphKeys.LOCAL_VARIABLES])
  incr_current_step = tf.assign_add(current_step, 1)
  reset_current_step = tf.assign(current_step, 0)

  slim.get_or_create_global_step(graph=tf.get_default_graph())

  # best_hits = tf.Variable(0., trainable=False)
  # best_step = tf.Variable(0, trainable=False)
  # with tf.control_dependencies([hits]):
  #   update_best_hits = tf.cond(tf.greater(hits, best_hits),
  #                              lambda: tf.assign(best_hits, hits),
  #                              lambda: 0.)
  #   update_best_step = tf.cond(tf.greater(hits, best_hits),
  #                              lambda: tf.assign(best_step, global_step),
  #                              lambda: 0)
  # best_hits_summary = tf.summary.scalar("Best Hits@10", best_hits)
  # best_step_summary = tf.summary.scalar("Best Step", best_step)

  nexamples = eval_data.data_graph.tuple_store.shape[0]
  if eval_data.data_graph.add_reverse_graph:
    nexamples *= 2
  num_batches = math.ceil(nexamples / float(FLAGS.test_batchsize))
  local_init_op = tf.local_variables_initializer()

  if FLAGS.analyze:
    entity_names = utils.read_entity_name_mapping(FLAGS.entity_names_file)
    session = tf.Session()
    # summary_writer = tf.summary.FileWriter(FLAGS.output_dir, session.graph)
    init_op = tf.global_variables_initializer()
    session.run(init_op)
    session.run(local_init_op)
    saver = tf.train.Saver(tf.trainable_variables())
    ckpt_path = FLAGS.model_path + "/model.ckpt-%d" % FLAGS.global_step
    attention_probs = model["attention_encoder"].get_from_collection(
        "attention_probs"
    )
    if FLAGS.clueweb_data:
      s, nbrs_s, text_nbrs_s, text_nbrs_s_emb, r, candidates, _ = inputs
    elif FLAGS.text_kg_file:
      s, nbrs_s, text_nbrs_s, r, candidates, _ = inputs
    else:
      s, nbrs_s, r, candidates, _ = inputs
    saver.restore(session, ckpt_path)
    session.run(iterator.initializer)
    num_attention = 5
    nsteps = 0
    outf_correct = open(FLAGS.output_dir + "/analyze_correct.txt", "w+")
    outf_incorrect = open(
        FLAGS.output_dir + "/analyze_incorrect.txt", "w+"
    )
    ncorrect = 0
    analyze_outputs = [candidate_scores, s, nbrs_s, r, candidates, labels,
                       attention_probs]
    if FLAGS.text_kg_file:
      analyze_outputs.append(text_nbrs_s)
    while True:
      try:
        analyze_vals = session.run(analyze_outputs, {is_train_ph: False})
        if FLAGS.text_kg_file:
          cscores, se, nbrs, qr, cands, te, nbr_attention_probs, text_nbrs = \
            analyze_vals
        else:
          cscores, se, nbrs, qr, cands, te, nbr_attention_probs = analyze_vals
        # import pdb; pdb.set_trace()
        pred_ids = cscores.argmax(1)
        for i in range(se.shape[0]):
          sname = train_graph.inverse_entity_vocab[se[i]]
          if sname in entity_names:
            sname = entity_names[sname]
          rname = train_graph.inverse_relation_vocab[qr[i]]
          pred_target = cands[i, pred_ids[i]]
          pred_name = train_graph.inverse_entity_vocab[pred_target]
          if pred_name in entity_names:
            pred_name = entity_names[pred_name]
          tname = train_graph.inverse_entity_vocab[te[i][0]]
          if tname in entity_names:
            tname = entity_names[tname]
          if te[i][0] == pred_target:
            outf = outf_correct
            ncorrect += 1
          else:
            outf = outf_incorrect
          outf.write("\n(%d) %s, %s, ? \t Pred: %s \t Target: %s" %
                     (nsteps+i+1, sname, rname, pred_name, tname))
          top_nbrs_index = np.argsort(nbr_attention_probs[i, :])[::-1]
          outf.write("\nTop Nbrs:")
          for j in range(num_attention):
            nbr_index = top_nbrs_index[j]
            if nbr_index < FLAGS.max_neighbors:
              nbr_id = nbrs[i, nbr_index, :]
              nbr_name = ""
              for k in range(0, nbrs.shape[-1], 2):
                ent_name = train_graph.inverse_entity_vocab[nbr_id[k+1]]
                if ent_name in entity_names:
                  ent_name = entity_names[ent_name]
                rel_name = train_graph.inverse_relation_vocab[nbr_id[k]]
                nbr_name += "(%s, %s)" % (rel_name, ent_name)
            else:
              # Text Relation
              text_nbr_ids = text_nbrs[i, nbr_index - FLAGS.max_neighbors, :]
              text_nbr_ent = text_nbr_ids[0]
              ent_name = train_graph.inverse_entity_vocab[text_nbr_ent]
              if ent_name in entity_names:
                ent_name = entity_names[ent_name]
              rel_name = train_graph.get_relation_text(text_nbr_ids[1:])
              nbr_name = "(%s, %s)" % (rel_name, ent_name)
            outf.write("\n\t\t %s Prob: %.4f" %
                       (nbr_name, nbr_attention_probs[i, nbr_index]))
        nsteps += se.shape[0]
        tf.logging.info("Current hits@1: %.3f", ncorrect * 1.0 / (nsteps))

      except tf.errors.OutOfRangeError:
        break
    outf_correct.close()
    outf_incorrect.close()
    return

  class DataInitHook(tf.train.SessionRunHook):

    def after_create_session(self, sess, coord):
      sess.run(iterator.initializer)
      sess.run(reset_current_step)

  if FLAGS.test_only:
    ckpt_path = FLAGS.model_path + "/model.ckpt-%d" % FLAGS.global_step
    slim.evaluation.evaluate_once(
        master=FLAGS.master,
        checkpoint_path=ckpt_path,
        logdir=FLAGS.output_dir,
        variables_to_restore=tf.trainable_variables() + [global_step],
        initial_op=tf.group(local_init_op, iterator.initializer),
        # initial_op=iterator.initializer,
        num_evals=num_batches,
        eval_op=tf.group(mrr_update, hits_update, incr_current_step),
        eval_op_feed_dict={is_train_ph: False},
        final_op=tf.group(mrr, hits),
        final_op_feed_dict={is_train_ph: False},
        summary_op=tf.summary.merge([mrr_summary]+ all_hits_summaries),
        hooks=[DataInitHook(),
               tf.train.LoggingTensorHook(
                   {"mrr": mrr, "hits": hits, "step": current_step},
                   every_n_iter=1
               )]
    )
  else:
    slim.evaluation.evaluation_loop(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.model_path,
        logdir=FLAGS.output_dir,
        variables_to_restore=tf.trainable_variables() + [global_step],
        initial_op=tf.group(local_init_op, iterator.initializer),
        # initial_op=iterator.initializer,
        num_evals=num_batches,
        eval_op=tf.group(mrr_update, hits_update, incr_current_step),
        eval_op_feed_dict={is_train_ph: False},
        final_op=tf.group(mrr, hits),
        final_op_feed_dict={is_train_ph: False},
        summary_op=tf.summary.merge([mrr_summary] +  all_hits_summaries),
        max_number_of_evaluations=None,
        eval_interval_secs=60,
        hooks=[DataInitHook(),
               tf.train.LoggingTensorHook(
                   {"mrr": mrr, "hits": hits, "step": current_step},
                   every_n_iter=1
               )]
    )