コード例 #1
0
 def testInvalidRewardEventShape(self, overall_reward_shape):
   observation_distribution = tfd.Normal(
       tf.zeros(overall_reward_shape[0]),
       tf.ones(overall_reward_shape[0]))
   reward_distribution = tfd.Independent(
       tfd.Normal(tf.zeros(overall_reward_shape),
                  tf.ones(overall_reward_shape)))
   with self.assertRaisesRegexp(
       ValueError, '`reward_distribution` must have event_shape ()'):
     random_bandit_environment.RandomBanditEnvironment(
         observation_distribution, reward_distribution)
コード例 #2
0
def get_gaussian_random_environment(observation_shape, action_shape,
                                    batch_size):
    """Returns a RandomBanditEnvironment with Gaussian observation and reward."""
    overall_shape = [batch_size] + observation_shape
    observation_distribution = tfd.Independent(
        tfd.Normal(loc=tf.zeros(overall_shape), scale=tf.ones(overall_shape)))
    reward_distribution = tfd.Normal(loc=tf.zeros(batch_size),
                                     scale=tf.ones(batch_size))
    action_spec = tensor_spec.TensorSpec(shape=action_shape, dtype=tf.float32)
    return random_bandit_environment.RandomBanditEnvironment(
        observation_distribution, reward_distribution, action_spec)
コード例 #3
0
def get_bounded_reward_random_environment(
    observation_shape, action_shape, batch_size, num_actions):
  """Returns a RandomBanditEnvironment with U(0, 1) observation and reward."""
  overall_shape = [batch_size] + observation_shape
  observation_distribution = tfd.Independent(
      tfd.Uniform(low=tf.zeros(overall_shape), high=tf.ones(overall_shape)))
  reward_distribution = tfd.Uniform(
      low=tf.zeros(batch_size), high=tf.ones(batch_size))
  action_spec = tensor_spec.BoundedTensorSpec(
      shape=action_shape, dtype=tf.int32, minimum=0, maximum=num_actions - 1)
  return random_bandit_environment.RandomBanditEnvironment(
      observation_distribution, reward_distribution, action_spec)
コード例 #4
0
 def testInvalidObservationBatchShape(
     self, overall_observation_shape, batch_dims):
   observation_distribution = tfd.Independent(
       tfd.Normal(tf.zeros(overall_observation_shape),
                  tf.ones(overall_observation_shape)),
       reinterpreted_batch_ndims=batch_dims)
   reward_distribution = tfd.Normal(tf.zeros(overall_observation_shape[0]),
                                    tf.ones(overall_observation_shape[0]))
   with self.assertRaisesRegexp(
       ValueError,
       '`observation_distribution` must have batch shape with length 1'):
     random_bandit_environment.RandomBanditEnvironment(
         observation_distribution, reward_distribution)
コード例 #5
0
 def testMismatchedBatchShape(
     self, overall_observation_shape, overall_reward_shape):
   observation_distribution = tfd.Independent(
       tfd.Normal(tf.zeros(overall_observation_shape),
                  tf.ones(overall_observation_shape)))
   reward_distribution = tfd.Independent(
       tfd.Normal(tf.zeros(overall_reward_shape),
                  tf.ones(overall_reward_shape)))
   with self.assertRaisesRegexp(
       ValueError,
       '`reward_distribution` and `observation_distribution` must have the '
       'same batch shape'):
     random_bandit_environment.RandomBanditEnvironment(
         observation_distribution, reward_distribution)
コード例 #6
0
  def testDisallowedAction(self, batch_size, num_actions):
    observation_distribution = tfd.Independent(
        tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2])))
    reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size))
    action_spec = tensor_spec.BoundedTensorSpec(
        shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32)

    env = random_bandit_environment.RandomBanditEnvironment(
        observation_distribution, reward_distribution, action_spec)
    masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment(
        env, lambda x, y: (x, y), 0.0)
    _, mask = self.evaluate(masked_env.reset().observation)
    surely_disallowed_actions = tf.argmin(mask, axis=-1, output_type=tf.int32)
    with self.assertRaisesRegex(tf.errors.InvalidArgumentError,
                                'not in allowed'):
      self.evaluate(masked_env.step(surely_disallowed_actions).reward)
コード例 #7
0
  def testMaskedEnvironment(self, batch_size, num_actions):
    observation_distribution = tfd.Independent(
        tfd.Normal(tf.zeros([batch_size, 2]), tf.ones([batch_size, 2])))
    reward_distribution = tfd.Normal(tf.zeros(batch_size), tf.ones(batch_size))
    action_spec = tensor_spec.BoundedTensorSpec(
        shape=(), minimum=0, maximum=num_actions - 1, dtype=tf.int32)

    env = random_bandit_environment.RandomBanditEnvironment(
        observation_distribution, reward_distribution, action_spec)
    masked_env = masked_tf_env.BernoulliActionMaskTFEnvironment(
        env, lambda x, y: (x, y), 0.5)
    context, mask = self.evaluate(masked_env.reset().observation)
    self.assertAllEqual(tf.shape(context), [batch_size, 2])
    self.assertAllEqual(tf.shape(mask), [batch_size, num_actions])
    surely_allowed_actions = tf.argmax(mask, axis=-1, output_type=tf.int32)
    rewards = self.evaluate(masked_env.step(surely_allowed_actions).reward)
    self.assertAllEqual(tf.shape(rewards), [batch_size])