def testDisallowedAction(self, batch_size, num_actions): observation_distribution = tfd.Independent( tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2]))) reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size)) action_spec = tensor_spec.BoundedTensorSpec( shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32) env = random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution, action_spec) masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment( env, lambda x, y: (x, y), 0.0) _, mask = self.evaluate(masked_env.reset().observation) surely_disallowed_actions = tf.argmin(mask, axis=-1, output_type=tf.int32) with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'not in allowed'): self.evaluate(masked_env.step(surely_disallowed_actions).reward)
def testMaskedEnvironment(self, batch_size, num_actions): observation_distribution = tfd.Independent( tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2]))) reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size)) action_spec = tensor_spec.BoundedTensorSpec( shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32) env = random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution, action_spec) masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment( env, lambda x, y: (x, y), 0.5) context, mask = self.evaluate(masked_env.reset().observation) self.assertAllEqual(tf.shape(context), [batch_size, 2]) self.assertAllEqual(tf.shape(mask), [batch_size, num_actions]) surely_allowed_actions = tf.argmax(mask, axis=-1, output_type=tf.int32) rewards = self.evaluate(masked_env.step(surely_allowed_actions).reward) self.assertAllEqual(tf.shape(rewards), [batch_size])