Пример #1
0
  def test_masked_softmax_on_all_invalid_moves(self):
    # If all actions are illegal, the behavior is undefined (it can be nan
    # or can be 0. We add this test to document this behavior and know if we
    # change it.
    np_logits = np.asarray([[
        [1.0, 1.0, 1.0],
        [0.0, 0.0, 0.0],
        [0.0, 0.0, 0.0],
    ]])
    logits = tf.Variable(np_logits, tf.float32)
    np_mask = np.asarray([[
        [1.0, 1.0, 1.0],
        [1.0, 0.0, 1.0],
        [0.0, 0.0, 0.0],
    ]])
    mask = tf.Variable(np_mask, tf.float32)

    expected = np.asarray([[
        [1 / 3, 1 / 3, 1 / 3],
        [1 / 2, 0.0, 1 / 2],
        [np.nan, np.nan, np.nan],
    ]])

    policy = masked_softmax.tf_masked_softmax(logits, mask)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      np_policy = sess.run(policy)
    np.testing.assert_array_almost_equal(expected, np_policy)

    # Numpy behaves similarly.
    np.testing.assert_array_almost_equal(
        expected, masked_softmax.np_masked_softmax(np_logits, np_mask))
Пример #2
0
  def value_and_prior(self, state):
    state_feature = self.feature_extractor(state)
    with self.device:
      value, policy = self.model(state_feature)

    # renormalize policy over legal actions
    policy = np.array(policy)[0]
    mask = np.array(state.legal_actions_mask())
    policy = masked_softmax.np_masked_softmax(policy, mask)
    policy = [(action, policy[action]) for action in state.legal_actions()]

    # value is required to be array over players
    value = value[0, 0].numpy()
    if state.current_player() == 0:
      values = np.array([value, -value])
    else:
      values = np.array([-value, value])

    return (values, policy)
Пример #3
0
 def inference(self, obs, mask):
   with self._device:
     value, policy = self._keras_model(obs)
   policy = masked_softmax.np_masked_softmax(np.array(policy), np.array(mask))
   return value, policy
Пример #4
0
 def test_np_masked_softmax(self, logits, legal_actions, expected):
   np.testing.assert_array_almost_equal(
       expected, masked_softmax.np_masked_softmax(logits, legal_actions))