Esempio n. 1
0
    def update(self, timestep: dm_env.TimeStep, action: base.Action,
               new_timestep: dm_env.TimeStep):
        """Receives a transition and performs a learning update."""

        # Insert this step into our rolling window 'batch'.
        items = [
            timestep.observation, action, new_timestep.reward,
            new_timestep.discount,
            float(not timestep.first())
        ]
        for buf, item in zip(self._buffer, items):
            buf[self._num_transitions_in_buffer % self._sequence_length,
                0] = item
        self._num_transitions_in_buffer += 1

        # When the batch is full, do a step of SGD.
        if self._num_transitions_in_buffer % self._sequence_length != 0:
            return

        transitions = (
            self._buffer + [
                tf.expand_dims(new_timestep.observation, axis=0),  # final_obs
                tf.expand_dims(float(not new_timestep.first()),
                               axis=0),  # final_mask
            ])
        self._rollout_initial_state = self._step(transitions)
Esempio n. 2
0
    def add_first(self, timestep: dm_env.TimeStep):
        """Record the first observation of a trajectory."""
        if not timestep.first():
            raise ValueError(
                'adder.add_first with an initial timestep (i.e. one for '
                'which timestep.first() is True')

        # Record the next observation but leave the history buffer row open by
        # passing `partial_step=True`.
        self._writer.append(dict(observation=timestep.observation,
                                 start_of_episode=timestep.first()),
                            partial_step=True)
Esempio n. 3
0
 def policy(self, timestep: dm_env.TimeStep) -> base.Action:
   """Selects actions according to the latest softmax policy."""
   action, self.state = self._policy_fn(
       np.expand_dims(timestep.observation, 0),
       np.expand_dims(float(not timestep.first()), 0),
       self.state)
   return np.int32(action)
Esempio n. 4
0
    def _track(self, timestep: dm_env.TimeStep):
        # Count transitions only.
        if not timestep.first():
            self._steps += 1
            self._episode_len += 1
        if timestep.last():
            self._episode += 1
        self._episode_return += timestep.reward or 0.0
        self._total_return += timestep.reward or 0.0

        # Log statistics periodically, either by step or by episode.
        if self._log_by_step:
            if _logarithmic_logging(self._steps) or self._log_every:
                self._log_bsuite_data()

        elif timestep.last():
            if _logarithmic_logging(self._episode) or self._log_every:
                self._log_bsuite_data()

        # Perform bookkeeping at the end of episodes.
        if timestep.last():
            self._episode_len = 0
            self._episode_return = 0.0

        if self._episode == self._env.bsuite_num_episodes:
            self.flush()
Esempio n. 5
0
 def policy(self, timestep: dm_env.TimeStep) -> base.Action:
     """Selects actions according to the latest softmax policy."""
     observation = tf.expand_dims(timestep.observation, axis=0)
     mask = tf.expand_dims(float(not timestep.first()), axis=0)
     (logits, _), self._state = self._forward((observation, mask),
                                              self._state)
     return tf.random.categorical(logits, num_samples=1).numpy().squeeze()
Esempio n. 6
0
    def step(
        self,
        environment: Optional[dm_env.Environment],
        timestep_t: dm_env.TimeStep,
        agent: Optional[Agent],
        a_t: Optional[Action],
    ) -> None:
        """Accumulates statistics from timestep."""
        del (environment, agent, a_t)

        if timestep_t.first():
            if self._current_episode_rewards:
                raise ValueError(
                    'Current episode reward list should be empty.')
            if self._current_episode_step != 0:
                raise ValueError('Current episode step should be zero.')
        else:
            # First reward is invalid, all other rewards are appended.
            self._current_episode_rewards.append(timestep_t.reward)

        self._num_steps_since_reset += 1
        self._current_episode_step += 1

        if timestep_t.last():
            self._episode_returns.append(sum(self._current_episode_rewards))
            self._current_episode_rewards = []
            self._num_steps_over_episodes += self._current_episode_step
            self._current_episode_step = 0
Esempio n. 7
0
File: base.py Progetto: srsohn/mtsgi
    def observe(
        self,
        action: types.NestedArray,
        next_timestep: dm_env.TimeStep,
    ):
        action = np.expand_dims(action, axis=-1)
        next_rewards = np.expand_dims(next_timestep.reward, axis=-1)
        is_first = np.expand_dims(next_timestep.first(), axis=-1)  # for mask
        avg_rewards = np.take_along_axis(self._avg_rewards, action, axis=-1)
        counts = np.take_along_axis(self._counts, action, axis=-1)

        # Compute & update avg rewards.
        update_values = 1 / counts * (next_rewards - avg_rewards)
        next_avg_rewards = avg_rewards + np.where(
            is_first, 0, update_values)  # skip first timestep.
        np.put_along_axis(self._avg_rewards,
                          action,
                          values=next_avg_rewards,
                          axis=-1)

        # Update counts.
        np.put_along_axis(self._counts,
                          action,
                          values=counts + (1 - is_first),
                          axis=-1)
        self._total_counts += (1 - is_first).squeeze()
