def test_masked_softmax_on_all_invalid_moves(self): # If all actions are illegal, the behavior is undefined (it can be nan # or can be 0. We add this test to document this behavior and know if we # change it. np_logits = np.asarray([[ [1.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ]]) logits = tf.Variable(np_logits, tf.float32) np_mask = np.asarray([[ [1.0, 1.0, 1.0], [1.0, 0.0, 1.0], [0.0, 0.0, 0.0], ]]) mask = tf.Variable(np_mask, tf.float32) expected = np.asarray([[ [1 / 3, 1 / 3, 1 / 3], [1 / 2, 0.0, 1 / 2], [np.nan, np.nan, np.nan], ]]) policy = masked_softmax.tf_masked_softmax(logits, mask) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) np_policy = sess.run(policy) np.testing.assert_array_almost_equal(expected, np_policy) # Numpy behaves similarly. np.testing.assert_array_almost_equal( expected, masked_softmax.np_masked_softmax(np_logits, np_mask))
def test_tf_masked_softmax(self, np_logits, np_legal_actions, expected): logits = tf.Variable(np_logits, tf.float32) mask = tf.Variable(np_legal_actions, tf.float32) policy = masked_softmax.tf_masked_softmax(logits, mask) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) np_policy = sess.run(policy) np.testing.assert_array_almost_equal(expected, np_policy)
def masked_softmax(self, logits): """Safe masked softmax.""" return masked_softmax.tf_masked_softmax( logits, self.tabular_policy.legal_actions_mask)