Example #1
0
    def _test_forward(self, gpu):
        in_size = 2
        out0_size = 3
        out1_size = 4
        out2_size = 1

        par = RecurrentBranched(
            nn.LSTM(num_layers=1, input_size=in_size, hidden_size=out0_size),
            RecurrentSequential(
                nn.RNN(num_layers=1, input_size=in_size, hidden_size=out1_size),
            ),
            RecurrentSequential(nn.Linear(in_size, out2_size),),
        )

        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            par.to(device)
        else:
            device = torch.device("cpu")

        seqs_x = [
            torch.rand(1, in_size, device=device),
            torch.rand(3, in_size, device=device),
        ]

        packed_x = nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)

        # Concatenated output should be a tuple of three variables.
        out, rs = par(packed_x, None)
        self.assertIsInstance(out, tuple)
        self.assertEqual(len(out), len(par))
        self.assertEqual(out[0].data.shape, (4, out0_size))
        self.assertEqual(out[1].data.shape, (4, out1_size))
        self.assertEqual(out[2].data.shape, (4, out2_size))

        self.assertIsInstance(rs, tuple)
        self.assertEqual(len(rs), len(par))

        # LSTM
        self.assertIsInstance(rs[0], tuple)
        self.assertEqual(len(rs[0]), 2)
        self.assertEqual(rs[0][0].shape, (1, len(seqs_x), out0_size))
        self.assertEqual(rs[0][1].shape, (1, len(seqs_x), out0_size))

        # RecurrentSequential(RNN)
        self.assertIsInstance(rs[1], tuple)
        self.assertEqual(len(rs[1]), 1)
        self.assertEqual(rs[1][0].shape, (1, len(seqs_x), out1_size))

        # RecurrentSequential(Linear)
        self.assertIsInstance(rs[2], tuple)
        self.assertEqual(len(rs[2]), 0)
Example #2
0
    def make_model(self, env):
        hidden_size = 20
        obs_size = env.observation_space.low.size

        def weight_scale(layer, scale):
            with torch.no_grad():
                layer.weight.mul_(scale)
            return layer

        if self.recurrent:
            v = RecurrentSequential(
                nn.LSTM(num_layers=1,
                        input_size=obs_size,
                        hidden_size=hidden_size),
                weight_scale(nn.Linear(hidden_size, 1), 1e-1),
            )
            if self.discrete:
                n_actions = env.action_space.n
                pi = RecurrentSequential(
                    nn.LSTM(num_layers=1,
                            input_size=obs_size,
                            hidden_size=hidden_size),
                    weight_scale(nn.Linear(hidden_size, n_actions), 1e-1),
                    SoftmaxCategoricalHead(),
                )
            else:
                action_size = env.action_space.low.size
                pi = RecurrentSequential(
                    nn.LSTM(num_layers=1,
                            input_size=obs_size,
                            hidden_size=hidden_size),
                    weight_scale(nn.Linear(hidden_size, action_size), 1e-1),
                    GaussianHeadWithStateIndependentCovariance(
                        action_size=action_size,
                        var_type="diagonal",
                        var_func=lambda x: torch.exp(2 * x),
                        var_param_init=0,
                    ),
                )
            return RecurrentBranched(pi, v)
        else:
            v = nn.Sequential(
                nn.Linear(obs_size, hidden_size),
                nn.Tanh(),
                weight_scale(nn.Linear(hidden_size, 1), 1e-1),
            )
            if self.discrete:
                n_actions = env.action_space.n
                pi = nn.Sequential(
                    nn.Linear(obs_size, hidden_size),
                    nn.Tanh(),
                    weight_scale(nn.Linear(hidden_size, n_actions), 1e-1),
                    SoftmaxCategoricalHead(),
                )
            else:
                action_size = env.action_space.low.size
                pi = nn.Sequential(
                    nn.Linear(obs_size, hidden_size),
                    nn.Tanh(),
                    weight_scale(nn.Linear(hidden_size, action_size), 1e-1),
                    GaussianHeadWithStateIndependentCovariance(
                        action_size=action_size,
                        var_type="diagonal",
                        var_func=lambda x: torch.exp(2 * x),
                        var_param_init=0,
                    ),
                )
            return pfrl.nn.Branched(pi, v)
