Example #1
0
    def _test_forward_with_tuple_output(self, gpu):
        in_size = 5
        out_size = 6

        def split_output(x):
            return tuple(torch.split(x, [2, 1, 3], dim=1))

        rseq = RecurrentSequential(
            nn.RNN(num_layers=1, input_size=in_size, hidden_size=out_size),
            Lambda(split_output),
        )

        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            rseq.to(device)
        else:
            device = torch.device("cpu")

        # Input is a list of two variables.
        seqs_x = [
            torch.rand(3, in_size, device=device),
            torch.rand(2, in_size, device=device),
        ]

        packed_x = nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)

        # Concatenated output should be a tuple of three variables.
        out, _ = rseq(packed_x, None)

        self.assertIsInstance(out, tuple)
        self.assertEqual(len(out), 3)
        self.assertEqual(out[0].data.shape, (5, 2))
        self.assertEqual(out[1].data.shape, (5, 1))
        self.assertEqual(out[2].data.shape, (5, 3))
Example #2
0
    def _test_forward_with_tuple_input(self, gpu):
        in_size = 5
        out_size = 3

        def concat_input(tensors):
            return torch.cat(tensors, dim=1)

        rseq = RecurrentSequential(
            Lambda(concat_input),
            nn.RNN(num_layers=1, input_size=in_size, hidden_size=out_size),
        )

        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            rseq.to(device)
        else:
            device = torch.device("cpu")

        # Input is list of tuples. Each tuple has two variables.
        seqs_x = [
            (torch.rand(3, 2, device=device), torch.rand(3, 3, device=device)),
            (torch.rand(1, 2, device=device), torch.rand(1, 3, device=device)),
        ]
        packed_x = (
            nn.utils.rnn.pack_sequence([seqs_x[0][0], seqs_x[1][0]]),
            nn.utils.rnn.pack_sequence([seqs_x[0][1], seqs_x[1][1]]),
        )

        # Concatenated output should be a variable.
        out, _ = rseq(packed_x, None)
        self.assertEqual(out.data.shape, (4, out_size))
Example #3
0
    def _test_forward(self, gpu):
        in_size = 2
        out0_size = 3
        out1_size = 4
        out2_size = 1

        par = RecurrentBranched(
            nn.LSTM(num_layers=1, input_size=in_size, hidden_size=out0_size),
            RecurrentSequential(
                nn.RNN(num_layers=1, input_size=in_size, hidden_size=out1_size),
            ),
            RecurrentSequential(nn.Linear(in_size, out2_size),),
        )

        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            par.to(device)
        else:
            device = torch.device("cpu")

        seqs_x = [
            torch.rand(1, in_size, device=device),
            torch.rand(3, in_size, device=device),
        ]

        packed_x = nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)

        # Concatenated output should be a tuple of three variables.
        out, rs = par(packed_x, None)
        self.assertIsInstance(out, tuple)
        self.assertEqual(len(out), len(par))
        self.assertEqual(out[0].data.shape, (4, out0_size))
        self.assertEqual(out[1].data.shape, (4, out1_size))
        self.assertEqual(out[2].data.shape, (4, out2_size))

        self.assertIsInstance(rs, tuple)
        self.assertEqual(len(rs), len(par))

        # LSTM
        self.assertIsInstance(rs[0], tuple)
        self.assertEqual(len(rs[0]), 2)
        self.assertEqual(rs[0][0].shape, (1, len(seqs_x), out0_size))
        self.assertEqual(rs[0][1].shape, (1, len(seqs_x), out0_size))

        # RecurrentSequential(RNN)
        self.assertIsInstance(rs[1], tuple)
        self.assertEqual(len(rs[1]), 1)
        self.assertEqual(rs[1][0].shape, (1, len(seqs_x), out1_size))

        # RecurrentSequential(Linear)
        self.assertIsInstance(rs[2], tuple)
        self.assertEqual(len(rs[2]), 0)
Example #4
0
 def make_q_func(self, env):
     n_hidden_channels = 10
     return RecurrentSequential(
         nn.Linear(env.observation_space.low.size, n_hidden_channels),
         nn.ELU(),
         nn.RNN(input_size=n_hidden_channels,
                hidden_size=n_hidden_channels),
         nn.Linear(n_hidden_channels, env.action_space.n),
         q_functions.DiscreteActionValueHead(),
     )
