コード例 #1
0
def make_timestep_mask(batched_next_time_step: ts.TimeStep,
                       allow_partial_episodes: bool = False) -> types.Tensor:
  """Create a mask for transitions and optionally final incomplete episodes.

  Args:
    batched_next_time_step: Next timestep, doubly-batched [batch_dim, time_dim,
      ...].
    allow_partial_episodes: If true, then steps on incomplete episodes are
      allowed.

  Returns:
    A mask, type tf.float32, that is 0.0 for all between-episode timesteps
      (batched_next_time_step is FIRST). If allow_partial_episodes is set to
      False, the mask has 0.0 for incomplete episode at the end of the sequence.
  """
  if allow_partial_episodes:
    episode_is_complete = None
  else:
    # 1.0 for timesteps of all complete episodes. 0.0 for incomplete episode at
    #   the end of the sequence.
    episode_is_complete = tf.cumsum(
        tf.cast(batched_next_time_step.is_last(), tf.float32),
        axis=1,
        reverse=True) > 0

  # 1.0 for all valid timesteps. 0.0 where between episodes.
  not_between_episodes = ~batched_next_time_step.is_first()

  if allow_partial_episodes:
    return tf.cast(not_between_episodes, tf.float32)
  else:
    return tf.cast(episode_is_complete & not_between_episodes, tf.float32)
コード例 #2
0
    def run(
        self, time_step: ts.TimeStep, policy_state: types.NestedArray = ()
    ) -> Tuple[ts.TimeStep, types.NestedArray]:
        """Run policy in environment given initial time_step and policy_state.

    Args:
      time_step: The initial time_step.
      policy_state: The initial policy_state.

    Returns:
      A tuple (final time_step, final policy_state).
    """
        num_steps = 0
        num_episodes = 0
        while num_steps < self._max_steps and num_episodes < self._max_episodes:
            # For now we reset the policy_state for non batched envs.
            if not self.env.batched and time_step.is_first(
            ) and num_episodes > 0:
                policy_state = self._policy.get_initial_state(
                    self.env.batch_size or 1)

            action_step = self.policy.action(time_step, policy_state)
            next_time_step = self.env.step(action_step.action)

            # When using observer (for the purpose of training), only the previous
            # policy_state is useful. Therefore substitube it in the PolicyStep and
            # consume it w/ the observer.
            action_step_with_previous_state = action_step._replace(
                state=policy_state)
            traj = trajectory.from_transition(time_step,
                                              action_step_with_previous_state,
                                              next_time_step)
            for observer in self._transition_observers:
                observer((time_step, action_step_with_previous_state,
                          next_time_step))
            for observer in self.observers:
                observer(traj)

            num_episodes += np.sum(traj.is_boundary())
            num_steps += np.sum(~traj.is_boundary())

            time_step = next_time_step
            policy_state = action_step.state

        return time_step, policy_state
コード例 #3
0
ファイル: py_driver.py プロジェクト: morgandu/agents
  def run(
      self,
      time_step: ts.TimeStep,
      policy_state: types.NestedArray = ()
  ) -> Tuple[ts.TimeStep, types.NestedArray]:
    """Run policy in environment given initial time_step and policy_state.

    Args:
      time_step: The initial time_step.
      policy_state: The initial policy_state.

    Returns:
      A tuple (final time_step, final policy_state).
    """
    num_steps = 0
    num_episodes = 0
    while num_steps < self._max_steps and num_episodes < self._max_episodes:
      # For now we reset the policy_state for non batched envs.
      if not self.env.batched and time_step.is_first() and num_episodes > 0:
        policy_state = self._policy.get_initial_state(self.env.batch_size or 1)

      action_step = self.policy.action(time_step, policy_state)
      next_time_step = self.env.step(action_step.action)

      traj = trajectory.from_transition(time_step, action_step, next_time_step)
      for observer in self._transition_observers:
        observer((time_step, action_step, next_time_step))
      for observer in self.observers:
        observer(traj)

      num_episodes += np.sum(traj.is_boundary())
      num_steps += np.sum(~traj.is_boundary())

      time_step = next_time_step
      policy_state = action_step.state

    return time_step, policy_state