コード例 #1
0
  def testDisallowedAction(self, batch_size, num_actions):
    observation_distribution = tfd.Independent(
        tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2])))
    reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size))
    action_spec = tensor_spec.BoundedTensorSpec(
        shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32)

    env = random_bandit_environment.RandomBanditEnvironment(
        observation_distribution, reward_distribution, action_spec)
    masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment(
        env, lambda x, y: (x, y), 0.0)
    _, mask = self.evaluate(masked_env.reset().observation)
    surely_disallowed_actions = tf.argmin(mask, axis=-1, output_type=tf.int32)
    with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
                                'not in allowed'):
      self.evaluate(masked_env.step(surely_disallowed_actions).reward)
コード例 #2
0
  def testMaskedEnvironment(self, batch_size, num_actions):
    observation_distribution = tfd.Independent(
        tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2])))
    reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size))
    action_spec = tensor_spec.BoundedTensorSpec(
        shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32)

    env = random_bandit_environment.RandomBanditEnvironment(
        observation_distribution, reward_distribution, action_spec)
    masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment(
        env, lambda x, y: (x, y), 0.5)
    context, mask = self.evaluate(masked_env.reset().observation)
    self.assertAllEqual(tf.shape(context), [batch_size, 2])
    self.assertAllEqual(tf.shape(mask), [batch_size, num_actions])
    surely_allowed_actions = tf.argmax(mask, axis=-1, output_type=tf.int32)
    rewards = self.evaluate(masked_env.step(surely_allowed_actions).reward)
    self.assertAllEqual(tf.shape(rewards), [batch_size])