Ejemplo n.º 1
0
def make_timestep_mask(batched_next_time_step: ts.TimeStep,
                       allow_partial_episodes: bool = False) -> types.Tensor:
    """Create a mask for transitions and optionally final incomplete episodes.

  Args:
    batched_next_time_step: Next timestep, doubly-batched [batch_dim, time_dim,
      ...].
    allow_partial_episodes: If true, then steps on incomplete episodes are
      allowed.

  Returns:
    A mask, type tf.float32, that is 0.0 for all between-episode timesteps
      (batched_next_time_step is FIRST). If allow_partial_episodes is set to
      False, the mask has 0.0 for incomplete episode at the end of the sequence.
  """
    if allow_partial_episodes:
        episode_is_complete = None
    else:
        # 1.0 for timesteps of all complete episodes. 0.0 for incomplete episode at
        #   the end of the sequence.
        episode_is_complete = tf.cumsum(tf.cast(
            batched_next_time_step.is_last(), tf.float32),
                                        axis=1,
                                        reverse=True) > 0

    # 1.0 for all valid timesteps. 0.0 where between episodes.
    not_between_episodes = ~batched_next_time_step.is_first()

    if allow_partial_episodes:
        return tf.cast(not_between_episodes, tf.float32)
    else:
        return tf.cast(episode_is_complete & not_between_episodes, tf.float32)
Ejemplo n.º 2
0
    def _step(self, action):
        self._num_steps.assign_add(tf.ones_like(self._num_steps))

        time_step = super()._step(action)

        time_limit_terminations = tf.math.greater_equal(
            self._num_steps, self._duration)
        step_types = tf.where(condition=time_limit_terminations,
                              x=StepType.LAST,
                              y=time_step.step_type)
        discounts = tf.where(condition=time_limit_terminations,
                             x=0,
                             y=time_step.discount)
        new_time_step = TimeStep(step_types, time_step.reward, discounts,
                                 time_step.observation)
        self._env._time_step = new_time_step  # pylint: disable=protected-access

        # We convert the TF Tensors to numpy first for performance reasons.
        if any(new_time_step.is_last().numpy()):
            terminates = step_types == StepType.LAST
            termination_indexes = tf.where(terminates)
            number_terminations = tf.math.count_nonzero(terminates)
            # we use dtype tf.int32 because this avoids a GPU bug detected by Dongho
            self._num_steps.scatter_nd_update(
                termination_indexes,
                tf.constant(-1, shape=(number_terminations, ), dtype=tf.int32),
            )

        return new_time_step
Ejemplo n.º 3
0
    def should_reset(self, current_time_step: ts.TimeStep) -> bool:
        """Whether the Environmet should reset given the current timestep.

    By default it only resets when all time_steps are `LAST`.

    Args:
      current_time_step: The current `TimeStep`.

    Returns:
      A bool indicating whether the Environment should reset or not.
    """
        handle_auto_reset = getattr(self, '_handle_auto_reset', False)
        return handle_auto_reset and np.all(current_time_step.is_last())