def testInvalidRewardEventShape(self, overall_reward_shape): observation_distribution = tfd.Normal( tf.zeros(overall_reward_shape[0]), tf.ones(overall_reward_shape[0])) reward_distribution = tfd.Independent( tfd.Normal(tf.zeros(overall_reward_shape), tf.ones(overall_reward_shape))) with self.assertRaisesRegexp( ValueError, '`reward_distribution` must have event_shape ()'): random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution)
def get_gaussian_random_environment(observation_shape, action_shape, batch_size): """Returns a RandomBanditEnvironment with Gaussian observation and reward.""" overall_shape = [batch_size] + observation_shape observation_distribution = tfd.Independent( tfd.Normal(loc=tf.zeros(overall_shape), scale=tf.ones(overall_shape))) reward_distribution = tfd.Normal(loc=tf.zeros(batch_size), scale=tf.ones(batch_size)) action_spec = tensor_spec.TensorSpec(shape=action_shape, dtype=tf.float32) return random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution, action_spec)
def get_bounded_reward_random_environment( observation_shape, action_shape, batch_size, num_actions): """Returns a RandomBanditEnvironment with U(0, 1) observation and reward.""" overall_shape = [batch_size] + observation_shape observation_distribution = tfd.Independent( tfd.Uniform(low=tf.zeros(overall_shape), high=tf.ones(overall_shape))) reward_distribution = tfd.Uniform( low=tf.zeros(batch_size), high=tf.ones(batch_size)) action_spec = tensor_spec.BoundedTensorSpec( shape=action_shape, dtype=tf.int32, minimum=0, maximum=num_actions - 1) return random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution, action_spec)
def testInvalidObservationBatchShape( self, overall_observation_shape, batch_dims): observation_distribution = tfd.Independent( tfd.Normal(tf.zeros(overall_observation_shape), tf.ones(overall_observation_shape)), reinterpreted_batch_ndims=batch_dims) reward_distribution = tfd.Normal(tf.zeros(overall_observation_shape[0]), tf.ones(overall_observation_shape[0])) with self.assertRaisesRegexp( ValueError, '`observation_distribution` must have batch shape with length 1'): random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution)
def testMismatchedBatchShape( self, overall_observation_shape, overall_reward_shape): observation_distribution = tfd.Independent( tfd.Normal(tf.zeros(overall_observation_shape), tf.ones(overall_observation_shape))) reward_distribution = tfd.Independent( tfd.Normal(tf.zeros(overall_reward_shape), tf.ones(overall_reward_shape))) with self.assertRaisesRegexp( ValueError, '`reward_distribution` and `observation_distribution` must have the ' 'same batch shape'): random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution)
def testDisallowedAction(self, batch_size, num_actions): observation_distribution = tfd.Independent( tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2]))) reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size)) action_spec = tensor_spec.BoundedTensorSpec( shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32) env = random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution, action_spec) masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment( env, lambda x, y: (x, y), 0.0) _, mask = self.evaluate(masked_env.reset().observation) surely_disallowed_actions = tf.argmin(mask, axis=-1, output_type=tf.int32) with self.assertRaisesRegex(tf.errors.InvalidArgumentError, 'not in allowed'): self.evaluate(masked_env.step(surely_disallowed_actions).reward)
def testMaskedEnvironment(self, batch_size, num_actions): observation_distribution = tfd.Independent( tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2]))) reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size)) action_spec = tensor_spec.BoundedTensorSpec( shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32) env = random_bandit_environment.RandomBanditEnvironment( observation_distribution, reward_distribution, action_spec) masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment( env, lambda x, y: (x, y), 0.5) context, mask = self.evaluate(masked_env.reset().observation) self.assertAllEqual(tf.shape(context), [batch_size, 2]) self.assertAllEqual(tf.shape(mask), [batch_size, num_actions]) surely_allowed_actions = tf.argmax(mask, axis=-1, output_type=tf.int32) rewards = self.evaluate(masked_env.step(surely_allowed_actions).reward) self.assertAllEqual(tf.shape(rewards), [batch_size])