def make_model(self, env): n_hidden_channels = 20 obs_size = env.observation_space.low.size if self.recurrent: v = StatelessRecurrentSequential( L.NStepLSTM(1, obs_size, n_hidden_channels, 0), L.Linear( None, 1, initialW=chainer.initializers.LeCunNormal(1e-1)), ) if self.discrete: n_actions = env.action_space.n pi = StatelessRecurrentSequential( L.NStepLSTM(1, obs_size, n_hidden_channels, 0), policies.FCSoftmaxPolicy( n_hidden_channels, n_actions, n_hidden_layers=0, nonlinearity=F.tanh, last_wscale=1e-1, ) ) else: action_size = env.action_space.low.size pi = StatelessRecurrentSequential( L.NStepLSTM(1, obs_size, n_hidden_channels, 0), policies.FCGaussianPolicy( n_hidden_channels, action_size, n_hidden_layers=0, nonlinearity=F.tanh, mean_wscale=1e-1, ) ) return StatelessRecurrentBranched(pi, v) else: v = chainer.Sequential( L.Linear(None, n_hidden_channels), F.tanh, L.Linear( None, 1, initialW=chainer.initializers.LeCunNormal(1e-1)), ) if self.discrete: n_actions = env.action_space.n pi = policies.FCSoftmaxPolicy( obs_size, n_actions, n_hidden_layers=1, n_hidden_channels=n_hidden_channels, nonlinearity=F.tanh, last_wscale=1e-1, ) else: action_size = env.action_space.low.size pi = policies.FCGaussianPolicy( obs_size, action_size, n_hidden_layers=1, n_hidden_channels=n_hidden_channels, nonlinearity=F.tanh, mean_wscale=1e-1, ) return A3CSeparateModel(pi=pi, v=v)
def _test_three_recurrent_children(self, gpu): # Test if https://github.com/chainer/chainer/issues/6053 is addressed in_size = 2 out_size = 6 rseq = StatelessRecurrentSequential( L.NStepLSTM(1, in_size, 3, 0), L.NStepGRU(2, 3, 4, 0), L.NStepRNNTanh(5, 4, out_size, 0), ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() rseq.to_gpu() xp = rseq.xp seqs_x = [ xp.random.uniform(-1, 1, size=(4, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32), ] # Make and load a recurrent state to check if the order is correct. _, rs = rseq.n_step_forward(seqs_x, None, output_mode='concat') _, _ = rseq.n_step_forward(seqs_x, rs, output_mode='concat') _, rs = rseq.n_step_forward(seqs_x, None, output_mode='split') _, _ = rseq.n_step_forward(seqs_x, rs, output_mode='split')
def make_q_func(self, env): n_hidden_channels = 10 return StatelessRecurrentSequential( L.Linear(env.observation_space.low.size, n_hidden_channels), F.elu, L.NStepRNNTanh(1, n_hidden_channels, n_hidden_channels, 0), L.Linear(n_hidden_channels, env.action_space.n), DiscreteActionValue, )
def _test_n_step_forward_with_tuple_output(self, gpu): in_size = 5 out_size = 6 def split_output(x): return tuple(F.split_axis(x, [2, 3], axis=1)) rseq = StatelessRecurrentSequential( L.NStepRNNTanh(1, in_size, out_size, 0), split_output, ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() rseq.to_gpu() xp = rseq.xp # Input is a list of two variables. seqs_x = [ xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32), ] # Concatenated output should be a tuple of three variables. concat_out, concat_state = rseq.n_step_forward(seqs_x, None, output_mode='concat') self.assertIsInstance(concat_out, tuple) self.assertEqual(len(concat_out), 3) self.assertEqual(concat_out[0].shape, (5, 2)) self.assertEqual(concat_out[1].shape, (5, 1)) self.assertEqual(concat_out[2].shape, (5, 3)) # Split output should be a list of two tuples, each of which is of # three variables. split_out, split_state = rseq.n_step_forward(seqs_x, None, output_mode='split') self.assertIsInstance(split_out, list) self.assertEqual(len(split_out), 2) self.assertIsInstance(split_out[0], tuple) self.assertIsInstance(split_out[1], tuple) for seq_x, seq_out in zip(seqs_x, split_out): self.assertEqual(len(seq_out), 3) self.assertEqual(seq_out[0].shape, (len(seq_x), 2)) self.assertEqual(seq_out[1].shape, (len(seq_x), 1)) self.assertEqual(seq_out[2].shape, (len(seq_x), 3)) # Check if output_mode='concat' and output_mode='split' are consistent xp.testing.assert_allclose( F.concat([F.concat(seq_out, axis=1) for seq_out in split_out], axis=0).array, F.concat(concat_out, axis=1).array, )
def _test_n_step_forward_with_tuple_input(self, gpu): in_size = 5 out_size = 3 def concat_input(*args): return F.concat(args, axis=1) rseq = StatelessRecurrentSequential( concat_input, L.NStepRNNTanh(1, in_size, out_size, 0), ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() rseq.to_gpu() xp = rseq.xp # Input is list of tuples. Each tuple has two variables. seqs_x = [ (xp.random.uniform(-1, 1, size=(3, 2)).astype(np.float32), xp.random.uniform(-1, 1, size=(3, 3)).astype(np.float32)), (xp.random.uniform(-1, 1, size=(1, 2)).astype(np.float32), xp.random.uniform(-1, 1, size=(1, 3)).astype(np.float32)), ] # Concatenated output should be a variable. concat_out, concat_state = rseq.n_step_forward(seqs_x, None, output_mode='concat') self.assertEqual(concat_out.shape, (4, out_size)) # Split output should be a list of variables. split_out, split_state = rseq.n_step_forward(seqs_x, None, output_mode='split') self.assertIsInstance(split_out, list) self.assertEqual(len(split_out), len(seqs_x)) for seq_x, seq_out in zip(seqs_x, split_out): self.assertEqual(seq_out.shape, (len(seq_x), out_size)) # Check if output_mode='concat' and output_mode='split' are consistent xp.testing.assert_allclose( F.concat(split_out, axis=0).array, concat_out.array, )
def _test_n_step_forward(self, gpu): in_size = 2 out_size = 6 rseq = StatelessRecurrentSequential( L.Linear(in_size, 3), F.elu, L.NStepLSTM(1, 3, 4, 0), L.Linear(4, 5), L.NStepRNNTanh(1, 5, out_size, 0), F.tanh, ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() rseq.to_gpu() xp = rseq.xp linear1 = rseq._layers[0] lstm = rseq._layers[2] linear2 = rseq._layers[3] rnn = rseq._layers[4] seqs_x = [ xp.random.uniform(-1, 1, size=(4, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32), ] concat_out, concat_state = rseq.n_step_forward( seqs_x, None, output_mode='concat') self.assertEqual(concat_out.shape, (8, out_size)) split_out, split_state = rseq.n_step_forward( seqs_x, None, output_mode='split') self.assertIsInstance(split_out, list) self.assertEqual(len(split_out), len(seqs_x)) for seq_x, seq_out in zip(seqs_x, split_out): self.assertEqual(seq_out.shape, (len(seq_x), out_size)) # Check if output_mode='concat' and output_mode='split' are consistent xp.testing.assert_allclose( F.concat(split_out, axis=0).array, concat_out.array, ) (concat_lstm_h, concat_lstm_c), concat_rnn_h = concat_state (split_lstm_h, split_lstm_c), split_rnn_h = split_state xp.testing.assert_allclose(concat_lstm_h.array, split_lstm_h.array) xp.testing.assert_allclose(concat_lstm_c.array, split_lstm_c.array) xp.testing.assert_allclose(concat_rnn_h.array, split_rnn_h.array) # Check if the output matches that of step-by-step execution def manual_n_step_forward(seqs_x): sorted_seqs_x = sorted(seqs_x, key=len, reverse=True) transposed_x = F.transpose_sequence(sorted_seqs_x) lstm_h = None lstm_c = None rnn_h = None ys = [] for batch in transposed_x: if lstm_h is not None: lstm_h = lstm_h[:len(batch)] lstm_c = lstm_c[:len(batch)] rnn_h = rnn_h[:len(batch)] h = linear1(batch) h = F.elu(h) h, (lstm_h, lstm_c) = _step_lstm(lstm, h, (lstm_h, lstm_c)) h = linear2(h) h, rnn_h = _step_rnn_tanh(rnn, h, rnn_h) y = F.tanh(h) ys.append(y) sorted_seqs_y = F.transpose_sequence(ys) # Undo sort seqs_y = [sorted_seqs_y[0], sorted_seqs_y[2], sorted_seqs_y[1]] return seqs_y manual_split_out = manual_n_step_forward(seqs_x) for man_seq_out, seq_out in zip(manual_split_out, split_out): xp.testing.assert_allclose( man_seq_out.array, seq_out.array, rtol=1e-5) # Finally, check the gradient (wrt linear1.W) concat_grad, = chainer.grad([F.sum(concat_out)], [linear1.W]) split_grad, = chainer.grad( [F.sum(F.concat(split_out, axis=0))], [linear1.W]) manual_split_grad, = chainer.grad( [F.sum(F.concat(manual_split_out, axis=0))], [linear1.W]) xp.testing.assert_allclose( concat_grad.array, split_grad.array, rtol=1e-5) xp.testing.assert_allclose( concat_grad.array, manual_split_grad.array, rtol=1e-5)
def _test_mask_recurrent_state_at(self, gpu): in_size = 2 out_size = 4 rseq = StatelessRecurrentSequential( L.Linear(in_size, 3), F.elu, L.NStepGRU(1, 3, out_size, 0), F.softmax, ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() rseq.to_gpu() xp = rseq.xp seqs_x = [ xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32), ] transposed_x = F.transpose_sequence(seqs_x) print('transposed_x[0]', transposed_x[0]) def no_mask_n_step_forward(): nomask_nstep_out, nstep_rs = rseq.n_step_forward( seqs_x, None, output_mode='concat') return F.reshape(nomask_nstep_out, (2, 2, out_size)), nstep_rs nstep_out, nstep_rs = no_mask_n_step_forward() # Check if n_step_forward and forward twice results are same def no_mask_forward_twice(): _, rs = rseq(transposed_x[0], None) return rseq(transposed_x[1], rs) nomask_out, nomask_rs = no_mask_forward_twice() xp.testing.assert_allclose( nstep_out.array[:, 1], nomask_out.array, ) xp.testing.assert_allclose(nstep_rs[0].array, nomask_rs[0].array) # 1st-only mask forward twice: only 2nd should be the same def mask0_forward_twice(): _, rs = rseq(transposed_x[0], None) rs = rseq.mask_recurrent_state_at(rs, 0) return rseq(transposed_x[1], rs) mask0_out, mask0_rs = mask0_forward_twice() with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out.array[0, 1], mask0_out.array[0], ) xp.testing.assert_allclose( nstep_out.array[1, 1], mask0_out.array[1], ) # 2nd-only mask forward twice: only 1st should be the same def mask1_forward_twice(): _, rs = rseq(transposed_x[0], None) rs = rseq.mask_recurrent_state_at(rs, 1) return rseq(transposed_x[1], rs) mask1_out, mask1_rs = mask1_forward_twice() xp.testing.assert_allclose( nstep_out.array[0, 1], mask1_out.array[0], ) with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out.array[1, 1], mask1_out.array[1], ) # both 1st and 2nd mask forward twice: both should be different def mask01_forward_twice(): _, rs = rseq(transposed_x[0], None) rs = rseq.mask_recurrent_state_at(rs, [0, 1]) return rseq(transposed_x[1], rs) mask01_out, mask01_rs = mask01_forward_twice() with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out.array[0, 1], mask01_out.array[0], ) with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out.array[1, 1], mask01_out.array[1], ) # get and concat recurrent states and resume forward def get_and_concat_rs_forward(): _, rs = rseq(transposed_x[0], None) rs0 = rseq.get_recurrent_state_at(rs, 0, unwrap_variable=True) rs1 = rseq.get_recurrent_state_at(rs, 1, unwrap_variable=True) concat_rs = rseq.concatenate_recurrent_states([rs0, rs1]) return rseq(transposed_x[1], concat_rs) getcon_out, getcon_rs = get_and_concat_rs_forward() xp.testing.assert_allclose(getcon_rs[0].array, nomask_rs[0].array) xp.testing.assert_allclose( nstep_out.array[0, 1], getcon_out.array[0]) xp.testing.assert_allclose( nstep_out.array[1, 1], getcon_out.array[1])
def _test_n_step_forward(self, gpu): in_size = 2 out0_size = 3 out1_size = 4 out2_size = 1 par = StatelessRecurrentBranched( L.NStepLSTM(1, in_size, out0_size, 0), StatelessRecurrentSequential( L.NStepRNNReLU(1, in_size, out1_size, 0), ), StatelessRecurrentSequential(L.Linear(in_size, out2_size), ), ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() par.to_gpu() xp = par.xp seqs_x = [ xp.random.uniform(-1, 1, size=(1, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(3, in_size)).astype(np.float32), ] # Concatenated output should be a tuple of three variables. concat_out, concat_rs = par.n_step_forward(seqs_x, None, output_mode='concat') self.assertIsInstance(concat_out, tuple) self.assertEqual(len(concat_out), len(par)) self.assertEqual(concat_out[0].shape, (4, out0_size)) self.assertEqual(concat_out[1].shape, (4, out1_size)) self.assertEqual(concat_out[2].shape, (4, out2_size)) self.assertIsInstance(concat_rs, tuple) self.assertEqual(len(concat_rs), len(par)) self.assertIsInstance(concat_rs[0], tuple) # NStepLSTM self.assertEqual(len(concat_rs[0]), 2) self.assertEqual(concat_rs[0][0].shape, (1, len(seqs_x), out0_size)) self.assertEqual(concat_rs[0][1].shape, (1, len(seqs_x), out0_size)) # StatelessRecurrentSequential(NStepRNNReLU) self.assertEqual(len(concat_rs[1]), 1) self.assertEqual(concat_rs[1][0].shape, (1, len(seqs_x), out1_size)) # StatelessRecurrentSequential(Linear) self.assertEqual(len(concat_rs[2]), 0) # Split output should be a list of two tuples, each of which is of # three variables. split_out, split_rs = par.n_step_forward(seqs_x, None, output_mode='split') self.assertIsInstance(split_out, list) self.assertEqual(len(split_out), len(seqs_x)) self.assertEqual(len(split_out[0]), len(par)) self.assertEqual(len(split_out[1]), len(par)) self.assertEqual(split_out[0][0].shape, ( 1, out0_size, )) self.assertEqual(split_out[0][1].shape, ( 1, out1_size, )) self.assertEqual(split_out[0][2].shape, ( 1, out2_size, )) self.assertEqual(split_out[1][0].shape, ( 3, out0_size, )) self.assertEqual(split_out[1][1].shape, ( 3, out1_size, )) self.assertEqual(split_out[1][2].shape, ( 3, out2_size, )) # Check if output_mode='concat' and output_mode='split' are consistent xp.testing.assert_allclose( F.concat([F.concat(seq_out, axis=1) for seq_out in split_out], axis=0).array, F.concat(concat_out, axis=1).array, )
def _test_mask_recurrent_state_at(self, gpu): in_size = 2 out0_size = 2 out1_size = 3 par = StatelessRecurrentBranched( L.NStepGRU(1, in_size, out0_size, 0), StatelessRecurrentSequential(L.NStepLSTM(1, in_size, out1_size, 0), ), ) if gpu >= 0: chainer.cuda.get_device_from_id(gpu).use() par.to_gpu() xp = par.xp seqs_x = [ xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32), xp.random.uniform(-1, 1, size=(2, in_size)).astype(np.float32), ] transposed_x = F.transpose_sequence(seqs_x) nstep_out, nstep_rs = par.n_step_forward(seqs_x, None, output_mode='concat') # Check if n_step_forward and forward twice results are same def no_mask_forward_twice(): _, rs = par(transposed_x[0], None) return par(transposed_x[1], rs) nomask_out, nomask_rs = no_mask_forward_twice() # GRU xp.testing.assert_allclose( nstep_out[0].array[[1, 3]], nomask_out[0].array, ) # LSTM xp.testing.assert_allclose( nstep_out[1].array[[1, 3]], nomask_out[1].array, ) xp.testing.assert_allclose(nstep_rs[0].array, nomask_rs[0].array) self.assertIsInstance(nomask_rs[1], tuple) self.assertEqual(len(nomask_rs[1]), 1) self.assertEqual(len(nomask_rs[1][0]), 2) xp.testing.assert_allclose(nstep_rs[1][0][0].array, nomask_rs[1][0][0].array) xp.testing.assert_allclose(nstep_rs[1][0][1].array, nomask_rs[1][0][1].array) # 1st-only mask forward twice: only 2nd should be the same def mask0_forward_twice(): _, rs = par(transposed_x[0], None) rs = par.mask_recurrent_state_at(rs, 0) return par(transposed_x[1], rs) mask0_out, mask0_rs = mask0_forward_twice() # GRU with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[0].array[1], mask0_out[0].array[0], ) xp.testing.assert_allclose( nstep_out[0].array[3], mask0_out[0].array[1], ) # LSTM with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[1].array[1], mask0_out[1].array[0], ) xp.testing.assert_allclose( nstep_out[1].array[3], mask0_out[1].array[1], ) # 2nd-only mask forward twice: only 1st should be the same def mask1_forward_twice(): _, rs = par(transposed_x[0], None) rs = par.mask_recurrent_state_at(rs, 1) return par(transposed_x[1], rs) mask1_out, mask1_rs = mask1_forward_twice() # GRU xp.testing.assert_allclose( nstep_out[0].array[1], mask1_out[0].array[0], ) with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[0].array[3], mask1_out[0].array[1], ) # LSTM xp.testing.assert_allclose( nstep_out[1].array[1], mask1_out[1].array[0], ) with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[1].array[3], mask1_out[1].array[1], ) # both 1st and 2nd mask forward twice: both should be different def mask01_forward_twice(): _, rs = par(transposed_x[0], None) rs = par.mask_recurrent_state_at(rs, [0, 1]) return par(transposed_x[1], rs) mask01_out, mask01_rs = mask01_forward_twice() # GRU with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[0].array[1], mask01_out[0].array[0], ) with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[0].array[3], mask01_out[0].array[1], ) # LSTM with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[1].array[1], mask01_out[1].array[0], ) with self.assertRaises(AssertionError): xp.testing.assert_allclose( nstep_out[1].array[3], mask01_out[1].array[1], ) # get and concat recurrent states and resume forward def get_and_concat_rs_forward(): _, rs = par(transposed_x[0], None) rs0 = par.get_recurrent_state_at(rs, 0, unwrap_variable=True) rs1 = par.get_recurrent_state_at(rs, 1, unwrap_variable=True) concat_rs = par.concatenate_recurrent_states([rs0, rs1]) return par(transposed_x[1], concat_rs) getcon_out, getcon_rs = get_and_concat_rs_forward() # GRU xp.testing.assert_allclose( nstep_out[0].array[1], getcon_out[0].array[0], ) xp.testing.assert_allclose( nstep_out[0].array[3], getcon_out[0].array[1], ) # LSTM xp.testing.assert_allclose( nstep_out[1].array[1], getcon_out[1].array[0], ) xp.testing.assert_allclose( nstep_out[1].array[3], getcon_out[1].array[1], )
def test_recurrent_and_non_recurrent_equivalence(self): """Test equivalence between recurrent and non-recurrent datasets. When the same feed-forward model is used, the values of log_prob, v_pred, next_v_pred obtained by both recurrent and non-recurrent dataset creation functions should be the same. """ episodes = make_random_episodes() if self.use_obs_normalizer: obs_normalizer = chainerrl.links.EmpiricalNormalization( 2, clip_threshold=5) obs_normalizer.experience(np.random.uniform(-1, 1, size=(10, 2))) else: obs_normalizer = None def phi(obs): return (obs * 0.5).astype(np.float32) obs_size = 2 n_actions = 3 non_recurrent_model = A3CSeparateModel( pi=chainerrl.policies.FCSoftmaxPolicy(obs_size, n_actions), v=L.Linear(obs_size, 1), ) recurrent_model = StatelessRecurrentSequential(non_recurrent_model, ) xp = non_recurrent_model.xp dataset = chainerrl.agents.ppo._make_dataset( episodes=copy.deepcopy(episodes), model=non_recurrent_model, phi=phi, batch_states=batch_states, obs_normalizer=obs_normalizer, gamma=self.gamma, lambd=self.lambd, ) dataset_recurrent = chainerrl.agents.ppo._make_dataset_recurrent( episodes=copy.deepcopy(episodes), model=recurrent_model, phi=phi, batch_states=batch_states, obs_normalizer=obs_normalizer, gamma=self.gamma, lambd=self.lambd, max_recurrent_sequence_len=self.max_recurrent_sequence_len, ) self.assertTrue('log_prob' not in episodes[0][0]) self.assertTrue('log_prob' in dataset[0]) self.assertTrue('log_prob' in dataset_recurrent[0][0]) # They are not just shallow copies self.assertTrue( dataset[0]['log_prob'] is not dataset_recurrent[0][0]['log_prob']) states = [tr['state'] for tr in dataset] recurrent_states = [ tr['state'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(states, recurrent_states) actions = [tr['action'] for tr in dataset] recurrent_actions = [ tr['action'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(actions, recurrent_actions) rewards = [tr['reward'] for tr in dataset] recurrent_rewards = [ tr['reward'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(rewards, recurrent_rewards) nonterminals = [tr['nonterminal'] for tr in dataset] recurrent_nonterminals = [ tr['nonterminal'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(nonterminals, recurrent_nonterminals) log_probs = [tr['log_prob'] for tr in dataset] recurrent_log_probs = [ tr['log_prob'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(log_probs, recurrent_log_probs) vs_pred = [tr['v_pred'] for tr in dataset] recurrent_vs_pred = [ tr['v_pred'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(vs_pred, recurrent_vs_pred) next_vs_pred = [tr['next_v_pred'] for tr in dataset] recurrent_next_vs_pred = [ tr['next_v_pred'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(next_vs_pred, recurrent_next_vs_pred) advs = [tr['adv'] for tr in dataset] recurrent_advs = [ tr['adv'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(advs, recurrent_advs) vs_teacher = [tr['v_teacher'] for tr in dataset] recurrent_vs_teacher = [ tr['v_teacher'] for tr in itertools.chain.from_iterable(dataset_recurrent) ] xp.testing.assert_allclose(vs_teacher, recurrent_vs_teacher)