Esempio n. 8
0
 def _add_reward_noise(self, timestep: dm_env.TimeStep):
     if timestep.first():
         return timestep
     reward = timestep.reward + self._noise_scale * self._rng.randn()
     return dm_env.TimeStep(step_type=timestep.step_type,
                            reward=reward,
                            discount=timestep.discount,
                            observation=timestep.observation)
Esempio n. 9
0
 def _rescale_rewards(self, timestep: dm_env.TimeStep):
     if timestep.first():
         return timestep
     reward = timestep.reward * self._reward_scale
     return dm_env.TimeStep(step_type=timestep.step_type,
                            reward=reward,
                            discount=timestep.discount,
                            observation=timestep.observation)
Esempio n. 10
0
 def select_action(self, timestep: dm_env.TimeStep) -> base.Action:
   """Selects actions according to the latest softmax policy."""
   if timestep.first():
     self._state = self._network.initial_state(1)
     self._rollout_initial_state = self._network.initial_state(1)
   observation = tf.expand_dims(timestep.observation, axis=0)
   (logits, _), self._state = self._forward(observation, self._state)
   return tf.random.categorical(logits, num_samples=1).numpy().squeeze()
Esempio n. 11
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, maybe yields transitions."""
        if timestep_t.first():
            self.reset()

        # There are no transitions on the first timestep.
        if self._timestep_tm1 is None:
            assert self._a_tm1 is None
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterable.

        self._transitions.append(
            Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
                a_t=a_t,
                mc_return_tm1=None,
            ))

        self._timestep_tm1 = timestep_t
        self._a_tm1 = a_t

        if timestep_t.last():
            # Annotate all episode transitions with their MC returns.
            mc_return = 0
            mc_transitions = []
            while self._transitions:
                transition = self._transitions.pop()
                mc_return = transition.discount_t * mc_return + transition.r_t
                mc_transitions.append(
                    transition._replace(mc_return_tm1=mc_return))
            for transition in reversed(mc_transitions):
                yield transition

        else:
            # Wait for episode end before yielding anything.
            return
Esempio n. 12
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, yields transitions."""
        if timestep_t.first():
            self.reset()

        # There are no transitions on the first timestep.
        if self._timestep_tm1 is None:
            assert self._a_tm1 is None
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterator.

        self._transitions.append(
            Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
            ))

        self._timestep_tm1 = timestep_t
        self._a_tm1 = a_t

        if timestep_t.last():
            # Yield any remaining n, n-1, ..., 1-step transitions at episode end.
            while self._transitions:
                yield _build_n_step_transition(self._transitions)
                self._transitions.popleft()
        else:
            # Wait for n transitions before yielding anything.
            if len(self._transitions) < self._transitions.maxlen:
                return  # Empty iterator.

            assert len(self._transitions) == self._transitions.maxlen

            # This is the typical case, yield a single n-step transition.
            yield _build_n_step_transition(self._transitions)
Esempio n. 13
0
  def observe_first(self, timestep: dm_env.TimeStep):
    if self._queue is not None:
      self._queue.add_first(timestep)

    # Set the state to None so that we re-initialize at the next policy call.
    assert timestep.step_type.shape[0] == 1, \
        'Currently only supports single worker.'

    # Reset hidden state every new episode.
    if timestep.first():
      self._state = None
Esempio n. 14
0
    def _postprocess_observation(self,
                                 timestep: dm_env.TimeStep) -> dm_env.TimeStep:
        """Observation processing applied after action repeat consolidation."""

        if timestep.first():
            return dm_env.restart(timestep.observation)

        reward = np.clip(timestep.reward, -self._max_abs_reward,
                         self._max_abs_reward)

        return timestep._replace(reward=reward)
Esempio n. 15
0
 def _add_reward_noise(self, timestep: dm_env.TimeStep):
     if timestep.first():
         return timestep
     if self._bernoulli:
         reward = self._rng.binomial(p=timestep.reward, n=1,
                                     size=1).astype(np.float64)[0]
     else:
         reward = timestep.reward + self._noise_scale * self._rng.randn()
     return dm_env.TimeStep(step_type=timestep.step_type,
                            reward=reward,
                            discount=timestep.discount,
                            observation=timestep.observation)
Esempio n. 16
0
File: base.py Progetto: zzp110/acme
  def add_first(self, timestep: dm_env.TimeStep):
    """Record the first observation of a trajectory."""
    if not timestep.first():
      raise ValueError('adder.add_first with an initial timestep (i.e. one for '
                       'which timestep.first() is True')

    if self._next_observation is not None:
      raise ValueError('adder.reset must be called before adder.add_first '
                       '(called automatically if `next_timestep.last()` is '
                       'true when `add` is called).')

    # Record the next observation.
    self._next_observation = timestep.observation
