def observe(self, obs, reward, done, reset):
     if self.training:
         self.t += 1
         assert self.last_state is not None
         assert self.last_action is not None
         # Add a transition to the replay buffer
         transition = {
             "state": self.last_state,
             "action": self.last_action,
             "reward": reward,
             "next_state": obs,
             "is_state_terminal": done,
         }
         if self.recurrent:
             transition["recurrent_state"] = recurrent_state_as_numpy(
                 get_recurrent_state_at(self.train_prev_recurrent_states,
                                        0,
                                        detach=True))
             self.train_prev_recurrent_states = None
             transition["next_recurrent_state"] = recurrent_state_as_numpy(
                 get_recurrent_state_at(self.train_recurrent_states,
                                        0,
                                        detach=True))
         self._send_to_learner(transition, stop_episode=done or reset)
         if (done or reset) and self.recurrent:
             self.train_prev_recurrent_states = None
             self.train_recurrent_states = None
     else:
         if (done or reset) and self.recurrent:
             self.test_recurrent_states = None
Example #2
0
    def _batch_observe_train(
        self,
        batch_obs: Sequence[Any],
        batch_reward: Sequence[float],
        batch_done: Sequence[bool],
        batch_reset: Sequence[bool],
    ) -> None:

        for i in range(len(batch_obs)):
            self.t += 1
            self._cumulative_steps += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    "state": self.batch_last_obs[i],
                    "action": self.batch_last_action[i],
                    "reward": batch_reward[i],
                    "feature": self.batch_h[i],
                    "next_state": batch_obs[i],
                    "next_action": None,
                    "is_state_terminal": batch_done[i],
                }
                if self.recurrent:
                    transition["recurrent_state"] = recurrent_state_as_numpy(
                        get_recurrent_state_at(
                            self.train_prev_recurrent_states, i, detach=True))
                    transition[
                        "next_recurrent_state"] = recurrent_state_as_numpy(
                            get_recurrent_state_at(self.train_recurrent_states,
                                                   i,
                                                   detach=True))
                self.replay_buffer.append(env_id=i, **transition)

                self._backup_if_necessary(self.t, self.batch_h[i])

                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.train_prev_recurrent_states = None
            self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.train_recurrent_states,
            )