Example #5
0
 def make_model(self, env):
     obs_size = env.observation_space.low.size
     action_size = env.action_space.low.size
     hidden_size = 50
     # Model must be recurrent
     policy = RecurrentSequential(
         nn.Linear(obs_size, hidden_size),
         nn.ReLU(),
         nn.LSTM(input_size=hidden_size, hidden_size=hidden_size),
         nn.Linear(hidden_size, action_size),
         BoundByTanh(low=env.action_space.low, high=env.action_space.high),
         DeterministicHead(),
     )
     q_func = RecurrentSequential(
         ConcatObsAndAction(),
         nn.Linear(obs_size + action_size, hidden_size),
         nn.ReLU(),
         nn.LSTM(input_size=hidden_size, hidden_size=hidden_size),
         nn.Linear(hidden_size, 1),
     )
     return policy, q_func
Example #6
0
    def _test_forward(self, gpu):
        in_size = 2
        out_size = 6

        rseq = RecurrentSequential(
            nn.Linear(in_size, 3),
            nn.ELU(),
            nn.LSTM(num_layers=1, input_size=3, hidden_size=4),
            nn.Linear(4, 5),
            nn.RNN(num_layers=1, input_size=5, hidden_size=out_size),
            nn.Tanh(),
        )

        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            rseq.to(device)
        else:
            device = torch.device("cpu")

        assert len(rseq.recurrent_children) == 2
        assert rseq.recurrent_children[0] is rseq[2]
        assert rseq.recurrent_children[1] is rseq[4]

        linear1 = rseq[0]
        lstm = rseq[2]
        linear2 = rseq[3]
        rnn = rseq[4]

        seqs_x = [
            torch.rand(4, in_size, requires_grad=True, device=device),
            torch.rand(1, in_size, requires_grad=True, device=device),
            torch.rand(3, in_size, requires_grad=True, device=device),
        ]

        packed_x = nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)

        out, _ = rseq(packed_x, None)
        self.assertEqual(out.data.shape, (8, out_size))

        # Check if the output matches that of step-by-step execution
        def manual_forward(seqs_x):
            seqs_y = []
            for seq_x in seqs_x:
                lstm_st = None
                rnn_st = None
                seq_y = []
                for i in range(len(seq_x)):
                    h = seq_x[i:i + 1]
                    h = linear1(h)
                    h = F.elu(h)
                    h, lstm_st = _step_lstm(lstm, h, lstm_st)
                    h = linear2(h)
                    h, rnn_st = _step_rnn_tanh(rnn, h, rnn_st)
                    y = F.tanh(h)
                    seq_y.append(y[0])
                seqs_y.append(torch.stack(seq_y))
            return nn.utils.rnn.pack_sequence(seqs_y, enforce_sorted=False)

        manual_out = manual_forward(seqs_x)
        torch_assert_allclose(out.data, manual_out.data, atol=1e-4)

        # Finally, check the gradient (wrt input)
        grads = torch.autograd.grad([torch.sum(out.data)], seqs_x)
        manual_grads = torch.autograd.grad([torch.sum(manual_out.data)],
                                           seqs_x)
        assert len(grads) == len(manual_grads) == 3
        for grad, manual_grad in zip(grads, manual_grads):
            torch_assert_allclose(grad, manual_grad, atol=1e-4)
Example #7
0
    def make_model(self, env):
        hidden_size = 20
        obs_size = env.observation_space.low.size

        def weight_scale(layer, scale):
            with torch.no_grad():
                layer.weight.mul_(scale)
            return layer

        if self.recurrent:
            v = RecurrentSequential(
                nn.LSTM(num_layers=1,
                        input_size=obs_size,
                        hidden_size=hidden_size),
                weight_scale(nn.Linear(hidden_size, 1), 1e-1),
            )
            if self.discrete:
                n_actions = env.action_space.n
                pi = RecurrentSequential(
                    nn.LSTM(num_layers=1,
                            input_size=obs_size,
                            hidden_size=hidden_size),
                    weight_scale(nn.Linear(hidden_size, n_actions), 1e-1),
                    SoftmaxCategoricalHead(),
                )
            else:
                action_size = env.action_space.low.size
                pi = RecurrentSequential(
                    nn.LSTM(num_layers=1,
                            input_size=obs_size,
                            hidden_size=hidden_size),
                    weight_scale(nn.Linear(hidden_size, action_size), 1e-1),
                    GaussianHeadWithStateIndependentCovariance(
                        action_size=action_size,
                        var_type="diagonal",
                        var_func=lambda x: torch.exp(2 * x),
                        var_param_init=0,
                    ),
                )
            return RecurrentBranched(pi, v)
        else:
            v = nn.Sequential(
                nn.Linear(obs_size, hidden_size),
                nn.Tanh(),
                weight_scale(nn.Linear(hidden_size, 1), 1e-1),
            )
            if self.discrete:
                n_actions = env.action_space.n
                pi = nn.Sequential(
                    nn.Linear(obs_size, hidden_size),
                    nn.Tanh(),
                    weight_scale(nn.Linear(hidden_size, n_actions), 1e-1),
                    SoftmaxCategoricalHead(),
                )
            else:
                action_size = env.action_space.low.size
                pi = nn.Sequential(
                    nn.Linear(obs_size, hidden_size),
                    nn.Tanh(),
                    weight_scale(nn.Linear(hidden_size, action_size), 1e-1),
                    GaussianHeadWithStateIndependentCovariance(
                        action_size=action_size,
                        var_type="diagonal",
                        var_func=lambda x: torch.exp(2 * x),
                        var_param_init=0,
                    ),
                )
            return pfrl.nn.Branched(pi, v)
