def get_type13_graph_optimizaton(kbc_path, dataset, dataset_mode, similarity_metric='l2', t_norm='min'): try: env = preload_env(kbc_path, dataset, dataset_mode, '1_3') part1, part2, part3 = env.parts target_ids, lhs_norm = env.target_ids, env.lhs_norm kbc, chains = env.kbc, env.chains obj_guess_raw,closest_map,indices_rankedby_distances \ = kbc.model.type1_3chain_optimize(chains, kbc.regularizer,\ max_steps=1000,similarity_metric=similarity_metric, t_norm = t_norm) lhs_norm, guess_norm = norm_comparison(lhs_norm, obj_guess_raw) keys = [] for i in range(len(indices_rankedby_distances)): key = [part1[i][0],part1[i][1],\ part2[i][0],part2[i][1],\ part3[i][1],part3[i][2] ] key = '_'.join(str(e.item()) for e in key) keys.append(key) hits = hits_at_k(indices_rankedby_distances, target_ids, keys, hits=[1, 3, 5, 10, 20]) APR = average_percentile_rank(indices_rankedby_distances, target_ids, keys) except RuntimeError as e: print(e) return None return obj_guess_raw, closest_map
def evaluate(): """Run evaluation on dev or test data.""" add_inverse_edge = FLAGS.model in \ ["source_rel_attention", "source_path_attention"] if FLAGS.clueweb_data: train_graph = clueweb_text_graph.CWTextGraph( text_kg_file=FLAGS.clueweb_data, embeddings_file=FLAGS.clueweb_embeddings, sentence_vocab_file=FLAGS.clueweb_sentences, skip_new=True, kg_file=FLAGS.kg_file, add_reverse_graph=not add_inverse_edge, add_inverse_edge=add_inverse_edge, subsample=FLAGS.subsample_text_rels ) elif FLAGS.text_kg_file: train_graph = text_graph.TextGraph( text_kg_file=FLAGS.text_kg_file, skip_new=True, max_text_len=FLAGS.max_text_len, max_vocab_size=FLAGS.max_vocab_size, min_word_freq=FLAGS.min_word_freq, kg_file=FLAGS.kg_file, add_reverse_graph=not add_inverse_edge, add_inverse_edge=add_inverse_edge, max_path_length=FLAGS.max_path_length ) else: train_graph = graph.Graph( kg_file=FLAGS.kg_file, add_reverse_graph=not add_inverse_edge, add_inverse_edge=add_inverse_edge, max_path_length=FLAGS.max_path_length ) # train_graph, _ = read_graph_data( # kg_file=FLAGS.kg_file, # add_reverse_graph=(FLAGS.model != "source_rel_attention"), # add_inverse_edge=(FLAGS.model == "source_rel_attention"), # mode="train", num_epochs=FLAGS.num_epochs, batchsize=FLAGS.batchsize, # max_neighbors=FLAGS.max_neighbors, # max_negatives=FLAGS.max_negatives # ) val_graph = None if FLAGS.dev_kg_file: val_graph, eval_data = read_graph_data( kg_file=FLAGS.dev_kg_file, add_reverse_graph=not add_inverse_edge, add_inverse_edge=add_inverse_edge, # add_reverse_graph=False, # add_inverse_edge=False, mode="dev", num_epochs=1, batchsize=FLAGS.test_batchsize, max_neighbors=FLAGS.max_neighbors, max_negatives=FLAGS.max_negatives, train_graph=train_graph, text_kg_file=FLAGS.text_kg_file ) if FLAGS.test_kg_file: _, eval_data = read_graph_data( kg_file=FLAGS.test_kg_file, add_reverse_graph=not add_inverse_edge, add_inverse_edge=add_inverse_edge, # add_reverse_graph=False, # add_inverse_edge=False, mode="test", num_epochs=1, batchsize=FLAGS.test_batchsize, max_neighbors=FLAGS.max_neighbors, max_negatives=None, train_graph=train_graph, text_kg_file=FLAGS.text_kg_file, val_graph=val_graph ) if not FLAGS.dev_kg_file and not FLAGS.test_kg_file: raise ValueError("Evalution without a dev or test file!") iterator = eval_data.dataset.make_initializable_iterator() candidate_scores, candidates, labels, model, is_train_ph, inputs = \ create_model(train_graph, iterator) # Create eval metrics # if FLAGS.dev_kg_file: batch_rr = metrics.mrr(candidate_scores, candidates, labels) mrr, mrr_update = tf.metrics.mean(batch_rr) mrr_summary = tf.summary.scalar("MRR", mrr) all_hits, all_hits_update, all_hits_summaries = [], [], [] for k in [1, 3, 10]: batch_hits = metrics.hits_at_k(candidate_scores, candidates, labels, k=k) hits, hits_update = tf.metrics.mean(batch_hits) hits_summary = tf.summary.scalar("Hits_at_%d" % k, hits) all_hits.append(hits) all_hits_update.append(hits_update) all_hits_summaries.append(hits_summary) hits = tf.group(*all_hits) hits_update = tf.group(*all_hits_update) global_step = tf.Variable(0, name="global_step", trainable=False) current_step = tf.Variable(0, name="current_step", trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES]) incr_current_step = tf.assign_add(current_step, 1) reset_current_step = tf.assign(current_step, 0) slim.get_or_create_global_step(graph=tf.get_default_graph()) # best_hits = tf.Variable(0., trainable=False) # best_step = tf.Variable(0, trainable=False) # with tf.control_dependencies([hits]): # update_best_hits = tf.cond(tf.greater(hits, best_hits), # lambda: tf.assign(best_hits, hits), # lambda: 0.) # update_best_step = tf.cond(tf.greater(hits, best_hits), # lambda: tf.assign(best_step, global_step), # lambda: 0) # best_hits_summary = tf.summary.scalar("Best Hits@10", best_hits) # best_step_summary = tf.summary.scalar("Best Step", best_step) nexamples = eval_data.data_graph.tuple_store.shape[0] if eval_data.data_graph.add_reverse_graph: nexamples *= 2 num_batches = math.ceil(nexamples / float(FLAGS.test_batchsize)) local_init_op = tf.local_variables_initializer() if FLAGS.analyze: entity_names = utils.read_entity_name_mapping(FLAGS.entity_names_file) session = tf.Session() # summary_writer = tf.summary.FileWriter(FLAGS.output_dir, session.graph) init_op = tf.global_variables_initializer() session.run(init_op) session.run(local_init_op) saver = tf.train.Saver(tf.trainable_variables()) ckpt_path = FLAGS.model_path + "/model.ckpt-%d" % FLAGS.global_step attention_probs = model["attention_encoder"].get_from_collection( "attention_probs" ) if FLAGS.clueweb_data: s, nbrs_s, text_nbrs_s, text_nbrs_s_emb, r, candidates, _ = inputs elif FLAGS.text_kg_file: s, nbrs_s, text_nbrs_s, r, candidates, _ = inputs else: s, nbrs_s, r, candidates, _ = inputs saver.restore(session, ckpt_path) session.run(iterator.initializer) num_attention = 5 nsteps = 0 outf_correct = open(FLAGS.output_dir + "/analyze_correct.txt", "w+") outf_incorrect = open( FLAGS.output_dir + "/analyze_incorrect.txt", "w+" ) ncorrect = 0 analyze_outputs = [candidate_scores, s, nbrs_s, r, candidates, labels, attention_probs] if FLAGS.text_kg_file: analyze_outputs.append(text_nbrs_s) while True: try: analyze_vals = session.run(analyze_outputs, {is_train_ph: False}) if FLAGS.text_kg_file: cscores, se, nbrs, qr, cands, te, nbr_attention_probs, text_nbrs = \ analyze_vals else: cscores, se, nbrs, qr, cands, te, nbr_attention_probs = analyze_vals # import pdb; pdb.set_trace() pred_ids = cscores.argmax(1) for i in range(se.shape[0]): sname = train_graph.inverse_entity_vocab[se[i]] if sname in entity_names: sname = entity_names[sname] rname = train_graph.inverse_relation_vocab[qr[i]] pred_target = cands[i, pred_ids[i]] pred_name = train_graph.inverse_entity_vocab[pred_target] if pred_name in entity_names: pred_name = entity_names[pred_name] tname = train_graph.inverse_entity_vocab[te[i][0]] if tname in entity_names: tname = entity_names[tname] if te[i][0] == pred_target: outf = outf_correct ncorrect += 1 else: outf = outf_incorrect outf.write("\n(%d) %s, %s, ? \t Pred: %s \t Target: %s" % (nsteps+i+1, sname, rname, pred_name, tname)) top_nbrs_index = np.argsort(nbr_attention_probs[i, :])[::-1] outf.write("\nTop Nbrs:") for j in range(num_attention): nbr_index = top_nbrs_index[j] if nbr_index < FLAGS.max_neighbors: nbr_id = nbrs[i, nbr_index, :] nbr_name = "" for k in range(0, nbrs.shape[-1], 2): ent_name = train_graph.inverse_entity_vocab[nbr_id[k+1]] if ent_name in entity_names: ent_name = entity_names[ent_name] rel_name = train_graph.inverse_relation_vocab[nbr_id[k]] nbr_name += "(%s, %s)" % (rel_name, ent_name) else: # Text Relation text_nbr_ids = text_nbrs[i, nbr_index - FLAGS.max_neighbors, :] text_nbr_ent = text_nbr_ids[0] ent_name = train_graph.inverse_entity_vocab[text_nbr_ent] if ent_name in entity_names: ent_name = entity_names[ent_name] rel_name = train_graph.get_relation_text(text_nbr_ids[1:]) nbr_name = "(%s, %s)" % (rel_name, ent_name) outf.write("\n\t\t %s Prob: %.4f" % (nbr_name, nbr_attention_probs[i, nbr_index])) nsteps += se.shape[0] tf.logging.info("Current hits@1: %.3f", ncorrect * 1.0 / (nsteps)) except tf.errors.OutOfRangeError: break outf_correct.close() outf_incorrect.close() return class DataInitHook(tf.train.SessionRunHook): def after_create_session(self, sess, coord): sess.run(iterator.initializer) sess.run(reset_current_step) if FLAGS.test_only: ckpt_path = FLAGS.model_path + "/model.ckpt-%d" % FLAGS.global_step slim.evaluation.evaluate_once( master=FLAGS.master, checkpoint_path=ckpt_path, logdir=FLAGS.output_dir, variables_to_restore=tf.trainable_variables() + [global_step], initial_op=tf.group(local_init_op, iterator.initializer), # initial_op=iterator.initializer, num_evals=num_batches, eval_op=tf.group(mrr_update, hits_update, incr_current_step), eval_op_feed_dict={is_train_ph: False}, final_op=tf.group(mrr, hits), final_op_feed_dict={is_train_ph: False}, summary_op=tf.summary.merge([mrr_summary]+ all_hits_summaries), hooks=[DataInitHook(), tf.train.LoggingTensorHook( {"mrr": mrr, "hits": hits, "step": current_step}, every_n_iter=1 )] ) else: slim.evaluation.evaluation_loop( master=FLAGS.master, checkpoint_dir=FLAGS.model_path, logdir=FLAGS.output_dir, variables_to_restore=tf.trainable_variables() + [global_step], initial_op=tf.group(local_init_op, iterator.initializer), # initial_op=iterator.initializer, num_evals=num_batches, eval_op=tf.group(mrr_update, hits_update, incr_current_step), eval_op_feed_dict={is_train_ph: False}, final_op=tf.group(mrr, hits), final_op_feed_dict={is_train_ph: False}, summary_op=tf.summary.merge([mrr_summary] + all_hits_summaries), max_number_of_evaluations=None, eval_interval_secs=60, hooks=[DataInitHook(), tf.train.LoggingTensorHook( {"mrr": mrr, "hits": hits, "step": current_step}, every_n_iter=1 )] )