예제 #1
0
    def test_get_action(self):
        dim = 3
        env_spec = MockEnvSpec()
        sampler = UniformlyRandomLatentSampler(
            scheduler=ConstantIntervalScheduler(), name='test', dim=dim)
        sampler.reset([True])
        obs = [[0, 1]]
        latent, agent_info = sampler.get_action(obs)
        self.assertTrue('latent' in agent_info.keys())

        sampler.reset([True])
        obs = [[0, 0, 1]]
        latent, agent_info = sampler.get_action(obs)
        self.assertEqual(latent.shape, (3, ))
        self.assertEqual(sum(latent), 1)
예제 #2
0
    def test_reset(self):

        # single env
        dim = 3
        env_spec = MockEnvSpec()
        sampler = UniformlyRandomLatentSampler(
            scheduler=ConstantIntervalScheduler(), name='test', dim=dim)
        dones = [True]
        sampler.reset(dones)
        action, _ = sampler.get_action(None)
        self.assertTrue(sampler.latent_values.shape == (1, 3))
        self.assertTrue(np.sum(sampler.latent_values, axis=1) == 1)

        # multi env
        env_spec = MockEnvSpec(num_envs=2)
        dim = 100
        sampler = UniformlyRandomLatentSampler(
            scheduler=ConstantIntervalScheduler(), name='test', dim=dim)
        dones = [True, True]
        sampler.reset(dones)

        self.assertTrue(sampler.latent_values.shape == (2, dim))

        actions_1, _ = sampler.get_actions([None] * 2)
        sampler.reset(dones)

        actions_2, _ = sampler.get_actions([None] * 2)
        self.assertEqual(sampler.latent_values.shape, (2, dim))
        self.assertNotEqual(tuple(np.argmax(actions_1, axis=1)),
                            tuple(np.argmax(actions_2, axis=1)))

        dones = [False, True]
        sampler.reset(dones)
        np.testing.assert_array_equal(np.sum(sampler.latent_values, axis=1),
                                      [1, 1])