Ejemplo n.º 1
0
 def observe(self, obs, reward, done, reset):
     if self.training:
         self.t += 1
         assert self.last_state is not None
         assert self.last_action is not None
         # Add a transition to the replay buffer
         transition = {
             "state": self.last_state,
             "action": self.last_action,
             "reward": reward,
             "next_state": obs,
             "is_state_terminal": done,
         }
         if self.recurrent:
             transition["recurrent_state"] = recurrent_state_as_numpy(
                 get_recurrent_state_at(self.train_prev_recurrent_states,
                                        0,
                                        detach=True))
             self.train_prev_recurrent_states = None
             transition["next_recurrent_state"] = recurrent_state_as_numpy(
                 get_recurrent_state_at(self.train_recurrent_states,
                                        0,
                                        detach=True))
         self._send_to_learner(transition, stop_episode=done or reset)
         if (done or reset) and self.recurrent:
             self.train_prev_recurrent_states = None
             self.train_recurrent_states = None
     else:
         if (done or reset) and self.recurrent:
             self.test_recurrent_states = None
Ejemplo n.º 2
0
    def _test_non_lstm(self, gpu, name):
        in_size = 2
        out_size = 3
        device = "cuda:{}".format(gpu) if gpu >= 0 else "cpu"
        seqs_x = [
            torch.rand(4, in_size, device=device),
            torch.rand(1, in_size, device=device),
            torch.rand(3, in_size, device=device),
        ]
        seqs_x = torch.nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)
        self.assertTrue(name in ("GRU", "RNN"))
        cls = getattr(nn, name)
        link = cls(num_layers=1, input_size=in_size, hidden_size=out_size)
        link.to(device)

        # Forward twice: with None and non-None random states
        y0, h0 = link(seqs_x, None)
        y1, h1 = link(seqs_x, h0)
        y0, _ = torch.nn.utils.rnn.pad_packed_sequence(y0, batch_first=True)
        y1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1, batch_first=True)
        self.assertEqual(h0.shape, (1, 3, out_size))
        self.assertEqual(h1.shape, (1, 3, out_size))
        self.assertEqual(y0.shape, (3, 4, out_size))
        self.assertEqual(y1.shape, (3, 4, out_size))

        # Masked at 0
        rs0_mask0 = mask_recurrent_state_at(h0, 0)
        y1m0, _ = link(seqs_x, rs0_mask0)
        y1m0, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m0,
                                                         batch_first=True)
        torch_assert_allclose(y1m0[0], y0[0])
        torch_assert_allclose(y1m0[1], y1[1])
        torch_assert_allclose(y1m0[2], y1[2])

        # Masked at (1, 2)
        rs0_mask12 = mask_recurrent_state_at(h0, (1, 2))
        y1m12, _ = link(seqs_x, rs0_mask12)
        y1m12, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m12,
                                                          batch_first=True)
        torch_assert_allclose(y1m12[0], y1[0])
        torch_assert_allclose(y1m12[1], y0[1])
        torch_assert_allclose(y1m12[2], y0[2])

        # Get at 1 and concat with None
        rs0_get1 = get_recurrent_state_at(h0, 1, detach=False)
        assert rs0_get1.requires_grad
        torch_assert_allclose(rs0_get1, h0[:, 1])
        concat_rs_get1 = concatenate_recurrent_states([None, rs0_get1, None])
        y1g1, _ = link(seqs_x, concat_rs_get1)
        y1g1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1g1,
                                                         batch_first=True)
        torch_assert_allclose(y1g1[0], y0[0])
        torch_assert_allclose(y1g1[1], y1[1])
        torch_assert_allclose(y1g1[2], y0[2])

        # Get at 1 with detach=True
        rs0_get1_detach = get_recurrent_state_at(h0, 1, detach=True)
        assert not rs0_get1_detach.requires_grad
        torch_assert_allclose(rs0_get1_detach, h0[:, 1])
Ejemplo n.º 3
0
    def _batch_observe_train(
        self,
        batch_obs: Sequence[Any],
        batch_reward: Sequence[float],
        batch_done: Sequence[bool],
        batch_reset: Sequence[bool],
    ) -> None:

        for i in range(len(batch_obs)):
            self.t += 1
            self._cumulative_steps += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    "state": self.batch_last_obs[i],
                    "action": self.batch_last_action[i],
                    "reward": batch_reward[i],
                    "feature": self.batch_h[i],
                    "next_state": batch_obs[i],
                    "next_action": None,
                    "is_state_terminal": batch_done[i],
                }
                if self.recurrent:
                    transition["recurrent_state"] = recurrent_state_as_numpy(
                        get_recurrent_state_at(
                            self.train_prev_recurrent_states, i, detach=True))
                    transition[
                        "next_recurrent_state"] = recurrent_state_as_numpy(
                            get_recurrent_state_at(self.train_recurrent_states,
                                                   i,
                                                   detach=True))
                self.replay_buffer.append(env_id=i, **transition)

                self._backup_if_necessary(self.t, self.batch_h[i])

                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.train_prev_recurrent_states = None
            self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.train_recurrent_states,
            )
Ejemplo n.º 4
0
    def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset):
        assert self.training

        for i, (state, action, reward, next_state, done, reset) in enumerate(
            zip(  # NOQA
                self.batch_last_state,
                self.batch_last_action,
                batch_reward,
                batch_obs,
                batch_done,
                batch_reset,
            )
        ):
            if state is not None:
                assert action is not None
                transition = {
                    "state": state,
                    "action": action,
                    "reward": reward,
                    "next_state": next_state,
                    "nonterminal": 0.0 if done else 1.0,
                }
                if self.recurrent:
                    transition["recurrent_state"] = get_recurrent_state_at(
                        self.train_prev_recurrent_states, i, detach=True
                    )
                    transition["next_recurrent_state"] = get_recurrent_state_at(
                        self.train_recurrent_states, i, detach=True
                    )
                self.batch_last_episode[i].append(transition)
            if done or reset:
                assert self.batch_last_episode[i]
                self.memory.append(self.batch_last_episode[i])
                self.batch_last_episode[i] = []
            self.batch_last_state[i] = None
            self.batch_last_action[i] = None

        self.train_prev_recurrent_states = None

        if self.recurrent:
            # Reset recurrent states when episodes end
            indices_that_ended = [
                i
                for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
                if done or reset
            ]
            if indices_that_ended:
                self.train_recurrent_states = mask_recurrent_state_at(
                    self.train_recurrent_states, indices_that_ended
                )

        self._update_if_dataset_is_ready()
Ejemplo n.º 5
0
 def get_and_concat_rs_forward():
     _, rs = one_step_forward(par, x_t0, None)
     rs0 = get_recurrent_state_at(rs, 0, detach=True)
     rs1 = get_recurrent_state_at(rs, 1, detach=True)
     concat_rs = concatenate_recurrent_states([rs0, rs1])
     return one_step_forward(par, x_t1, concat_rs)