def test_masked_softmax_on_all_invalid_moves(self):
    # If all actions are illegal, the behavior is undefined (it can be nan
    # or can be 0. We add this test to document this behavior and know if we
    # change it.
    np_logits = np.asarray([[
        [1.0, 1.0, 1.0],
        [0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0],
    ]])
    logits = tf.Variable(np_logits, tf.float32)
    np_mask = np.asarray([[
        [1.0, 1.0, 1.0],
        [1.0, 0.0, 1.0],
        [0.0, 0.0, 0.0],
    ]])
    mask = tf.Variable(np_mask, tf.float32)

    expected = np.asarray([[
        [1 / 3, 1 / 3, 1 / 3],
        [1 / 2, 0.0, 1 / 2],
        [np.nan, np.nan, np.nan],
    ]])

    policy = masked_softmax.tf_masked_softmax(logits, mask)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      np_policy = sess.run(policy)
    np.testing.assert_array_almost_equal(expected, np_policy)

    # Numpy behaves similarly.
    np.testing.assert_array_almost_equal(
        expected, masked_softmax.np_masked_softmax(np_logits, np_mask))
  def test_tf_masked_softmax(self, np_logits, np_legal_actions, expected):
    logits = tf.Variable(np_logits, tf.float32)
    mask = tf.Variable(np_legal_actions, tf.float32)

    policy = masked_softmax.tf_masked_softmax(logits, mask)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      np_policy = sess.run(policy)

    np.testing.assert_array_almost_equal(expected, np_policy)
Esempio n. 3
0
 def masked_softmax(self, logits):
     """Safe masked softmax."""
     return masked_softmax.tf_masked_softmax(
         logits, self.tabular_policy.legal_actions_mask)