def _action(self, time_step, policy_state, seed: Optional[types.Seed] = None): if seed is not None and self._use_tf_function: logging.warning( 'Using `seed` may force a retrace for each call to `action`.') if self._batch_time_steps: time_step = nest_utils.batch_nested_array(time_step) # Avoid passing numpy arrays to avoid retracing of the tf.function. time_step = tf.nest.map_structure(tf.convert_to_tensor, time_step) if seed is not None: policy_step = self._policy_action_fn(time_step, policy_state, seed=seed) else: policy_step = self._policy_action_fn(time_step, policy_state) if not self._batch_time_steps: return policy_step return policy_step._replace( action=nest_utils.unbatch_nested_tensors_to_arrays( policy_step.action), # We intentionally do not convert the `state` so it is outputted as the # underlying policy generated it (i.e. in the form of a Tensor) which is # not necessarily compatible with a py-policy. However, we do so since # the `state` is fed back to the policy. So if it was converted, it'd be # required to convert back to the original form before calling the # method `action` of the policy again in the next step. If one wants to # store the `state` e.g. in replay buffer, then we suggest placing it # into the `info` field. info=nest_utils.unbatch_nested_tensors_to_arrays(policy_step.info))
def _action(self, time_step, policy_state): time_step = nest_utils.batch_nested_array(time_step) # Avoid passing numpy arrays to avoid retracing of the tf.function. time_step = tf.nest.map_structure(tf.convert_to_tensor, time_step) policy_step = self._policy_action_fn(time_step, policy_state) return policy_step._replace( action=nest_utils.unbatch_nested_tensors_to_arrays(policy_step.action), # We intentionally do not convert the `state` so it is outputted as the # underlying policy generated it (i.e. in the form of a Tensor) which is # not necessarily compatible with a py-policy. However, we do so since # the `state` is fed back to the policy. So if it was converted, it'd be # required to convert back to the original form before calling the # method `action` of the policy again in the next step. If one wants to # store the `state` e.g. in replay buffer, then we suggest placing it # into the `info` field. info=nest_utils.unbatch_nested_tensors_to_arrays(policy_step.info))