Esempio n. 1
0
    def test_eos_masking(self):
        probs = tf.constant([
            [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0,
                                                                0]],
            [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]],
        ])

        eos_token = 0
        previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
        masked = beam_search_decoder._mask_probs(probs, eos_token,
                                                 previously_finished)

        with self.cached_session() as sess:
            probs = sess.run(probs)
            masked = sess.run(masked)

            self.assertAllEqual(probs[0][0], masked[0][0])
            self.assertAllEqual(probs[0][2], masked[0][2])
            self.assertAllEqual(probs[1][0], masked[1][0])

            self.assertEqual(masked[0][1][0], 0)
            self.assertEqual(masked[1][1][0], 0)
            self.assertEqual(masked[1][2][0], 0)

            for i in range(1, 5):
                self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
                self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
                self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
Esempio n. 2
0
def test_eos_masking():
    probs = tf.constant([
        [
            [-0.2, -0.2, -0.2, -0.2, -0.2],
            [-0.3, -0.3, -0.3, 3, 0],
            [5, 6, 0, 0, 0],
        ],
        [
            [-0.2, -0.2, -0.2, -0.2, 0],
            [-0.3, -0.3, -0.1, 3, 0],
            [5, 6, 3, 0, 0],
        ],
    ])

    eos_token = 0
    previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
    masked = beam_search_decoder._mask_probs(probs, eos_token,
                                             previously_finished)
    masked = masked.numpy()

    np.testing.assert_equal(probs[0][0], masked[0][0])
    np.testing.assert_equal(probs[0][2], masked[0][2])
    np.testing.assert_equal(probs[1][0], masked[1][0])

    np.testing.assert_equal(masked[0][1][0], 0)
    np.testing.assert_equal(masked[1][1][0], 0)
    np.testing.assert_equal(masked[1][2][0], 0)

    for i in range(1, 5):
        np.testing.assert_allclose(masked[0][1][i], np.finfo("float32").min)
        np.testing.assert_allclose(masked[1][1][i], np.finfo("float32").min)
        np.testing.assert_allclose(masked[1][2][i], np.finfo("float32").min)