def test_visitations(self): env = tabular_env.CliffwalkEnv(num_states=3, transition_noise=0.00) params = { 'num_itrs': 50, 'ent_wt': 0.0, 'discount': 0.99, } qvals_py = q_iteration_py.softq_iteration(env, **params) visitations = q_iteration_py.compute_visitation(env, qvals_py, ent_wt=0.0, env_time_limit=1) s_visitations = np.sum(visitations, axis=1) tru_visits = np.array([1, 0, 0]) self.assertTrue(np.allclose(tru_visits, s_visitations)) visitations = q_iteration_py.compute_visitation(env, qvals_py, ent_wt=0.0, env_time_limit=3) s_visitations = np.sum(visitations, axis=1) tru_visits = np.array([1, 1, 1]) / 3.0 self.assertTrue(np.allclose(tru_visits, s_visitations)) visitations = q_iteration_py.compute_visitation(env, qvals_py, ent_wt=0.0, env_time_limit=5) s_visitations = np.sum(visitations, axis=1) tru_visits = np.array([2, 2, 1]) / 5.0 self.assertTrue(np.allclose(tru_visits, s_visitations))
def test_qevaluation_ent(self): env = tabular_env.CliffwalkEnv(num_states=2, transition_noise=0.00) params = { 'num_itrs': 100, 'ent_wt': 0.001, 'discount': 0.5, } q_values = np.zeros((env.num_states, env.num_actions)) q_values[:, 1] = 1e10 returns, _ = q_iteration.softq_evaluation(env, q_values, **params) self.assertAlmostEqual(returns, 0.66666666)
def setUp(self): self.env_tab = tabular_env.CliffwalkEnv(10) self.env_obs = random_obs_wrapper.RandomObsWrapper(self.env_tab, 8) self.env = time_limit_wrapper.TimeLimitWrapper(self.env_obs, 50) self.network = q_networks.LinearNetwork(self.env) ptu.initialize_network(self.network) self.alg_args = { 'min_project_steps': 10, 'max_project_steps': 20, 'lr': 5e-3, 'discount': 0.95, 'n_steps': 1, 'num_samples': 32, 'stop_modes': (stopping.AtolStop(), stopping.RtolStop()), 'backup_mode': 'sampling', 'ent_wt': 0.01, }
def setUp(self): self.env = tabular_env.CliffwalkEnv(10) self.env = env_wrapper.StochasticActionWrapper(self.env, eps=0.1)
def setUp(self): self.env = tabular_env.CliffwalkEnv(10) self.absorb_env = env_wrapper.AbsorbingStateWrapper( self.env, absorb_reward=123.0) self.reward_state = self.env.num_states - 1 self.absorb_state = self.absorb_env.num_states - 1
def get_env(name): if name == 'grid16randomobs': env = random_grid_env(16, 16, dim_obs=16, time_limit=50, wall_ratio=0.2, smooth_obs=False, seed=0) elif name == 'grid16onehot': env = random_grid_env(16, 16, time_limit=50, wall_ratio=0.2, one_hot_obs=True, seed=0) elif name == 'grid16sparse': env = random_grid_env(16, 16, time_limit=50, wall_ratio=0.2, one_hot_obs=True, seed=0, distance_reward=False) elif name == 'grid64randomobs': env = random_grid_env(64, 64, dim_obs=64, time_limit=100, wall_ratio=0.2, smooth_obs=False, seed=0) elif name == 'grid64onehot': env = random_grid_env(64, 64, time_limit=100, wall_ratio=0.2, one_hot_obs=True, seed=0) elif name == 'cliffwalk': with math_utils.np_seed(0): env = tabular_env.CliffwalkEnv(25) # Cliffwalk is unsolvable by QI with moderate entropy - up the reward to reduce the effects. env = env_wrapper.AbsorbingStateWrapper(env, absorb_reward=10.0) env = wrap_obs_time(env, dim_obs=16, time_limit=50) elif name == 'pendulum': env = tabular_env.InvertedPendulum(state_discretization=32, action_discretization=5) env = wrap_time(env, time_limit=50) elif name == 'mountaincar': env = tabular_env.MountainCar(posdisc=56, veldisc=32) # MountainCar is unsolvable by QI with moderate entropy - up the reward to reduce the effects. env = env_wrapper.AbsorbingStateWrapper(env, absorb_reward=10.0) env = wrap_time(env, time_limit=100) elif name == 'sparsegraph': with math_utils.np_seed(0): env = tabular_env.RandomTabularEnv(num_states=500, num_actions=3, transitions_per_action=1, self_loop=True) env = env_wrapper.AbsorbingStateWrapper(env, absorb_reward=10.0) env = wrap_obs_time(env, dim_obs=4, time_limit=10) else: raise NotImplementedError('Unknown env id: %s' % name) return env
def setUp(self): self.env = tabular_env.CliffwalkEnv(10)
def setUp(self): self.env = tabular_env.CliffwalkEnv(num_states=3, transition_noise=0.01)
def setUp(self): self.env = tabular_env.CliffwalkEnv(10) self.T = 12 self.env = wrappers.TimeLimitWrapper(self.env, time_limit=self.T)