示例#1
0
    def test_reset(self):

        # single env
        dim = 3
        env_spec = MockEnvSpec()
        sampler = UniformlyRandomLatentSampler(
            scheduler=ConstantIntervalScheduler(), name='test', dim=dim)
        dones = [True]
        sampler.reset(dones)
        action, _ = sampler.get_action(None)
        self.assertTrue(sampler.latent_values.shape == (1, 3))
        self.assertTrue(np.sum(sampler.latent_values, axis=1) == 1)

        # multi env
        env_spec = MockEnvSpec(num_envs=2)
        dim = 100
        sampler = UniformlyRandomLatentSampler(
            scheduler=ConstantIntervalScheduler(), name='test', dim=dim)
        dones = [True, True]
        sampler.reset(dones)

        self.assertTrue(sampler.latent_values.shape == (2, dim))

        actions_1, _ = sampler.get_actions([None] * 2)
        sampler.reset(dones)

        actions_2, _ = sampler.get_actions([None] * 2)
        self.assertEqual(sampler.latent_values.shape, (2, dim))
        self.assertNotEqual(tuple(np.argmax(actions_1, axis=1)),
                            tuple(np.argmax(actions_2, axis=1)))

        dones = [False, True]
        sampler.reset(dones)
        np.testing.assert_array_equal(np.sum(sampler.latent_values, axis=1),
                                      [1, 1])
示例#2
0
    def test_get_actions(self):
        dim = 2
        env_spec = MockEnvSpec(num_envs=5)
        sampler = UniformlyRandomLatentSampler(
            scheduler=ConstantIntervalScheduler(), name='test', dim=dim)
        sampler.reset([True] * 5)

        # scalar observations case
        obs = np.zeros((env_spec.num_envs, 3))
        latent, agent_info = sampler.get_actions(obs)
        self.assertEqual(latent.shape, (env_spec.num_envs, 2))