def make_timestep_mask(batched_next_time_step: ts.TimeStep, allow_partial_episodes: bool = False) -> types.Tensor: """Create a mask for transitions and optionally final incomplete episodes. Args: batched_next_time_step: Next timestep, doubly-batched [batch_dim, time_dim, ...]. allow_partial_episodes: If true, then steps on incomplete episodes are allowed. Returns: A mask, type tf.float32, that is 0.0 for all between-episode timesteps (batched_next_time_step is FIRST). If allow_partial_episodes is set to False, the mask has 0.0 for incomplete episode at the end of the sequence. """ if allow_partial_episodes: episode_is_complete = None else: # 1.0 for timesteps of all complete episodes. 0.0 for incomplete episode at # the end of the sequence. episode_is_complete = tf.cumsum(tf.cast( batched_next_time_step.is_last(), tf.float32), axis=1, reverse=True) > 0 # 1.0 for all valid timesteps. 0.0 where between episodes. not_between_episodes = ~batched_next_time_step.is_first() if allow_partial_episodes: return tf.cast(not_between_episodes, tf.float32) else: return tf.cast(episode_is_complete & not_between_episodes, tf.float32)
def _step(self, action): self._num_steps.assign_add(tf.ones_like(self._num_steps)) time_step = super()._step(action) time_limit_terminations = tf.math.greater_equal( self._num_steps, self._duration) step_types = tf.where(condition=time_limit_terminations, x=StepType.LAST, y=time_step.step_type) discounts = tf.where(condition=time_limit_terminations, x=0, y=time_step.discount) new_time_step = TimeStep(step_types, time_step.reward, discounts, time_step.observation) self._env._time_step = new_time_step # pylint: disable=protected-access # We convert the TF Tensors to numpy first for performance reasons. if any(new_time_step.is_last().numpy()): terminates = step_types == StepType.LAST termination_indexes = tf.where(terminates) number_terminations = tf.math.count_nonzero(terminates) # we use dtype tf.int32 because this avoids a GPU bug detected by Dongho self._num_steps.scatter_nd_update( termination_indexes, tf.constant(-1, shape=(number_terminations, ), dtype=tf.int32), ) return new_time_step
def should_reset(self, current_time_step: ts.TimeStep) -> bool: """Whether the Environmet should reset given the current timestep. By default it only resets when all time_steps are `LAST`. Args: current_time_step: The current `TimeStep`. Returns: A bool indicating whether the Environment should reset or not. """ handle_auto_reset = getattr(self, '_handle_auto_reset', False) return handle_auto_reset and np.all(current_time_step.is_last())