def test_eos_masking(self): probs = tf.constant([ [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]], [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]], ]) eos_token = 0 previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) with self.cached_session() as sess: probs = sess.run(probs) masked = sess.run(masked) self.assertAllEqual(probs[0][0], masked[0][0]) self.assertAllEqual(probs[0][2], masked[0][2]) self.assertAllEqual(probs[1][0], masked[1][0]) self.assertEqual(masked[0][1][0], 0) self.assertEqual(masked[1][1][0], 0) self.assertEqual(masked[1][2][0], 0) for i in range(1, 5): self.assertAllClose(masked[0][1][i], np.finfo('float32').min) self.assertAllClose(masked[1][1][i], np.finfo('float32').min) self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
def test_eos_masking(): probs = tf.constant([ [ [-0.2, -0.2, -0.2, -0.2, -0.2], [-0.3, -0.3, -0.3, 3, 0], [5, 6, 0, 0, 0], ], [ [-0.2, -0.2, -0.2, -0.2, 0], [-0.3, -0.3, -0.1, 3, 0], [5, 6, 3, 0, 0], ], ]) eos_token = 0 previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) masked = masked.numpy() np.testing.assert_equal(probs[0][0], masked[0][0]) np.testing.assert_equal(probs[0][2], masked[0][2]) np.testing.assert_equal(probs[1][0], masked[1][0]) np.testing.assert_equal(masked[0][1][0], 0) np.testing.assert_equal(masked[1][1][0], 0) np.testing.assert_equal(masked[1][2][0], 0) for i in range(1, 5): np.testing.assert_allclose(masked[0][1][i], np.finfo("float32").min) np.testing.assert_allclose(masked[1][1][i], np.finfo("float32").min) np.testing.assert_allclose(masked[1][2][i], np.finfo("float32").min)