Пример #1
0
    def test_visitations(self):
        env = tabular_env.CliffwalkEnv(num_states=3, transition_noise=0.00)
        params = {
            'num_itrs': 50,
            'ent_wt': 0.0,
            'discount': 0.99,
        }
        qvals_py = q_iteration_py.softq_iteration(env, **params)

        visitations = q_iteration_py.compute_visitation(env,
                                                        qvals_py,
                                                        ent_wt=0.0,
                                                        env_time_limit=1)
        s_visitations = np.sum(visitations, axis=1)
        tru_visits = np.array([1, 0, 0])
        self.assertTrue(np.allclose(tru_visits, s_visitations))

        visitations = q_iteration_py.compute_visitation(env,
                                                        qvals_py,
                                                        ent_wt=0.0,
                                                        env_time_limit=3)
        s_visitations = np.sum(visitations, axis=1)
        tru_visits = np.array([1, 1, 1]) / 3.0
        self.assertTrue(np.allclose(tru_visits, s_visitations))

        visitations = q_iteration_py.compute_visitation(env,
                                                        qvals_py,
                                                        ent_wt=0.0,
                                                        env_time_limit=5)
        s_visitations = np.sum(visitations, axis=1)
        tru_visits = np.array([2, 2, 1]) / 5.0
        self.assertTrue(np.allclose(tru_visits, s_visitations))
Пример #2
0
 def test_qevaluation_ent(self):
     env = tabular_env.CliffwalkEnv(num_states=2, transition_noise=0.00)
     params = {
         'num_itrs': 100,
         'ent_wt': 0.001,
         'discount': 0.5,
     }
     q_values = np.zeros((env.num_states, env.num_actions))
     q_values[:, 1] = 1e10
     returns, _ = q_iteration.softq_evaluation(env, q_values, **params)
     self.assertAlmostEqual(returns, 0.66666666)
Пример #3
0
    def setUp(self):
        self.env_tab = tabular_env.CliffwalkEnv(10)
        self.env_obs = random_obs_wrapper.RandomObsWrapper(self.env_tab, 8)
        self.env = time_limit_wrapper.TimeLimitWrapper(self.env_obs, 50)

        self.network = q_networks.LinearNetwork(self.env)
        ptu.initialize_network(self.network)

        self.alg_args = {
            'min_project_steps': 10,
            'max_project_steps': 20,
            'lr': 5e-3,
            'discount': 0.95,
            'n_steps': 1,
            'num_samples': 32,
            'stop_modes': (stopping.AtolStop(), stopping.RtolStop()),
            'backup_mode': 'sampling',
            'ent_wt': 0.01,
        }
Пример #4
0
 def setUp(self):
     self.env = tabular_env.CliffwalkEnv(10)
     self.env = env_wrapper.StochasticActionWrapper(self.env, eps=0.1)
Пример #5
0
 def setUp(self):
     self.env = tabular_env.CliffwalkEnv(10)
     self.absorb_env = env_wrapper.AbsorbingStateWrapper(
         self.env, absorb_reward=123.0)
     self.reward_state = self.env.num_states - 1
     self.absorb_state = self.absorb_env.num_states - 1
Пример #6
0
def get_env(name):
    if name == 'grid16randomobs':
        env = random_grid_env(16,
                              16,
                              dim_obs=16,
                              time_limit=50,
                              wall_ratio=0.2,
                              smooth_obs=False,
                              seed=0)
    elif name == 'grid16onehot':
        env = random_grid_env(16,
                              16,
                              time_limit=50,
                              wall_ratio=0.2,
                              one_hot_obs=True,
                              seed=0)
    elif name == 'grid16sparse':
        env = random_grid_env(16,
                              16,
                              time_limit=50,
                              wall_ratio=0.2,
                              one_hot_obs=True,
                              seed=0,
                              distance_reward=False)
    elif name == 'grid64randomobs':
        env = random_grid_env(64,
                              64,
                              dim_obs=64,
                              time_limit=100,
                              wall_ratio=0.2,
                              smooth_obs=False,
                              seed=0)
    elif name == 'grid64onehot':
        env = random_grid_env(64,
                              64,
                              time_limit=100,
                              wall_ratio=0.2,
                              one_hot_obs=True,
                              seed=0)
    elif name == 'cliffwalk':
        with math_utils.np_seed(0):
            env = tabular_env.CliffwalkEnv(25)
            # Cliffwalk is unsolvable by QI with moderate entropy - up the reward to reduce the effects.
            env = env_wrapper.AbsorbingStateWrapper(env, absorb_reward=10.0)
            env = wrap_obs_time(env, dim_obs=16, time_limit=50)
    elif name == 'pendulum':
        env = tabular_env.InvertedPendulum(state_discretization=32,
                                           action_discretization=5)
        env = wrap_time(env, time_limit=50)
    elif name == 'mountaincar':
        env = tabular_env.MountainCar(posdisc=56, veldisc=32)
        # MountainCar is unsolvable by QI with moderate entropy - up the reward to reduce the effects.
        env = env_wrapper.AbsorbingStateWrapper(env, absorb_reward=10.0)
        env = wrap_time(env, time_limit=100)
    elif name == 'sparsegraph':
        with math_utils.np_seed(0):
            env = tabular_env.RandomTabularEnv(num_states=500,
                                               num_actions=3,
                                               transitions_per_action=1,
                                               self_loop=True)
            env = env_wrapper.AbsorbingStateWrapper(env, absorb_reward=10.0)
            env = wrap_obs_time(env, dim_obs=4, time_limit=10)
    else:
        raise NotImplementedError('Unknown env id: %s' % name)
    return env
 def setUp(self):
     self.env = tabular_env.CliffwalkEnv(10)
Пример #8
0
 def setUp(self):
     self.env = tabular_env.CliffwalkEnv(num_states=3,
                                         transition_noise=0.01)
Пример #9
0
 def setUp(self):
     self.env = tabular_env.CliffwalkEnv(10)
     self.T = 12
     self.env = wrappers.TimeLimitWrapper(self.env, time_limit=self.T)