def reranking_eval(base_dir): """Outputs top-ranked ending sentences for stories in validation set.""" stories_text = utils.read_all_stories(FLAGS.rocstories_root_dir) valid_dataset, all_story_ids, all_emb_matrix = prepare_dataset() num_input_sentences = tf.compat.v1.data.get_output_shapes( valid_dataset)[0][1] model = models.build_model( num_input_sentences=num_input_sentences, embedding_matrix=all_emb_matrix) checkpoint = tf.train.Checkpoint(model=model) checkpoint_path = utils.pick_best_checkpoint( base_dir, 'valid_nolabel_acc') logging.info('LOADING FROM CHECKPOINT: %s', checkpoint_path) result = checkpoint.restore(checkpoint_path).expect_partial() result.assert_nontrivial_match() for x, labels, story_ids in valid_dataset: _, output_embedding = model(x, training=False) scores = tf.matmul(output_embedding, all_emb_matrix, transpose_b=True) sorted_indices = tf.argsort(scores, axis=-1, direction='DESCENDING') batch_size = x.shape[0] for batch_index in range(batch_size): story_id = story_ids[batch_index].numpy() story_id = story_id.decode('utf-8') story_text = stories_text[story_id] logging.info('Groundtruth story start:') for i in range(4): logging.info(' %d:\t%s', i+1, story_text[i]) logging.info('Groundtruth 5th sentence: %s', story_text[4]) logging.info('Guessed endings: ') for story_index in sorted_indices[batch_index, :10].numpy(): chosen_story_id = all_story_ids[story_index].numpy().decode('utf-8') story_text = stories_text[chosen_story_id] score = scores[batch_index, story_index].numpy() logging.info('(%f) %s', score, story_text[-1]) # Note: the indexes in labels are only valid for the validation set # because it happens ot be first in the embedding matric. gt_score = scores[batch_index, labels[batch_index]].numpy() label = labels[batch_index].numpy() gt_rank = sorted_indices[batch_index].numpy().tolist().index(label) logging.info('Rank of GT: %d', gt_rank) logging.info('Score for GT: %f', gt_score) logging.info('')
def run_eval(base_dir): """Writes model's predictions in proper format to [base_dir]/answer.txt.""" best_checkpoint_name = utils.pick_best_checkpoint(base_dir) dataset = prepare_dataset() checkpoint_path = os.path.join(base_dir, best_checkpoint_name) embedding_dim = tf.compat.v1.data.get_output_shapes(dataset)[0][-1] num_input_sentences = tf.compat.v1.data.get_output_shapes(dataset)[0][1] model = models.build_model(num_input_sentences=num_input_sentences, embedding_dim=embedding_dim) checkpoint = tf.train.Checkpoint(model=model) checkpoint.restore(checkpoint_path).expect_partial() logging.info('Evaluating with checkpoint: "%s"', checkpoint_path) test_accuracy = eval_single_checkpoint(model, dataset) with gfile.GFile(os.path.join(base_dir, 'test_spring2016_acc.txt'), 'w') as f: f.write(str(test_accuracy))
def all_distractors_eval(base_dir): """Computes model accuracy with all possible last sentences as distractors.""" valid_dataset, all_emb_matrix = prepare_dataset() num_input_sentences = tf.compat.v1.data.get_output_shapes( valid_dataset)[0][1] model = models.build_model(num_input_sentences=num_input_sentences, embedding_matrix=all_emb_matrix) checkpoint = tf.train.Checkpoint(model=model) checkpoint_path = utils.pick_best_checkpoint(base_dir, 'valid_nolabel_acc') logging.info('LOADING FROM CHECKPOINT: %s', checkpoint_path) result = checkpoint.restore(checkpoint_path).expect_partial() result.assert_nontrivial_match() ranks = [] label_in_top10 = [] for x, labels in valid_dataset: _, output_embedding = model(x, training=False) scores = tf.matmul(output_embedding, all_emb_matrix, transpose_b=True) sorted_indices = tf.argsort(scores, axis=-1, direction='DESCENDING') batch_size = x.shape[0] for batch_index in range(batch_size): # Note: the indexes in labels are only valid for the validation set # because it happens ot be first in the embedding matrix. label = labels[batch_index].numpy() gt_rank = sorted_indices[batch_index].numpy().tolist().index(label) top10predicted = sorted_indices[batch_index, :10] label_in_top10.append(1 if label in top10predicted else 0) top_predicted = sorted_indices[batch_index, 0] label_in_top10.append(1 if label == top_predicted else 0) ranks.append(gt_rank)