예제 #1
0
    def make_model(self, env):
        n_hidden_channels = 20
        obs_size = env.observation_space.low.size

        if self.recurrent:
            v = StatelessRecurrentSequential(
                L.NStepLSTM(1, obs_size, n_hidden_channels, 0),
                L.Linear(
                    None, 1, initialW=chainer.initializers.LeCunNormal(1e-1)),
            )
            if self.discrete:
                n_actions = env.action_space.n
                pi = StatelessRecurrentSequential(
                    L.NStepLSTM(1, obs_size, n_hidden_channels, 0),
                    policies.FCSoftmaxPolicy(
                        n_hidden_channels, n_actions,
                        n_hidden_layers=0,
                        nonlinearity=F.tanh,
                        last_wscale=1e-1,
                    )
                )
            else:
                action_size = env.action_space.low.size
                pi = StatelessRecurrentSequential(
                    L.NStepLSTM(1, obs_size, n_hidden_channels, 0),
                    policies.FCGaussianPolicy(
                        n_hidden_channels, action_size,
                        n_hidden_layers=0,
                        nonlinearity=F.tanh,
                        mean_wscale=1e-1,
                    )
                )
            return StatelessRecurrentBranched(pi, v)
        else:
            v = chainer.Sequential(
                L.Linear(None, n_hidden_channels),
                F.tanh,
                L.Linear(
                    None, 1, initialW=chainer.initializers.LeCunNormal(1e-1)),
            )
            if self.discrete:
                n_actions = env.action_space.n
                pi = policies.FCSoftmaxPolicy(
                    obs_size, n_actions,
                    n_hidden_layers=1,
                    n_hidden_channels=n_hidden_channels,
                    nonlinearity=F.tanh,
                    last_wscale=1e-1,
                )
            else:
                action_size = env.action_space.low.size
                pi = policies.FCGaussianPolicy(
                    obs_size, action_size,
                    n_hidden_layers=1,
                    n_hidden_channels=n_hidden_channels,
                    nonlinearity=F.tanh,
                    mean_wscale=1e-1,
                )
            return A3CSeparateModel(pi=pi, v=v)
    def _test_three_recurrent_children(self, gpu):
        # Test if https://github.com/chainer/chainer/issues/6053 is addressed
        in_size = 2
        out_size = 6

        rseq = StatelessRecurrentSequential(
            L.NStepLSTM(1, in_size, 3, 0),
            L.NStepGRU(2, 3, 4, 0),
            L.NStepRNNTanh(5, 4, out_size, 0),
        )

        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            rseq.to_gpu()
        xp = rseq.xp

        seqs_x = [
            xp.random.uniform(-1, 1, size=(4, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32),
        ]

        # Make and load a recurrent state to check if the order is correct.
        _, rs = rseq.n_step_forward(seqs_x, None, output_mode='concat')
        _, _ = rseq.n_step_forward(seqs_x, rs, output_mode='concat')

        _, rs = rseq.n_step_forward(seqs_x, None, output_mode='split')
        _, _ = rseq.n_step_forward(seqs_x, rs, output_mode='split')
예제 #3
0
 def make_q_func(self, env):
     n_hidden_channels = 10
     return StatelessRecurrentSequential(
         L.Linear(env.observation_space.low.size, n_hidden_channels),
         F.elu,
         L.NStepRNNTanh(1, n_hidden_channels, n_hidden_channels, 0),
         L.Linear(n_hidden_channels, env.action_space.n),
         DiscreteActionValue,
     )
    def _test_n_step_forward_with_tuple_output(self, gpu):
        in_size = 5
        out_size = 6

        def split_output(x):
            return tuple(F.split_axis(x, [2, 3], axis=1))

        rseq = StatelessRecurrentSequential(
            L.NStepRNNTanh(1, in_size, out_size, 0),
            split_output,
        )

        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            rseq.to_gpu()
        xp = rseq.xp

        # Input is a list of two variables.
        seqs_x = [
            xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32),
        ]

        # Concatenated output should be a tuple of three variables.
        concat_out, concat_state = rseq.n_step_forward(seqs_x,
                                                       None,
                                                       output_mode='concat')
        self.assertIsInstance(concat_out, tuple)
        self.assertEqual(len(concat_out), 3)
        self.assertEqual(concat_out[0].shape, (5, 2))
        self.assertEqual(concat_out[1].shape, (5, 1))
        self.assertEqual(concat_out[2].shape, (5, 3))

        # Split output should be a list of two tuples, each of which is of
        # three variables.
        split_out, split_state = rseq.n_step_forward(seqs_x,
                                                     None,
                                                     output_mode='split')
        self.assertIsInstance(split_out, list)
        self.assertEqual(len(split_out), 2)
        self.assertIsInstance(split_out[0], tuple)
        self.assertIsInstance(split_out[1], tuple)
        for seq_x, seq_out in zip(seqs_x, split_out):
            self.assertEqual(len(seq_out), 3)
            self.assertEqual(seq_out[0].shape, (len(seq_x), 2))
            self.assertEqual(seq_out[1].shape, (len(seq_x), 1))
            self.assertEqual(seq_out[2].shape, (len(seq_x), 3))

        # Check if output_mode='concat' and output_mode='split' are consistent
        xp.testing.assert_allclose(
            F.concat([F.concat(seq_out, axis=1) for seq_out in split_out],
                     axis=0).array,
            F.concat(concat_out, axis=1).array,
        )
    def _test_n_step_forward_with_tuple_input(self, gpu):
        in_size = 5
        out_size = 3

        def concat_input(*args):
            return F.concat(args, axis=1)

        rseq = StatelessRecurrentSequential(
            concat_input,
            L.NStepRNNTanh(1, in_size, out_size, 0),
        )

        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            rseq.to_gpu()
        xp = rseq.xp

        # Input is list of tuples. Each tuple has two variables.
        seqs_x = [
            (xp.random.uniform(-1, 1, size=(3, 2)).astype(np.float32),
             xp.random.uniform(-1, 1, size=(3, 3)).astype(np.float32)),
            (xp.random.uniform(-1, 1, size=(1, 2)).astype(np.float32),
             xp.random.uniform(-1, 1, size=(1, 3)).astype(np.float32)),
        ]

        # Concatenated output should be a variable.
        concat_out, concat_state = rseq.n_step_forward(seqs_x,
                                                       None,
                                                       output_mode='concat')
        self.assertEqual(concat_out.shape, (4, out_size))

        # Split output should be a list of variables.
        split_out, split_state = rseq.n_step_forward(seqs_x,
                                                     None,
                                                     output_mode='split')
        self.assertIsInstance(split_out, list)
        self.assertEqual(len(split_out), len(seqs_x))
        for seq_x, seq_out in zip(seqs_x, split_out):
            self.assertEqual(seq_out.shape, (len(seq_x), out_size))

        # Check if output_mode='concat' and output_mode='split' are consistent
        xp.testing.assert_allclose(
            F.concat(split_out, axis=0).array,
            concat_out.array,
        )
    def _test_n_step_forward(self, gpu):
        in_size = 2
        out_size = 6

        rseq = StatelessRecurrentSequential(
            L.Linear(in_size, 3),
            F.elu,
            L.NStepLSTM(1, 3, 4, 0),
            L.Linear(4, 5),
            L.NStepRNNTanh(1, 5, out_size, 0),
            F.tanh,
        )

        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            rseq.to_gpu()
        xp = rseq.xp

        linear1 = rseq._layers[0]
        lstm = rseq._layers[2]
        linear2 = rseq._layers[3]
        rnn = rseq._layers[4]

        seqs_x = [
            xp.random.uniform(-1, 1, size=(4, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32),
        ]

        concat_out, concat_state = rseq.n_step_forward(
            seqs_x, None, output_mode='concat')
        self.assertEqual(concat_out.shape, (8, out_size))

        split_out, split_state = rseq.n_step_forward(
            seqs_x, None, output_mode='split')
        self.assertIsInstance(split_out, list)
        self.assertEqual(len(split_out), len(seqs_x))
        for seq_x, seq_out in zip(seqs_x, split_out):
            self.assertEqual(seq_out.shape, (len(seq_x), out_size))

        # Check if output_mode='concat' and output_mode='split' are consistent
        xp.testing.assert_allclose(
            F.concat(split_out, axis=0).array,
            concat_out.array,
        )

        (concat_lstm_h, concat_lstm_c), concat_rnn_h = concat_state
        (split_lstm_h, split_lstm_c), split_rnn_h = split_state
        xp.testing.assert_allclose(concat_lstm_h.array, split_lstm_h.array)
        xp.testing.assert_allclose(concat_lstm_c.array, split_lstm_c.array)
        xp.testing.assert_allclose(concat_rnn_h.array, split_rnn_h.array)

        # Check if the output matches that of step-by-step execution
        def manual_n_step_forward(seqs_x):
            sorted_seqs_x = sorted(seqs_x, key=len, reverse=True)
            transposed_x = F.transpose_sequence(sorted_seqs_x)
            lstm_h = None
            lstm_c = None
            rnn_h = None
            ys = []
            for batch in transposed_x:
                if lstm_h is not None:
                    lstm_h = lstm_h[:len(batch)]
                    lstm_c = lstm_c[:len(batch)]
                    rnn_h = rnn_h[:len(batch)]
                h = linear1(batch)
                h = F.elu(h)
                h, (lstm_h, lstm_c) = _step_lstm(lstm, h, (lstm_h, lstm_c))
                h = linear2(h)
                h, rnn_h = _step_rnn_tanh(rnn, h, rnn_h)
                y = F.tanh(h)
                ys.append(y)
            sorted_seqs_y = F.transpose_sequence(ys)
            # Undo sort
            seqs_y = [sorted_seqs_y[0], sorted_seqs_y[2], sorted_seqs_y[1]]
            return seqs_y

        manual_split_out = manual_n_step_forward(seqs_x)
        for man_seq_out, seq_out in zip(manual_split_out, split_out):
            xp.testing.assert_allclose(
                man_seq_out.array, seq_out.array, rtol=1e-5)

        # Finally, check the gradient (wrt linear1.W)
        concat_grad, = chainer.grad([F.sum(concat_out)], [linear1.W])
        split_grad, = chainer.grad(
            [F.sum(F.concat(split_out, axis=0))], [linear1.W])
        manual_split_grad, = chainer.grad(
            [F.sum(F.concat(manual_split_out, axis=0))], [linear1.W])
        xp.testing.assert_allclose(
            concat_grad.array, split_grad.array, rtol=1e-5)
        xp.testing.assert_allclose(
            concat_grad.array, manual_split_grad.array, rtol=1e-5)
    def _test_mask_recurrent_state_at(self, gpu):
        in_size = 2
        out_size = 4
        rseq = StatelessRecurrentSequential(
            L.Linear(in_size, 3),
            F.elu,
            L.NStepGRU(1, 3, out_size, 0),
            F.softmax,
        )
        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            rseq.to_gpu()
        xp = rseq.xp
        seqs_x = [
            xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32),
        ]
        transposed_x = F.transpose_sequence(seqs_x)
        print('transposed_x[0]', transposed_x[0])

        def no_mask_n_step_forward():
            nomask_nstep_out, nstep_rs = rseq.n_step_forward(
                seqs_x, None, output_mode='concat')
            return F.reshape(nomask_nstep_out, (2, 2, out_size)), nstep_rs
        nstep_out, nstep_rs = no_mask_n_step_forward()

        # Check if n_step_forward and forward twice results are same
        def no_mask_forward_twice():
            _, rs = rseq(transposed_x[0], None)
            return rseq(transposed_x[1], rs)
        nomask_out, nomask_rs = no_mask_forward_twice()
        xp.testing.assert_allclose(
            nstep_out.array[:, 1],
            nomask_out.array,
        )
        xp.testing.assert_allclose(nstep_rs[0].array, nomask_rs[0].array)

        # 1st-only mask forward twice: only 2nd should be the same
        def mask0_forward_twice():
            _, rs = rseq(transposed_x[0], None)
            rs = rseq.mask_recurrent_state_at(rs, 0)
            return rseq(transposed_x[1], rs)
        mask0_out, mask0_rs = mask0_forward_twice()
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out.array[0, 1],
                mask0_out.array[0],
            )
        xp.testing.assert_allclose(
            nstep_out.array[1, 1],
            mask0_out.array[1],
        )

        # 2nd-only mask forward twice: only 1st should be the same
        def mask1_forward_twice():
            _, rs = rseq(transposed_x[0], None)
            rs = rseq.mask_recurrent_state_at(rs, 1)
            return rseq(transposed_x[1], rs)
        mask1_out, mask1_rs = mask1_forward_twice()
        xp.testing.assert_allclose(
            nstep_out.array[0, 1],
            mask1_out.array[0],
        )
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out.array[1, 1],
                mask1_out.array[1],
            )

        # both 1st and 2nd mask forward twice: both should be different
        def mask01_forward_twice():
            _, rs = rseq(transposed_x[0], None)
            rs = rseq.mask_recurrent_state_at(rs, [0, 1])
            return rseq(transposed_x[1], rs)
        mask01_out, mask01_rs = mask01_forward_twice()
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out.array[0, 1],
                mask01_out.array[0],
            )
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out.array[1, 1],
                mask01_out.array[1],
            )

        # get and concat recurrent states and resume forward
        def get_and_concat_rs_forward():
            _, rs = rseq(transposed_x[0], None)
            rs0 = rseq.get_recurrent_state_at(rs, 0, unwrap_variable=True)
            rs1 = rseq.get_recurrent_state_at(rs, 1, unwrap_variable=True)
            concat_rs = rseq.concatenate_recurrent_states([rs0, rs1])
            return rseq(transposed_x[1], concat_rs)
        getcon_out, getcon_rs = get_and_concat_rs_forward()
        xp.testing.assert_allclose(getcon_rs[0].array, nomask_rs[0].array)
        xp.testing.assert_allclose(
            nstep_out.array[0, 1], getcon_out.array[0])
        xp.testing.assert_allclose(
            nstep_out.array[1, 1], getcon_out.array[1])
    def _test_n_step_forward(self, gpu):
        in_size = 2
        out0_size = 3
        out1_size = 4
        out2_size = 1

        par = StatelessRecurrentBranched(
            L.NStepLSTM(1, in_size, out0_size, 0),
            StatelessRecurrentSequential(
                L.NStepRNNReLU(1, in_size, out1_size, 0), ),
            StatelessRecurrentSequential(L.Linear(in_size, out2_size), ),
        )

        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            par.to_gpu()
        xp = par.xp

        seqs_x = [
            xp.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32),
        ]

        # Concatenated output should be a tuple of three variables.
        concat_out, concat_rs = par.n_step_forward(seqs_x,
                                                   None,
                                                   output_mode='concat')
        self.assertIsInstance(concat_out, tuple)
        self.assertEqual(len(concat_out), len(par))
        self.assertEqual(concat_out[0].shape, (4, out0_size))
        self.assertEqual(concat_out[1].shape, (4, out1_size))
        self.assertEqual(concat_out[2].shape, (4, out2_size))

        self.assertIsInstance(concat_rs, tuple)
        self.assertEqual(len(concat_rs), len(par))
        self.assertIsInstance(concat_rs[0], tuple)
        # NStepLSTM
        self.assertEqual(len(concat_rs[0]), 2)
        self.assertEqual(concat_rs[0][0].shape, (1, len(seqs_x), out0_size))
        self.assertEqual(concat_rs[0][1].shape, (1, len(seqs_x), out0_size))
        # StatelessRecurrentSequential(NStepRNNReLU)
        self.assertEqual(len(concat_rs[1]), 1)
        self.assertEqual(concat_rs[1][0].shape, (1, len(seqs_x), out1_size))
        # StatelessRecurrentSequential(Linear)
        self.assertEqual(len(concat_rs[2]), 0)

        # Split output should be a list of two tuples, each of which is of
        # three variables.
        split_out, split_rs = par.n_step_forward(seqs_x,
                                                 None,
                                                 output_mode='split')
        self.assertIsInstance(split_out, list)
        self.assertEqual(len(split_out), len(seqs_x))
        self.assertEqual(len(split_out[0]), len(par))
        self.assertEqual(len(split_out[1]), len(par))
        self.assertEqual(split_out[0][0].shape, (
            1,
            out0_size,
        ))
        self.assertEqual(split_out[0][1].shape, (
            1,
            out1_size,
        ))
        self.assertEqual(split_out[0][2].shape, (
            1,
            out2_size,
        ))
        self.assertEqual(split_out[1][0].shape, (
            3,
            out0_size,
        ))
        self.assertEqual(split_out[1][1].shape, (
            3,
            out1_size,
        ))
        self.assertEqual(split_out[1][2].shape, (
            3,
            out2_size,
        ))

        # Check if output_mode='concat' and output_mode='split' are consistent
        xp.testing.assert_allclose(
            F.concat([F.concat(seq_out, axis=1) for seq_out in split_out],
                     axis=0).array,
            F.concat(concat_out, axis=1).array,
        )
    def _test_mask_recurrent_state_at(self, gpu):
        in_size = 2
        out0_size = 2
        out1_size = 3
        par = StatelessRecurrentBranched(
            L.NStepGRU(1, in_size, out0_size, 0),
            StatelessRecurrentSequential(L.NStepLSTM(1, in_size, out1_size,
                                                     0), ),
        )
        if gpu >= 0:
            chainer.cuda.get_device_from_id(gpu).use()
            par.to_gpu()
        xp = par.xp
        seqs_x = [
            xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32),
            xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32),
        ]
        transposed_x = F.transpose_sequence(seqs_x)

        nstep_out, nstep_rs = par.n_step_forward(seqs_x,
                                                 None,
                                                 output_mode='concat')

        # Check if n_step_forward and forward twice results are same
        def no_mask_forward_twice():
            _, rs = par(transposed_x[0], None)
            return par(transposed_x[1], rs)

        nomask_out, nomask_rs = no_mask_forward_twice()
        # GRU
        xp.testing.assert_allclose(
            nstep_out[0].array[[1, 3]],
            nomask_out[0].array,
        )
        # LSTM
        xp.testing.assert_allclose(
            nstep_out[1].array[[1, 3]],
            nomask_out[1].array,
        )
        xp.testing.assert_allclose(nstep_rs[0].array, nomask_rs[0].array)
        self.assertIsInstance(nomask_rs[1], tuple)
        self.assertEqual(len(nomask_rs[1]), 1)
        self.assertEqual(len(nomask_rs[1][0]), 2)
        xp.testing.assert_allclose(nstep_rs[1][0][0].array,
                                   nomask_rs[1][0][0].array)
        xp.testing.assert_allclose(nstep_rs[1][0][1].array,
                                   nomask_rs[1][0][1].array)

        # 1st-only mask forward twice: only 2nd should be the same
        def mask0_forward_twice():
            _, rs = par(transposed_x[0], None)
            rs = par.mask_recurrent_state_at(rs, 0)
            return par(transposed_x[1], rs)

        mask0_out, mask0_rs = mask0_forward_twice()
        # GRU
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[0].array[1],
                mask0_out[0].array[0],
            )
        xp.testing.assert_allclose(
            nstep_out[0].array[3],
            mask0_out[0].array[1],
        )
        # LSTM
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[1].array[1],
                mask0_out[1].array[0],
            )
        xp.testing.assert_allclose(
            nstep_out[1].array[3],
            mask0_out[1].array[1],
        )

        # 2nd-only mask forward twice: only 1st should be the same
        def mask1_forward_twice():
            _, rs = par(transposed_x[0], None)
            rs = par.mask_recurrent_state_at(rs, 1)
            return par(transposed_x[1], rs)

        mask1_out, mask1_rs = mask1_forward_twice()
        # GRU
        xp.testing.assert_allclose(
            nstep_out[0].array[1],
            mask1_out[0].array[0],
        )
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[0].array[3],
                mask1_out[0].array[1],
            )
        # LSTM
        xp.testing.assert_allclose(
            nstep_out[1].array[1],
            mask1_out[1].array[0],
        )
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[1].array[3],
                mask1_out[1].array[1],
            )

        # both 1st and 2nd mask forward twice: both should be different
        def mask01_forward_twice():
            _, rs = par(transposed_x[0], None)
            rs = par.mask_recurrent_state_at(rs, [0, 1])
            return par(transposed_x[1], rs)

        mask01_out, mask01_rs = mask01_forward_twice()
        # GRU
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[0].array[1],
                mask01_out[0].array[0],
            )
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[0].array[3],
                mask01_out[0].array[1],
            )
        # LSTM
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[1].array[1],
                mask01_out[1].array[0],
            )
        with self.assertRaises(AssertionError):
            xp.testing.assert_allclose(
                nstep_out[1].array[3],
                mask01_out[1].array[1],
            )

        # get and concat recurrent states and resume forward
        def get_and_concat_rs_forward():
            _, rs = par(transposed_x[0], None)
            rs0 = par.get_recurrent_state_at(rs, 0, unwrap_variable=True)
            rs1 = par.get_recurrent_state_at(rs, 1, unwrap_variable=True)
            concat_rs = par.concatenate_recurrent_states([rs0, rs1])
            return par(transposed_x[1], concat_rs)

        getcon_out, getcon_rs = get_and_concat_rs_forward()
        # GRU
        xp.testing.assert_allclose(
            nstep_out[0].array[1],
            getcon_out[0].array[0],
        )
        xp.testing.assert_allclose(
            nstep_out[0].array[3],
            getcon_out[0].array[1],
        )
        # LSTM
        xp.testing.assert_allclose(
            nstep_out[1].array[1],
            getcon_out[1].array[0],
        )
        xp.testing.assert_allclose(
            nstep_out[1].array[3],
            getcon_out[1].array[1],
        )
