Ejemplo n.º 1
0
  def testComputeTopkScoresAndSeq(self):
    batch_size = 2
    beam_size = 3

    sequences = tf.constant([[[2, 3], [4, 5], [6, 7], [19, 20]],
                             [[8, 9], [10, 11], [12, 13], [80, 17]]])

    scores = tf.constant([[-0.1, -2.5, 0., -1.5],
                          [-100., -5., -0.00789, -1.34]])
    flags = tf.constant([[True, False, False, True],
                         [False, False, False, True]])

    topk_seq, topk_scores, topk_flags, _ = (
        beam_search.compute_topk_scores_and_seq(
            sequences, scores, scores, flags, beam_size, batch_size))

    with self.test_session():
      topk_seq = topk_seq.eval()
      topk_scores = topk_scores.eval()
      topk_flags = topk_flags.eval()

    exp_seq = [[[6, 7], [2, 3], [19, 20]], [[12, 13], [80, 17], [10, 11]]]
    exp_scores = [[0., -0.1, -1.5], [-0.00789, -1.34, -5.]]

    exp_flags = [[False, True, True], [False, True, False]]
    self.assertAllEqual(exp_seq, topk_seq)
    self.assertAllClose(exp_scores, topk_scores)
    self.assertAllEqual(exp_flags, topk_flags)
Ejemplo n.º 2
0
  def testComputeTopkScoresAndSeq(self):
    batch_size = 2
    beam_size = 3

    sequences = tf.constant([[[2, 3], [4, 5], [6, 7], [19, 20]],
                             [[8, 9], [10, 11], [12, 13], [80, 17]]])

    scores = tf.constant([[-0.1, -2.5, 0., -1.5],
                          [-100., -5., -0.00789, -1.34]])
    flags = tf.constant([[True, False, False, True],
                         [False, False, False, True]])

    topk_seq, topk_scores, topk_flags, _ = (
        beam_search.compute_topk_scores_and_seq(
            sequences, scores, scores, flags, beam_size, batch_size))

    with self.test_session():
      topk_seq = topk_seq.eval()
      topk_scores = topk_scores.eval()
      topk_flags = topk_flags.eval()

    exp_seq = [[[6, 7], [2, 3], [19, 20]], [[12, 13], [80, 17], [10, 11]]]
    exp_scores = [[0., -0.1, -1.5], [-0.00789, -1.34, -5.]]

    exp_flags = [[False, True, True], [False, True, False]]
    self.assertAllEqual(exp_seq, topk_seq)
    self.assertAllClose(exp_scores, topk_scores)
    self.assertAllEqual(exp_flags, topk_flags)