def testWrongPolicyInfoType(self):
     dims = (10, 1)
     log_probability = tf.fill(dims, value=-0.5)
     info = policy_step.PolicyInfo(log_probability=log_probability)
     input_tensor = tf.fill(dims, value=_GREEDY)
     result = policy_utilities.set_bandit_policy_type(info, input_tensor)
     self.assertNotIsInstance(result, policy_utilities.PolicyInfo)
     self.assertAllEqual(info.log_probability, result.log_probability)
Example #2
0
 def testSetBanditPolicyType(self):
   dims = (10, 1)
   bandit_policy_spec = (
       policy_utilities.create_bandit_policy_type_tensor_spec(dims))
   info = policy_utilities.set_bandit_policy_type(None, bandit_policy_spec)
   self.assertIsInstance(info, policy_utilities.PolicyInfo)
   self.assertIsInstance(info.bandit_policy_type,
                         tensor_spec.BoundedTensorSpec)
   self.assertEqual(info.bandit_policy_type.shape, dims)
   self.assertEqual(info.bandit_policy_type.dtype, tf.int32)
   # Set to tensor.
   input_tensor = tf.fill(dims, value=_GREEDY)
   info = policy_utilities.set_bandit_policy_type(info, input_tensor)
   self.assertIsInstance(info.bandit_policy_type, tf.Tensor)
   self.assertEqual(info.bandit_policy_type.shape, input_tensor.shape)
   expected = [[_GREEDY] for _ in range(dims[0])]
   self.assertAllEqual(info.bandit_policy_type, expected)
Example #3
0
    def _action(self, time_step, policy_state, seed):
        seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy')
        greedy_action = self._greedy_policy.action(time_step, policy_state)
        random_action = self._random_policy.action(time_step, (),
                                                   seed_stream())

        outer_shape = nest_utils.get_outer_shape(time_step,
                                                 self._time_step_spec)
        rng = tf.random.uniform(outer_shape,
                                maxval=1.0,
                                seed=seed_stream(),
                                name='epsilon_rng')
        cond = tf.greater(rng, self._get_epsilon())

        # Selects the action/info from the random policy with probability epsilon.
        # TODO(b/133175894): tf.compat.v1.where only supports a condition which is
        # either a scalar or a vector. Use tf.compat.v2 so that it can support any
        # condition whose leading dimensions are the same as the other operands of
        # tf.where.
        outer_ndims = int(outer_shape.shape[0])
        if outer_ndims >= 2:
            raise ValueError(
                'Only supports batched time steps with a single batch dimension'
            )
        action = tf.nest.map_structure(
            lambda g, r: tf.compat.v1.where(cond, g, r), greedy_action.action,
            random_action.action)

        if greedy_action.info:
            if not random_action.info:
                raise ValueError('Incompatible info field')
            info = nest_utils.where(cond, greedy_action.info,
                                    random_action.info)
            # Overwrite bandit policy info type.
            if policy_utilities.has_bandit_policy_type(info,
                                                       check_for_tensor=True):
                # Generate mask of the same shape as bandit_policy_type (batch_size, 1).
                # This is the opposite of `cond`, which is 1-D bool tensor (batch_size,)
                # that is true when greedy policy was used, otherwise `cond` is false.
                random_policy_mask = tf.reshape(
                    tf.logical_not(cond), tf.shape(info.bandit_policy_type))
                bandit_policy_type = policy_utilities.bandit_policy_uniform_mask(
                    info.bandit_policy_type, mask=random_policy_mask)
                info = policy_utilities.set_bandit_policy_type(
                    info, bandit_policy_type)
        else:
            if random_action.info:
                raise ValueError('Incompatible info field')
            info = ()

        # The state of the epsilon greedy policy is the state of the underlying
        # greedy policy (the random policy carries no state).
        # It is commonly assumed that the new policy state only depends only
        # on the previous state and "time_step", the action (be it the greedy one
        # or the random one) does not influence the new policy state.
        state = greedy_action.state

        return policy_step.PolicyStep(action, state, info)
Example #4
0
    def _action(self, time_step, policy_state, seed):
        seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy')
        greedy_action = self._greedy_policy.action(time_step, policy_state)
        epsilon_action = self._epsilon_policy.action(time_step, (),
                                                     seed_stream())

        outer_shape = nest_utils.get_outer_shape(time_step,
                                                 self._time_step_spec)
        rng = tf.random.uniform(outer_shape,
                                maxval=1.0,
                                seed=seed_stream(),
                                name='epsilon_rng')
        cond = tf.greater(rng, self._get_epsilon())

        # Selects the action/info from the random policy with probability epsilon.
        # TODO(b/133175894): tf.compat.v1.where only supports a condition which is
        # either a scalar or a vector. Use tf.compat.v2 so that it can support any
        # condition whose leading dimensions are the same as the other operands of
        # tf.where.
        outer_ndims = int(outer_shape.shape[0])
        if outer_ndims >= 2:
            raise ValueError(
                'Only supports batched time steps with a single batch dimension'
            )
        action = tf.nest.map_structure(
            lambda g, r: tf.compat.v1.where(cond, g, r), greedy_action.action,
            epsilon_action.action)

        if greedy_action.info:
            if not epsilon_action.info:
                raise ValueError('Incompatible info field')
            # Note that the objects in PolicyInfo may have different shapes, so we
            # need to call nest_utils.where() on each type of object.
            info = tf.nest.map_structure(
                lambda x, y: nest_utils.where(cond, x, y), greedy_action.info,
                epsilon_action.info)
            if self._emit_log_probability:
                # At this point, info.log_probability contains the log prob of the
                # action chosen, conditioned on the policy that was chosen. We want to
                # emit the full log probability of the action, so we'll add in the log
                # probability of choosing the policy.
                random_log_prob = tf.nest.map_structure(
                    lambda t: tf.math.log(
                        tf.zeros_like(t) + self._get_epsilon()),
                    info.log_probability)
                greedy_log_prob = tf.nest.map_structure(
                    lambda t: tf.math.log(
                        tf.ones_like(t) - self._get_epsilon()),
                    random_log_prob)
                log_prob_of_chosen_policy = nest_utils.where(
                    cond, greedy_log_prob, random_log_prob)
                log_prob = tf.nest.map_structure(lambda a, b: a + b,
                                                 log_prob_of_chosen_policy,
                                                 info.log_probability)
                info = policy_step.set_log_probability(info, log_prob)
            # Overwrite bandit policy info type.
            if policy_utilities.has_bandit_policy_type(info,
                                                       check_for_tensor=True):
                # Generate mask of the same shape as bandit_policy_type (batch_size, 1).
                # This is the opposite of `cond`, which is 1-D bool tensor (batch_size,)
                # that is true when greedy policy was used, otherwise `cond` is false.
                random_policy_mask = tf.reshape(
                    tf.logical_not(cond), tf.shape(info.bandit_policy_type))
                bandit_policy_type = policy_utilities.bandit_policy_uniform_mask(
                    info.bandit_policy_type, mask=random_policy_mask)
                info = policy_utilities.set_bandit_policy_type(
                    info, bandit_policy_type)
        else:
            if epsilon_action.info:
                raise ValueError('Incompatible info field')
            info = ()

        # The state of the epsilon greedy policy is the state of the underlying
        # greedy policy (the random policy carries no state).
        # It is commonly assumed that the new policy state only depends only
        # on the previous state and "time_step", the action (be it the greedy one
        # or the random one) does not influence the new policy state.
        state = greedy_action.state

        return policy_step.PolicyStep(action, state, info)