Beispiel #1
0
    def _test_non_lstm(self, gpu, name):
        in_size = 2
        out_size = 3
        device = "cuda:{}".format(gpu) if gpu >= 0 else "cpu"
        seqs_x = [
            torch.rand(4, in_size, device=device),
            torch.rand(1, in_size, device=device),
            torch.rand(3, in_size, device=device),
        ]
        seqs_x = torch.nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)
        self.assertTrue(name in ("GRU", "RNN"))
        cls = getattr(nn, name)
        link = cls(num_layers=1, input_size=in_size, hidden_size=out_size)
        link.to(device)

        # Forward twice: with None and non-None random states
        y0, h0 = link(seqs_x, None)
        y1, h1 = link(seqs_x, h0)
        y0, _ = torch.nn.utils.rnn.pad_packed_sequence(y0, batch_first=True)
        y1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1, batch_first=True)
        self.assertEqual(h0.shape, (1, 3, out_size))
        self.assertEqual(h1.shape, (1, 3, out_size))
        self.assertEqual(y0.shape, (3, 4, out_size))
        self.assertEqual(y1.shape, (3, 4, out_size))

        # Masked at 0
        rs0_mask0 = mask_recurrent_state_at(h0, 0)
        y1m0, _ = link(seqs_x, rs0_mask0)
        y1m0, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m0,
                                                         batch_first=True)
        torch_assert_allclose(y1m0[0], y0[0])
        torch_assert_allclose(y1m0[1], y1[1])
        torch_assert_allclose(y1m0[2], y1[2])

        # Masked at (1, 2)
        rs0_mask12 = mask_recurrent_state_at(h0, (1, 2))
        y1m12, _ = link(seqs_x, rs0_mask12)
        y1m12, _ = torch.nn.utils.rnn.pad_packed_sequence(y1m12,
                                                          batch_first=True)
        torch_assert_allclose(y1m12[0], y1[0])
        torch_assert_allclose(y1m12[1], y0[1])
        torch_assert_allclose(y1m12[2], y0[2])

        # Get at 1 and concat with None
        rs0_get1 = get_recurrent_state_at(h0, 1, detach=False)
        assert rs0_get1.requires_grad
        torch_assert_allclose(rs0_get1, h0[:, 1])
        concat_rs_get1 = concatenate_recurrent_states([None, rs0_get1, None])
        y1g1, _ = link(seqs_x, concat_rs_get1)
        y1g1, _ = torch.nn.utils.rnn.pad_packed_sequence(y1g1,
                                                         batch_first=True)
        torch_assert_allclose(y1g1[0], y0[0])
        torch_assert_allclose(y1g1[1], y1[1])
        torch_assert_allclose(y1g1[2], y0[2])

        # Get at 1 with detach=True
        rs0_get1_detach = get_recurrent_state_at(h0, 1, detach=True)
        assert not rs0_get1_detach.requires_grad
        torch_assert_allclose(rs0_get1_detach, h0[:, 1])
Beispiel #2
0
    def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset):
        assert self.training

        for i, (state, action, reward, next_state, done, reset) in enumerate(
            zip(  # NOQA
                self.batch_last_state,
                self.batch_last_action,
                batch_reward,
                batch_obs,
                batch_done,
                batch_reset,
            )
        ):
            if state is not None:
                assert action is not None
                transition = {
                    "state": state,
                    "action": action,
                    "reward": reward,
                    "next_state": next_state,
                    "nonterminal": 0.0 if done else 1.0,
                }
                if self.recurrent:
                    transition["recurrent_state"] = get_recurrent_state_at(
                        self.train_prev_recurrent_states, i, detach=True
                    )
                    transition["next_recurrent_state"] = get_recurrent_state_at(
                        self.train_recurrent_states, i, detach=True
                    )
                self.batch_last_episode[i].append(transition)
            if done or reset:
                assert self.batch_last_episode[i]
                self.memory.append(self.batch_last_episode[i])
                self.batch_last_episode[i] = []
            self.batch_last_state[i] = None
            self.batch_last_action[i] = None

        self.train_prev_recurrent_states = None

        if self.recurrent:
            # Reset recurrent states when episodes end
            indices_that_ended = [
                i
                for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
                if done or reset
            ]
            if indices_that_ended:
                self.train_recurrent_states = mask_recurrent_state_at(
                    self.train_recurrent_states, indices_that_ended
                )

        self._update_if_dataset_is_ready()
Beispiel #3
0
 def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset):
     assert not self.training
     if self.recurrent:
         # Reset recurrent states when episodes end
         indices_that_ended = [
             i
             for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
             if done or reset
         ]
         if indices_that_ended:
             self.test_recurrent_states = mask_recurrent_state_at(
                 self.test_recurrent_states, indices_that_ended
             )
Beispiel #4
0
def _batch_reset_recurrent_states_when_episodes_end(batch_done, batch_reset,
                                                    recurrent_states):
    """Reset recurrent states when episodes end.

    Args:
        batch_done (array-like of bool): True iff episodes are terminal.
        batch_reset (array-like of bool): True iff episodes will be reset.
        recurrent_states (object): Recurrent state.

    Returns:
        object: New recurrent states.
    """
    indices_that_ended = [
        i for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
        if done or reset
    ]
    if indices_that_ended:
        return mask_recurrent_state_at(recurrent_states, indices_that_ended)
    else:
        return recurrent_states
Beispiel #5
0
 def mask01_forward_twice():
     _, rs = one_step_forward(par, x_t0, None)
     rs = mask_recurrent_state_at(rs, [0, 1])
     return one_step_forward(par, x_t1, rs)