예제 #10
0
    def test_recurrent_and_non_recurrent_equivalence(self):
        """Test equivalence between recurrent and non-recurrent datasets.

        When the same feed-forward model is used, the values of
        log_prob, v_pred, next_v_pred obtained by both recurrent and
        non-recurrent dataset creation functions should be the same.
        """
        episodes = make_random_episodes()
        if self.use_obs_normalizer:
            obs_normalizer = chainerrl.links.EmpiricalNormalization(
                2, clip_threshold=5)
            obs_normalizer.experience(np.random.uniform(-1, 1, size=(10, 2)))
        else:
            obs_normalizer = None

        def phi(obs):
            return (obs * 0.5).astype(np.float32)

        obs_size = 2
        n_actions = 3

        non_recurrent_model = A3CSeparateModel(
            pi=chainerrl.policies.FCSoftmaxPolicy(obs_size, n_actions),
            v=L.Linear(obs_size, 1),
        )
        recurrent_model = StatelessRecurrentSequential(non_recurrent_model, )
        xp = non_recurrent_model.xp

        dataset = chainerrl.agents.ppo._make_dataset(
            episodes=copy.deepcopy(episodes),
            model=non_recurrent_model,
            phi=phi,
            batch_states=batch_states,
            obs_normalizer=obs_normalizer,
            gamma=self.gamma,
            lambd=self.lambd,
        )

        dataset_recurrent = chainerrl.agents.ppo._make_dataset_recurrent(
            episodes=copy.deepcopy(episodes),
            model=recurrent_model,
            phi=phi,
            batch_states=batch_states,
            obs_normalizer=obs_normalizer,
            gamma=self.gamma,
            lambd=self.lambd,
            max_recurrent_sequence_len=self.max_recurrent_sequence_len,
        )

        self.assertTrue('log_prob' not in episodes[0][0])
        self.assertTrue('log_prob' in dataset[0])
        self.assertTrue('log_prob' in dataset_recurrent[0][0])
        # They are not just shallow copies
        self.assertTrue(
            dataset[0]['log_prob'] is not dataset_recurrent[0][0]['log_prob'])

        states = [tr['state'] for tr in dataset]
        recurrent_states = [
            tr['state']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(states, recurrent_states)

        actions = [tr['action'] for tr in dataset]
        recurrent_actions = [
            tr['action']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(actions, recurrent_actions)

        rewards = [tr['reward'] for tr in dataset]
        recurrent_rewards = [
            tr['reward']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(rewards, recurrent_rewards)

        nonterminals = [tr['nonterminal'] for tr in dataset]
        recurrent_nonterminals = [
            tr['nonterminal']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(nonterminals, recurrent_nonterminals)

        log_probs = [tr['log_prob'] for tr in dataset]
        recurrent_log_probs = [
            tr['log_prob']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(log_probs, recurrent_log_probs)

        vs_pred = [tr['v_pred'] for tr in dataset]
        recurrent_vs_pred = [
            tr['v_pred']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(vs_pred, recurrent_vs_pred)

        next_vs_pred = [tr['next_v_pred'] for tr in dataset]
        recurrent_next_vs_pred = [
            tr['next_v_pred']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(next_vs_pred, recurrent_next_vs_pred)

        advs = [tr['adv'] for tr in dataset]
        recurrent_advs = [
            tr['adv']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(advs, recurrent_advs)

        vs_teacher = [tr['v_teacher'] for tr in dataset]
        recurrent_vs_teacher = [
            tr['v_teacher']
            for tr in itertools.chain.from_iterable(dataset_recurrent)
        ]
        xp.testing.assert_allclose(vs_teacher, recurrent_vs_teacher)