Example #8
0
def test_ppo_dataset_recurrent_and_non_recurrent_equivalence(
        use_obs_normalizer, gamma, lambd, max_recurrent_sequence_len):
    """Test equivalence between recurrent and non-recurrent datasets.

    When the same feed-forward model is used, the values of
    log_prob, v_pred, next_v_pred obtained by both recurrent and
    non-recurrent dataset creation functions should be the same.
    """
    episodes = make_random_episodes()
    if use_obs_normalizer:
        obs_normalizer = pfrl.nn.EmpiricalNormalization(2, clip_threshold=5)
        obs_normalizer.experience(torch.rand(10, 2))
    else:
        obs_normalizer = None

    def phi(obs):
        return (obs * 0.5).astype(np.float32)

    device = torch.device("cpu")

    obs_size = 2
    n_actions = 3

    non_recurrent_model = pfrl.nn.Branched(
        nn.Sequential(
            nn.Linear(obs_size, n_actions),
            SoftmaxCategoricalHead(),
        ),
        nn.Linear(obs_size, 1),
    )
    recurrent_model = RecurrentSequential(non_recurrent_model, )

    dataset = pfrl.agents.ppo._make_dataset(
        episodes=copy.deepcopy(episodes),
        model=non_recurrent_model,
        phi=phi,
        batch_states=batch_states,
        obs_normalizer=obs_normalizer,
        gamma=gamma,
        lambd=lambd,
        device=device,
    )

    dataset_recurrent = pfrl.agents.ppo._make_dataset_recurrent(
        episodes=copy.deepcopy(episodes),
        model=recurrent_model,
        phi=phi,
        batch_states=batch_states,
        obs_normalizer=obs_normalizer,
        gamma=gamma,
        lambd=lambd,
        max_recurrent_sequence_len=max_recurrent_sequence_len,
        device=device,
    )

    assert "log_prob" not in episodes[0][0]
    assert "log_prob" in dataset[0]
    assert "log_prob" in dataset_recurrent[0][0]
    # They are not just shallow copies
    assert dataset[0]["log_prob"] is not dataset_recurrent[0][0]["log_prob"]

    states = [tr["state"] for tr in dataset]
    recurrent_states = [
        tr["state"] for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(states, recurrent_states)

    actions = [tr["action"] for tr in dataset]
    recurrent_actions = [
        tr["action"] for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(actions, recurrent_actions)

    rewards = [tr["reward"] for tr in dataset]
    recurrent_rewards = [
        tr["reward"] for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(rewards, recurrent_rewards)

    nonterminals = [tr["nonterminal"] for tr in dataset]
    recurrent_nonterminals = [
        tr["nonterminal"]
        for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(nonterminals, recurrent_nonterminals)

    log_probs = [tr["log_prob"] for tr in dataset]
    recurrent_log_probs = [
        tr["log_prob"]
        for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(log_probs, recurrent_log_probs)

    vs_pred = [tr["v_pred"] for tr in dataset]
    recurrent_vs_pred = [
        tr["v_pred"] for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(vs_pred, recurrent_vs_pred)

    next_vs_pred = [tr["next_v_pred"] for tr in dataset]
    recurrent_next_vs_pred = [
        tr["next_v_pred"]
        for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(next_vs_pred, recurrent_next_vs_pred)

    advs = [tr["adv"] for tr in dataset]
    recurrent_advs = [
        tr["adv"] for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(advs, recurrent_advs)

    vs_teacher = [tr["v_teacher"] for tr in dataset]
    recurrent_vs_teacher = [
        tr["v_teacher"]
        for tr in itertools.chain.from_iterable(dataset_recurrent)
    ]
    torch_assert_allclose(vs_teacher, recurrent_vs_teacher)
Example #9
0
    def _test_forward_with_modified_recurrent_state(self, gpu):
        in_size = 2
        out0_size = 2
        out1_size = 3
        par = RecurrentBranched(
            nn.GRU(num_layers=1, input_size=in_size, hidden_size=out0_size),
            RecurrentSequential(
                nn.LSTM(num_layers=1, input_size=in_size, hidden_size=out1_size),
            ),
        )
        if gpu >= 0:
            device = torch.device("cuda:{}".format(gpu))
            par.to(device)
        else:
            device = torch.device("cpu")
        seqs_x = [
            torch.rand(2, in_size, device=device),
            torch.rand(2, in_size, device=device),
        ]
        packed_x = nn.utils.rnn.pack_sequence(seqs_x, enforce_sorted=False)
        x_t0 = torch.stack((seqs_x[0][0], seqs_x[1][0]))
        x_t1 = torch.stack((seqs_x[0][1], seqs_x[1][1]))

        (gru_out, lstm_out), (gru_rs, (lstm_rs,)) = par(packed_x, None)

        # Check if n_step_forward and forward twice results are same
        def no_mask_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            return one_step_forward(par, x_t1, rs)

        (
            (nomask_gru_out, nomask_lstm_out),
            (nomask_gru_rs, (nomask_lstm_rs,)),
        ) = no_mask_forward_twice()

        # GRU
        torch_assert_allclose(gru_out.data[2:], nomask_gru_out, atol=1e-5)
        torch_assert_allclose(gru_rs, nomask_gru_rs)

        # LSTM
        torch_assert_allclose(lstm_out.data[2:], nomask_lstm_out, atol=1e-5)
        torch_assert_allclose(lstm_rs[0], nomask_lstm_rs[0], atol=1e-5)
        torch_assert_allclose(lstm_rs[1], nomask_lstm_rs[1], atol=1e-5)

        # 1st-only mask forward twice: only 2nd should be the same
        def mask0_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            rs = mask_recurrent_state_at(rs, 0)
            return one_step_forward(par, x_t1, rs)

        (
            (mask0_gru_out, mask0_lstm_out),
            (mask0_gru_rs, (mask0_lstm_rs,)),
        ) = mask0_forward_twice()

        # GRU
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[2], mask0_gru_out[0], atol=1e-5)
        torch_assert_allclose(gru_out.data[3], mask0_gru_out[1], atol=1e-5)

        # LSTM
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[2], mask0_lstm_out[0], atol=1e-5)
        torch_assert_allclose(lstm_out.data[3], mask0_lstm_out[1], atol=1e-5)

        # 2nd-only mask forward twice: only 1st should be the same
        def mask1_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            rs = mask_recurrent_state_at(rs, 1)
            return one_step_forward(par, x_t1, rs)

        (
            (mask1_gru_out, mask1_lstm_out),
            (mask1_gru_rs, (mask1_lstm_rs,)),
        ) = mask1_forward_twice()

        # GRU
        torch_assert_allclose(gru_out.data[2], mask1_gru_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[3], mask1_gru_out[1], atol=1e-5)

        # LSTM
        torch_assert_allclose(lstm_out.data[2], mask1_lstm_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[3], mask1_lstm_out[1], atol=1e-5)

        # both 1st and 2nd mask forward twice: both should be different
        def mask01_forward_twice():
            _, rs = one_step_forward(par, x_t0, None)
            rs = mask_recurrent_state_at(rs, [0, 1])
            return one_step_forward(par, x_t1, rs)

        (
            (mask01_gru_out, mask01_lstm_out),
            (mask01_gru_rs, (mask01_lstm_rs,)),
        ) = mask01_forward_twice()

        # GRU
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[2], mask01_gru_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(gru_out.data[3], mask01_gru_out[1], atol=1e-5)

        # LSTM
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[2], mask01_lstm_out[0], atol=1e-5)
        with self.assertRaises(AssertionError):
            torch_assert_allclose(lstm_out.data[3], mask01_lstm_out[1], atol=1e-5)

        # get and concat recurrent states and resume forward
        def get_and_concat_rs_forward():
            _, rs = one_step_forward(par, x_t0, None)
            rs0 = get_recurrent_state_at(rs, 0, detach=True)
            rs1 = get_recurrent_state_at(rs, 1, detach=True)
            concat_rs = concatenate_recurrent_states([rs0, rs1])
            return one_step_forward(par, x_t1, concat_rs)

        (
            (getcon_gru_out, getcon_lstm_out),
            (getcon_gru_rs, (getcon_lstm_rs,)),
        ) = get_and_concat_rs_forward()

        # GRU
        torch_assert_allclose(gru_out.data[2], getcon_gru_out[0], atol=1e-5)
        torch_assert_allclose(gru_out.data[3], getcon_gru_out[1], atol=1e-5)

        # LSTM
        torch_assert_allclose(lstm_out.data[2], getcon_lstm_out[0], atol=1e-5)
        torch_assert_allclose(lstm_out.data[3], getcon_lstm_out[1], atol=1e-5)