Example #3
0
    def _test_forward_with_modified_recurrent_state(self, gpu):
        in_size = 2
        out0_size = 2
        out1_size = 3
        par = RecurrentBranched(
            nn.GRU(num_layers=1, input_size=in_size, hidden_size=out0_size),
            RecurrentSequential(
                nn.LSTM(num_layers=1, input_size=in_size, hidden_size=out1_size),
            ),
        )
        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            par.to(device)
        else:
            device = torch.device("cpu")
        seqs_x = [
            torch.rand(2, in_size, device=device),
            torch.rand(2, in_size, device=device),
        ]
        packed_x = nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)
        x_t0 = torch.stack((seqs_x[0][0], seqs_x[1][0]))
        x_t1 = torch.stack((seqs_x[0][1], seqs_x[1][1]))

        (gru_out, lstm_out), (gru_rs, (lstm_rs,)) = par(packed_x, None)

        # Check if n_step_forward and forward twice results are same
        def no_mask_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            return one_step_forward(par, x_t1, rs)

        (
            (nomask_gru_out, nomask_lstm_out),
            (nomask_gru_rs, (nomask_lstm_rs,)),
        ) = no_mask_forward_twice()

        # GRU
        torch_assert_allclose(gru_out.data[2:], nomask_gru_out, atol=1e-5)
        torch_assert_allclose(gru_rs, nomask_gru_rs)

        # LSTM
        torch_assert_allclose(lstm_out.data[2:], nomask_lstm_out, atol=1e-5)
        torch_assert_allclose(lstm_rs[0], nomask_lstm_rs[0], atol=1e-5)
        torch_assert_allclose(lstm_rs[1], nomask_lstm_rs[1], atol=1e-5)

        # 1st-only mask forward twice: only 2nd should be the same
        def mask0_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            rs = mask_recurrent_state_at(rs, 0)
            return one_step_forward(par, x_t1, rs)

        (
            (mask0_gru_out, mask0_lstm_out),
            (mask0_gru_rs, (mask0_lstm_rs,)),
        ) = mask0_forward_twice()

        # GRU
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[2], mask0_gru_out[0], atol=1e-5)
        torch_assert_allclose(gru_out.data[3], mask0_gru_out[1], atol=1e-5)

        # LSTM
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[2], mask0_lstm_out[0], atol=1e-5)
        torch_assert_allclose(lstm_out.data[3], mask0_lstm_out[1], atol=1e-5)

        # 2nd-only mask forward twice: only 1st should be the same
        def mask1_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            rs = mask_recurrent_state_at(rs, 1)
            return one_step_forward(par, x_t1, rs)

        (
            (mask1_gru_out, mask1_lstm_out),
            (mask1_gru_rs, (mask1_lstm_rs,)),
        ) = mask1_forward_twice()

        # GRU
        torch_assert_allclose(gru_out.data[2], mask1_gru_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[3], mask1_gru_out[1], atol=1e-5)

        # LSTM
        torch_assert_allclose(lstm_out.data[2], mask1_lstm_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[3], mask1_lstm_out[1], atol=1e-5)

        # both 1st and 2nd mask forward twice: both should be different
        def mask01_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            rs = mask_recurrent_state_at(rs, [0, 1])
            return one_step_forward(par, x_t1, rs)

        (
            (mask01_gru_out, mask01_lstm_out),
            (mask01_gru_rs, (mask01_lstm_rs,)),
        ) = mask01_forward_twice()

        # GRU
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[2], mask01_gru_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[3], mask01_gru_out[1], atol=1e-5)

        # LSTM
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[2], mask01_lstm_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[3], mask01_lstm_out[1], atol=1e-5)

        # get and concat recurrent states and resume forward
        def get_and_concat_rs_forward():
            _, rs = one_step_forward(par, x_t0, None)
            rs0 = get_recurrent_state_at(rs, 0, detach=True)
            rs1 = get_recurrent_state_at(rs, 1, detach=True)
            concat_rs = concatenate_recurrent_states([rs0, rs1])
            return one_step_forward(par, x_t1, concat_rs)

        (
            (getcon_gru_out, getcon_lstm_out),
            (getcon_gru_rs, (getcon_lstm_rs,)),
        ) = get_and_concat_rs_forward()

        # GRU
        torch_assert_allclose(gru_out.data[2], getcon_gru_out[0], atol=1e-5)
        torch_assert_allclose(gru_out.data[3], getcon_gru_out[1], atol=1e-5)

        # LSTM
        torch_assert_allclose(lstm_out.data[2], getcon_lstm_out[0], atol=1e-5)
        torch_assert_allclose(lstm_out.data[3], getcon_lstm_out[1], atol=1e-5)