def test_recurrent_and_non_recurrent_equivalence(self): """Test equivalence between recurrent and non-recurrent datasets. When the same feed-forward model is used, the values of log_prob, v_pred, next_v_pred obtained by both recurrent and non-recurrent dataset creation functions should be the same. """ episodes = make_random_episodes() if self.use_obs_normalizer: obs_normalizer = chainerrl.links.EmpiricalNormalization( 2, clip_threshold=5) obs_normalizer.experience( np.random.uniform(-1, 1, size=(10, 2))) else: obs_normalizer = None def phi(obs): return (obs * 0.5).astype(np.float32) obs_size = 2 n_actions = 3 non_recurrent_model = A3CSeparateModel( pi=chainerrl.policies.FCSoftmaxPolicy(obs_size, n_actions), v=L.Linear(obs_size, 1), ) recurrent_model = StatelessRecurrentSequential( non_recurrent_model, ) xp = non_recurrent_model.xp dataset = chainerrl.agents.ppo._make_dataset( episodes=copy.deepcopy(episodes), model=non_recurrent_model, phi=phi, batch_states=batch_states, obs_normalizer=obs_normalizer, gamma=self.gamma, lambd=self.lambd, ) dataset_recurrent = chainerrl.agents.ppo._make_dataset_recurrent( episodes=copy.deepcopy(episodes), model=recurrent_model, phi=phi, batch_states=batch_states, obs_normalizer=obs_normalizer, gamma=self.gamma, lambd=self.lambd, max_recurrent_sequence_len=self.max_recurrent_sequence_len, ) self.assertTrue('log_prob' not in episodes[0][0]) self.assertTrue('log_prob' in dataset[0]) self.assertTrue('log_prob' in dataset_recurrent[0][0]) # They are not just shallow copies self.assertTrue(dataset[0]['log_prob'] is not dataset_recurrent[0][0]['log_prob']) states = [tr['state'] for tr in dataset] recurrent_states = [ tr['state'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(states, recurrent_states) actions = [tr['action'] for tr in dataset] recurrent_actions = [ tr['action'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(actions, recurrent_actions) rewards = [tr['reward'] for tr in dataset] recurrent_rewards = [ tr['reward'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(rewards, recurrent_rewards) nonterminals = [tr['nonterminal'] for tr in dataset] recurrent_nonterminals = [ tr['nonterminal'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(nonterminals, recurrent_nonterminals) log_probs = [tr['log_prob'] for tr in dataset] recurrent_log_probs = [ tr['log_prob'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(log_probs, recurrent_log_probs) vs_pred = [tr['v_pred'] for tr in dataset] recurrent_vs_pred = [ tr['v_pred'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(vs_pred, recurrent_vs_pred) next_vs_pred = [tr['next_v_pred'] for tr in dataset] recurrent_next_vs_pred = [ tr['next_v_pred'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(next_vs_pred, recurrent_next_vs_pred) advs = [tr['adv'] for tr in dataset] recurrent_advs = [ tr['adv'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(advs, recurrent_advs) vs_teacher = [tr['v_teacher'] for tr in dataset] recurrent_vs_teacher = [ tr['v_teacher'] for tr in itertools.chain.from_iterable( dataset_recurrent)] xp.testing.assert_allclose(vs_teacher, recurrent_vs_teacher)