Esempio n. 17
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, maybe yield a transition."""
        if timestep_t.first():
            self.reset()

        if self._timestep_tm1 is None:
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterable.
        else:
            transition = Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
            )
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            yield transition
Esempio n. 18
0
  def observe(
      self,
      action: types.NestedArray,
      next_timestep: dm_env.TimeStep,
  ):
    assert next_timestep.step_type.shape[0] == 1, \
        'Currently only supports single worker.'

    # Reset hidden state every new episode.
    if next_timestep.first():
      self._state = None

    if self._queue is None:
      return

    extras = {'logits': self._prev_logits, 'core_state': self._prev_state}
    extras = tf2_utils.to_numpy(extras)
    self._queue.add(action, next_timestep, extras)
Esempio n. 19
0
  def step(self, timestep: dm_env.TimeStep) -> None:
    """Accumulates statistics from timestep."""
    if timestep.first():
      if self._current_episode_rewards:
        raise ValueError('Current episode reward list should be empty.')
      if self._current_episode_step != 0:
        raise ValueError('Current episode step should be zero.')
    else:
      # First reward is invalid, all other rewards are appended.
      self._current_episode_rewards.append(timestep.reward)

    self._num_steps_since_reset += 1
    self._current_episode_step += 1

    if timestep.last():
      self._episode_returns.append(sum(self._current_episode_rewards))
      self._current_episode_rewards = []
      self._num_steps_over_episodes += self._current_episode_step
      self._current_episode_step = 0
Esempio n. 20
0
    def add_first(self,
                  timestep: dm_env.TimeStep,
                  extras: Dict[str, types.NestedArray] = {}) -> None:
        """Record the first observation of a trajectory."""
        if not timestep.first():
            raise ValueError(
                "adder.add_first with an initial timestep (i.e. one for "
                "which timestep.first() is True")

        if self._next_observations is not None:
            raise ValueError(
                "adder.reset must be called before adder.add_first "
                "(called automatically if `next_timestep.last()` is "
                "true when `add` is called).")

        # Record the next observation.
        self._next_observations = timestep.observation
        self._start_of_episode = True

        if self._use_next_extras:
            self._next_extras = extras
Esempio n. 21
0
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""

        if not self._add_first_called:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        # Add the timestep to the buffer.
        has_extras = (
            len(extras) > 0 if isinstance(extras, Sized)  # pylint: disable=g-explicit-length-test
            else extras is not None)
        current_step = dict(
            # Observation was passed at the previous add call.
            action=action,
            reward=next_timestep.reward,
            discount=next_timestep.discount,
            # Start of episode indicator was passed at the previous add call.
            **({
                'extras': extras
            } if has_extras else {}))
        self._writer.append(current_step)

        # Record the next observation and write.
        self._writer.append(dict(observation=next_timestep.observation,
                                 start_of_episode=next_timestep.first()),
                            partial_step=True)
        self._write()

        if next_timestep.last():
            # Complete the row by appending zeros to remaining open fields.
            # TODO(b/183945808): remove this when fields are no longer expected to be
            # of equal length on the learner side.
            dummy_step = tree.map_structure(np.zeros_like, current_step)
            self._writer.append(dummy_step)
            self._write_last()
            self.reset()
Esempio n. 22
0
    def step(self, timestep: dm_env.TimeStep, loss, shaped_reward,
             penalties) -> None:
        """Accumulates statistics from timestep."""

        if timestep.first():
            if self._current_episode_rewards:
                raise ValueError(
                    'Current episode reward list should be empty.')
            if self._current_episode_step != 0:
                raise ValueError('Current episode step should be zero.')
        else:
            # First reward is invalid, all other rewards are appended.
            self._current_episode_rewards.append(timestep.reward)

        if shaped_reward is not None:
            if isinstance(shaped_reward, list):
                self._current_episode_shaped_rewards.extend(shaped_reward)
            else:
                self._current_episode_shaped_rewards.append(shaped_reward)
        if loss is not None:
            self._current_episode_loss += loss

        if penalties is not None:
            if isinstance(penalties, list):
                self._current_episode_penalties.extend(penalties)
            else:
                self._current_episode_penalties.append(penalties)

        self._num_steps_since_reset += 1
        self._current_episode_step += 1

        if timestep.last():
            self._episode_returns.append(sum(self._current_episode_rewards))
            self._current_episode_rewards = []
            self._current_episode_shaped_rewards = []
            self._current_episode_penalties = []
            self._current_episode_loss = 0
            self._num_steps_over_episodes += self._current_episode_step
            self._current_episode_step = 0