Example #1
0
    def test_recurrent_and_non_recurrent_equivalence(self):
        """Test equivalence between recurrent and non-recurrent datasets.

        When the same feed-forward model is used, the values of
        log_prob, v_pred, next_v_pred obtained by both recurrent and
        non-recurrent dataset creation functions should be the same.
        """
        episodes = make_random_episodes()
        if self.use_obs_normalizer:
            obs_normalizer = chainerrl.links.EmpiricalNormalization(
                2, clip_threshold=5)
            obs_normalizer.experience(
                np.random.uniform(-1, 1, size=(10, 2)))
        else:
            obs_normalizer = None

        def phi(obs):
            return (obs * 0.5).astype(np.float32)

        obs_size = 2
        n_actions = 3

        non_recurrent_model = A3CSeparateModel(
            pi=chainerrl.policies.FCSoftmaxPolicy(obs_size, n_actions),
            v=L.Linear(obs_size, 1),
        )
        recurrent_model = StatelessRecurrentSequential(
            non_recurrent_model,
        )
        xp = non_recurrent_model.xp

        dataset = chainerrl.agents.ppo._make_dataset(
            episodes=copy.deepcopy(episodes),
            model=non_recurrent_model,
            phi=phi,
            batch_states=batch_states,
            obs_normalizer=obs_normalizer,
            gamma=self.gamma,
            lambd=self.lambd,
        )

        dataset_recurrent = chainerrl.agents.ppo._make_dataset_recurrent(
            episodes=copy.deepcopy(episodes),
            model=recurrent_model,
            phi=phi,
            batch_states=batch_states,
            obs_normalizer=obs_normalizer,
            gamma=self.gamma,
            lambd=self.lambd,
            max_recurrent_sequence_len=self.max_recurrent_sequence_len,
        )

        self.assertTrue('log_prob' not in episodes[0][0])
        self.assertTrue('log_prob' in dataset[0])
        self.assertTrue('log_prob' in dataset_recurrent[0][0])
        # They are not just shallow copies
        self.assertTrue(dataset[0]['log_prob']
                        is not dataset_recurrent[0][0]['log_prob'])

        states = [tr['state'] for tr in dataset]
        recurrent_states = [
            tr['state'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(states, recurrent_states)

        actions = [tr['action'] for tr in dataset]
        recurrent_actions = [
            tr['action'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(actions, recurrent_actions)

        rewards = [tr['reward'] for tr in dataset]
        recurrent_rewards = [
            tr['reward'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(rewards, recurrent_rewards)

        nonterminals = [tr['nonterminal'] for tr in dataset]
        recurrent_nonterminals = [
            tr['nonterminal'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(nonterminals, recurrent_nonterminals)

        log_probs = [tr['log_prob'] for tr in dataset]
        recurrent_log_probs = [
            tr['log_prob'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(log_probs, recurrent_log_probs)

        vs_pred = [tr['v_pred'] for tr in dataset]
        recurrent_vs_pred = [
            tr['v_pred'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(vs_pred, recurrent_vs_pred)

        next_vs_pred = [tr['next_v_pred'] for tr in dataset]
        recurrent_next_vs_pred = [
            tr['next_v_pred'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(next_vs_pred, recurrent_next_vs_pred)

        advs = [tr['adv'] for tr in dataset]
        recurrent_advs = [
            tr['adv'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(advs, recurrent_advs)

        vs_teacher = [tr['v_teacher'] for tr in dataset]
        recurrent_vs_teacher = [
            tr['v_teacher'] for tr in itertools.chain.from_iterable(
                dataset_recurrent)]
        xp.testing.assert_allclose(vs_teacher, recurrent_vs_teacher)