def testWrongPolicyInfoType(self): dims = (10, 1) log_probability = tf.fill(dims, value=-0.5) info = policy_step.PolicyInfo(log_probability=log_probability) input_tensor = tf.fill(dims, value=_GREEDY) result = policy_utilities.set_bandit_policy_type(info, input_tensor) self.assertNotIsInstance(result, policy_utilities.PolicyInfo) self.assertAllEqual(info.log_probability, result.log_probability)
def testSetBanditPolicyType(self): dims = (10, 1) bandit_policy_spec = ( policy_utilities.create_bandit_policy_type_tensor_spec(dims)) info = policy_utilities.set_bandit_policy_type(None, bandit_policy_spec) self.assertIsInstance(info, policy_utilities.PolicyInfo) self.assertIsInstance(info.bandit_policy_type, tensor_spec.BoundedTensorSpec) self.assertEqual(info.bandit_policy_type.shape, dims) self.assertEqual(info.bandit_policy_type.dtype, tf.int32) # Set to tensor. input_tensor = tf.fill(dims, value=_GREEDY) info = policy_utilities.set_bandit_policy_type(info, input_tensor) self.assertIsInstance(info.bandit_policy_type, tf.Tensor) self.assertEqual(info.bandit_policy_type.shape, input_tensor.shape) expected = [[_GREEDY] for _ in range(dims[0])] self.assertAllEqual(info.bandit_policy_type, expected)
def _action(self, time_step, policy_state, seed): seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy') greedy_action = self._greedy_policy.action(time_step, policy_state) random_action = self._random_policy.action(time_step, (), seed_stream()) outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) rng = tf.random.uniform(outer_shape, maxval=1.0, seed=seed_stream(), name='epsilon_rng') cond = tf.greater(rng, self._get_epsilon()) # Selects the action/info from the random policy with probability epsilon. # TODO(b/133175894): tf.compat.v1.where only supports a condition which is # either a scalar or a vector. Use tf.compat.v2 so that it can support any # condition whose leading dimensions are the same as the other operands of # tf.where. outer_ndims = int(outer_shape.shape[0]) if outer_ndims >= 2: raise ValueError( 'Only supports batched time steps with a single batch dimension' ) action = tf.nest.map_structure( lambda g, r: tf.compat.v1.where(cond, g, r), greedy_action.action, random_action.action) if greedy_action.info: if not random_action.info: raise ValueError('Incompatible info field') info = nest_utils.where(cond, greedy_action.info, random_action.info) # Overwrite bandit policy info type. if policy_utilities.has_bandit_policy_type(info, check_for_tensor=True): # Generate mask of the same shape as bandit_policy_type (batch_size, 1). # This is the opposite of `cond`, which is 1-D bool tensor (batch_size,) # that is true when greedy policy was used, otherwise `cond` is false. random_policy_mask = tf.reshape( tf.logical_not(cond), tf.shape(info.bandit_policy_type)) bandit_policy_type = policy_utilities.bandit_policy_uniform_mask( info.bandit_policy_type, mask=random_policy_mask) info = policy_utilities.set_bandit_policy_type( info, bandit_policy_type) else: if random_action.info: raise ValueError('Incompatible info field') info = () # The state of the epsilon greedy policy is the state of the underlying # greedy policy (the random policy carries no state). # It is commonly assumed that the new policy state only depends only # on the previous state and "time_step", the action (be it the greedy one # or the random one) does not influence the new policy state. state = greedy_action.state return policy_step.PolicyStep(action, state, info)
def _action(self, time_step, policy_state, seed): seed_stream = tfp.util.SeedStream(seed=seed, salt='epsilon_greedy') greedy_action = self._greedy_policy.action(time_step, policy_state) epsilon_action = self._epsilon_policy.action(time_step, (), seed_stream()) outer_shape = nest_utils.get_outer_shape(time_step, self._time_step_spec) rng = tf.random.uniform(outer_shape, maxval=1.0, seed=seed_stream(), name='epsilon_rng') cond = tf.greater(rng, self._get_epsilon()) # Selects the action/info from the random policy with probability epsilon. # TODO(b/133175894): tf.compat.v1.where only supports a condition which is # either a scalar or a vector. Use tf.compat.v2 so that it can support any # condition whose leading dimensions are the same as the other operands of # tf.where. outer_ndims = int(outer_shape.shape[0]) if outer_ndims >= 2: raise ValueError( 'Only supports batched time steps with a single batch dimension' ) action = tf.nest.map_structure( lambda g, r: tf.compat.v1.where(cond, g, r), greedy_action.action, epsilon_action.action) if greedy_action.info: if not epsilon_action.info: raise ValueError('Incompatible info field') # Note that the objects in PolicyInfo may have different shapes, so we # need to call nest_utils.where() on each type of object. info = tf.nest.map_structure( lambda x, y: nest_utils.where(cond, x, y), greedy_action.info, epsilon_action.info) if self._emit_log_probability: # At this point, info.log_probability contains the log prob of the # action chosen, conditioned on the policy that was chosen. We want to # emit the full log probability of the action, so we'll add in the log # probability of choosing the policy. random_log_prob = tf.nest.map_structure( lambda t: tf.math.log( tf.zeros_like(t) + self._get_epsilon()), info.log_probability) greedy_log_prob = tf.nest.map_structure( lambda t: tf.math.log( tf.ones_like(t) - self._get_epsilon()), random_log_prob) log_prob_of_chosen_policy = nest_utils.where( cond, greedy_log_prob, random_log_prob) log_prob = tf.nest.map_structure(lambda a, b: a + b, log_prob_of_chosen_policy, info.log_probability) info = policy_step.set_log_probability(info, log_prob) # Overwrite bandit policy info type. if policy_utilities.has_bandit_policy_type(info, check_for_tensor=True): # Generate mask of the same shape as bandit_policy_type (batch_size, 1). # This is the opposite of `cond`, which is 1-D bool tensor (batch_size,) # that is true when greedy policy was used, otherwise `cond` is false. random_policy_mask = tf.reshape( tf.logical_not(cond), tf.shape(info.bandit_policy_type)) bandit_policy_type = policy_utilities.bandit_policy_uniform_mask( info.bandit_policy_type, mask=random_policy_mask) info = policy_utilities.set_bandit_policy_type( info, bandit_policy_type) else: if epsilon_action.info: raise ValueError('Incompatible info field') info = () # The state of the epsilon greedy policy is the state of the underlying # greedy policy (the random policy carries no state). # It is commonly assumed that the new policy state only depends only # on the previous state and "time_step", the action (be it the greedy one # or the random one) does not influence the new policy state. state = greedy_action.state return policy_step.PolicyStep(action, state, info)