def make_timestep_mask(batched_next_time_step: ts.TimeStep, allow_partial_episodes: bool = False) -> types.Tensor: """Create a mask for transitions and optionally final incomplete episodes. Args: batched_next_time_step: Next timestep, doubly-batched [batch_dim, time_dim, ...]. allow_partial_episodes: If true, then steps on incomplete episodes are allowed. Returns: A mask, type tf.float32, that is 0.0 for all between-episode timesteps (batched_next_time_step is FIRST). If allow_partial_episodes is set to False, the mask has 0.0 for incomplete episode at the end of the sequence. """ if allow_partial_episodes: episode_is_complete = None else: # 1.0 for timesteps of all complete episodes. 0.0 for incomplete episode at # the end of the sequence. episode_is_complete = tf.cumsum( tf.cast(batched_next_time_step.is_last(), tf.float32), axis=1, reverse=True) > 0 # 1.0 for all valid timesteps. 0.0 where between episodes. not_between_episodes = ~batched_next_time_step.is_first() if allow_partial_episodes: return tf.cast(not_between_episodes, tf.float32) else: return tf.cast(episode_is_complete & not_between_episodes, tf.float32)
def run( self, time_step: ts.TimeStep, policy_state: types.NestedArray = () ) -> Tuple[ts.TimeStep, types.NestedArray]: """Run policy in environment given initial time_step and policy_state. Args: time_step: The initial time_step. policy_state: The initial policy_state. Returns: A tuple (final time_step, final policy_state). """ num_steps = 0 num_episodes = 0 while num_steps < self._max_steps and num_episodes < self._max_episodes: # For now we reset the policy_state for non batched envs. if not self.env.batched and time_step.is_first( ) and num_episodes > 0: policy_state = self._policy.get_initial_state( self.env.batch_size or 1) action_step = self.policy.action(time_step, policy_state) next_time_step = self.env.step(action_step.action) # When using observer (for the purpose of training), only the previous # policy_state is useful. Therefore substitube it in the PolicyStep and # consume it w/ the observer. action_step_with_previous_state = action_step._replace( state=policy_state) traj = trajectory.from_transition(time_step, action_step_with_previous_state, next_time_step) for observer in self._transition_observers: observer((time_step, action_step_with_previous_state, next_time_step)) for observer in self.observers: observer(traj) num_episodes += np.sum(traj.is_boundary()) num_steps += np.sum(~traj.is_boundary()) time_step = next_time_step policy_state = action_step.state return time_step, policy_state
def run( self, time_step: ts.TimeStep, policy_state: types.NestedArray = () ) -> Tuple[ts.TimeStep, types.NestedArray]: """Run policy in environment given initial time_step and policy_state. Args: time_step: The initial time_step. policy_state: The initial policy_state. Returns: A tuple (final time_step, final policy_state). """ num_steps = 0 num_episodes = 0 while num_steps < self._max_steps and num_episodes < self._max_episodes: # For now we reset the policy_state for non batched envs. if not self.env.batched and time_step.is_first() and num_episodes > 0: policy_state = self._policy.get_initial_state(self.env.batch_size or 1) action_step = self.policy.action(time_step, policy_state) next_time_step = self.env.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) for observer in self._transition_observers: observer((time_step, action_step, next_time_step)) for observer in self.observers: observer(traj) num_episodes += np.sum(traj.is_boundary()) num_steps += np.sum(~traj.is_boundary()) time_step = next_time_step policy_state = action_step.state return time_step, policy_state