Exemple #1
0
 def _action(self,
             time_step,
             policy_state,
             seed: Optional[types.Seed] = None):
     if seed is not None and self._use_tf_function:
         logging.warning(
             'Using `seed` may force a retrace for each call to `action`.')
     if self._batch_time_steps:
         time_step = nest_utils.batch_nested_array(time_step)
     # Avoid passing numpy arrays to avoid retracing of the tf.function.
     time_step = tf.nest.map_structure(tf.convert_to_tensor, time_step)
     if seed is not None:
         policy_step = self._policy_action_fn(time_step,
                                              policy_state,
                                              seed=seed)
     else:
         policy_step = self._policy_action_fn(time_step, policy_state)
     if not self._batch_time_steps:
         return policy_step
     return policy_step._replace(
         action=nest_utils.unbatch_nested_tensors_to_arrays(
             policy_step.action),
         # We intentionally do not convert the `state` so it is outputted as the
         # underlying policy generated it (i.e. in the form of a Tensor) which is
         # not necessarily compatible with a py-policy. However, we do so since
         # the `state` is fed back to the policy. So if it was converted, it'd be
         # required to convert back to the original form before calling the
         # method `action` of the policy again in the next step. If one wants to
         # store the `state` e.g. in replay buffer, then we suggest placing it
         # into the `info` field.
         info=nest_utils.unbatch_nested_tensors_to_arrays(policy_step.info))
 def _action(self, time_step, policy_state):
   time_step = nest_utils.batch_nested_array(time_step)
   # Avoid passing numpy arrays to avoid retracing of the tf.function.
   time_step = tf.nest.map_structure(tf.convert_to_tensor, time_step)
   policy_step = self._policy_action_fn(time_step, policy_state)
   return policy_step._replace(
       action=nest_utils.unbatch_nested_tensors_to_arrays(policy_step.action),
       # We intentionally do not convert the `state` so it is outputted as the
       # underlying policy generated it (i.e. in the form of a Tensor) which is
       # not necessarily compatible with a py-policy. However, we do so since
       # the `state` is fed back to the policy. So if it was converted, it'd be
       # required to convert back to the original form before calling the
       # method `action` of the policy again in the next step. If one wants to
       # store the `state` e.g. in replay buffer, then we suggest placing it
       # into the `info` field.
       info=nest_utils.unbatch_nested_tensors_to_arrays(policy_step.info))