def reranking_eval(base_dir):
  """Outputs top-ranked ending sentences for stories in validation set."""
  stories_text = utils.read_all_stories(FLAGS.rocstories_root_dir)

  valid_dataset, all_story_ids, all_emb_matrix = prepare_dataset()

  num_input_sentences = tf.compat.v1.data.get_output_shapes(
      valid_dataset)[0][1]
  model = models.build_model(
      num_input_sentences=num_input_sentences, embedding_matrix=all_emb_matrix)

  checkpoint = tf.train.Checkpoint(model=model)
  checkpoint_path = utils.pick_best_checkpoint(
      base_dir, 'valid_nolabel_acc')

  logging.info('LOADING FROM CHECKPOINT: %s', checkpoint_path)

  result = checkpoint.restore(checkpoint_path).expect_partial()
  result.assert_nontrivial_match()

  for x, labels, story_ids in valid_dataset:
    _, output_embedding = model(x, training=False)

    scores = tf.matmul(output_embedding, all_emb_matrix, transpose_b=True)
    sorted_indices = tf.argsort(scores, axis=-1, direction='DESCENDING')

    batch_size = x.shape[0]
    for batch_index in range(batch_size):
      story_id = story_ids[batch_index].numpy()
      story_id = story_id.decode('utf-8')
      story_text = stories_text[story_id]

      logging.info('Groundtruth story start:')
      for i in range(4):
        logging.info('  %d:\t%s', i+1, story_text[i])

      logging.info('Groundtruth 5th sentence: %s', story_text[4])

      logging.info('Guessed endings: ')
      for story_index in sorted_indices[batch_index, :10].numpy():
        chosen_story_id = all_story_ids[story_index].numpy().decode('utf-8')
        story_text = stories_text[chosen_story_id]
        score = scores[batch_index, story_index].numpy()
        logging.info('(%f)  %s', score, story_text[-1])

      # Note: the indexes in labels are only valid for the validation set
      # because it happens ot be first in the embedding matric.
      gt_score = scores[batch_index, labels[batch_index]].numpy()
      label = labels[batch_index].numpy()
      gt_rank = sorted_indices[batch_index].numpy().tolist().index(label)
      logging.info('Rank of GT: %d', gt_rank)
      logging.info('Score for GT: %f', gt_score)

      logging.info('')
def run_eval(base_dir):
    """Writes model's predictions in proper format to [base_dir]/answer.txt."""
    best_checkpoint_name = utils.pick_best_checkpoint(base_dir)

    dataset = prepare_dataset()
    checkpoint_path = os.path.join(base_dir, best_checkpoint_name)

    embedding_dim = tf.compat.v1.data.get_output_shapes(dataset)[0][-1]
    num_input_sentences = tf.compat.v1.data.get_output_shapes(dataset)[0][1]
    model = models.build_model(num_input_sentences=num_input_sentences,
                               embedding_dim=embedding_dim)

    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint.restore(checkpoint_path).expect_partial()

    logging.info('Evaluating with checkpoint: "%s"', checkpoint_path)
    test_accuracy = eval_single_checkpoint(model, dataset)

    with gfile.GFile(os.path.join(base_dir, 'test_spring2016_acc.txt'),
                     'w') as f:
        f.write(str(test_accuracy))
def all_distractors_eval(base_dir):
    """Computes model accuracy with all possible last sentences as distractors."""
    valid_dataset, all_emb_matrix = prepare_dataset()

    num_input_sentences = tf.compat.v1.data.get_output_shapes(
        valid_dataset)[0][1]
    model = models.build_model(num_input_sentences=num_input_sentences,
                               embedding_matrix=all_emb_matrix)

    checkpoint = tf.train.Checkpoint(model=model)
    checkpoint_path = utils.pick_best_checkpoint(base_dir, 'valid_nolabel_acc')

    logging.info('LOADING FROM CHECKPOINT: %s', checkpoint_path)

    result = checkpoint.restore(checkpoint_path).expect_partial()
    result.assert_nontrivial_match()

    ranks = []
    label_in_top10 = []

    for x, labels in valid_dataset:
        _, output_embedding = model(x, training=False)

        scores = tf.matmul(output_embedding, all_emb_matrix, transpose_b=True)
        sorted_indices = tf.argsort(scores, axis=-1, direction='DESCENDING')

        batch_size = x.shape[0]
        for batch_index in range(batch_size):
            # Note: the indexes in labels are only valid for the validation set
            # because it happens ot be first in the embedding matrix.
            label = labels[batch_index].numpy()
            gt_rank = sorted_indices[batch_index].numpy().tolist().index(label)

            top10predicted = sorted_indices[batch_index, :10]
            label_in_top10.append(1 if label in top10predicted else 0)

            top_predicted = sorted_indices[batch_index, 0]
            label_in_top10.append(1 if label == top_predicted else 0)
            ranks.append(gt_rank)