def test_shapes_are_correct_env_with_continuous_action_spaces_vector_sample( self, _, n_envs, n_steps): env_name = "MountainCarContinuous-v0" env = SyncVectorEnv( [lambda: gym.make(env_name) for _ in range(n_envs)]) observation_shape = (n_envs, ) + env.envs[0].observation_space.shape action_shape = (n_envs, ) + env.envs[0].action_space.shape loop = EnvironmentLoop(env, self._create_continuouse_policy(env)) batch = loop.sample() for _ in range(1, n_steps): batch = loop.sample() self._assert_has_shapes( batch, expected={ SampleBatch.OBSERVATIONS: observation_shape, SampleBatch.OBSERVATION_NEXTS: observation_shape, SampleBatch.ACTIONS: action_shape, }, default=(n_envs, ), ) self._assert_has_dtype( batch, expected={ SampleBatch.EPS_ID: torch.int64, }, default=torch.float32, )
def test_shapes_are_correct_env_with_continuous_obs_and_discrete_action_spaces_vector( self, _, n_envs, n_steps): env = SyncVectorEnv( [lambda: gym.make("CartPole-v0") for _ in range(n_envs)]) observation_shape = env.observation_space.shape loop = EnvironmentLoop(env, self._create_discrite_policy(env)) batch = loop.step() for _ in range(1, n_steps): batch = loop.step() self._assert_has_shapes( batch, expected={ SampleBatch.OBSERVATIONS: observation_shape, SampleBatch.OBSERVATION_NEXTS: observation_shape, }, default=(n_envs, ), ) self._assert_has_dtype( batch, expected={ SampleBatch.ACTIONS: torch.int64, SampleBatch.EPS_ID: torch.int64, }, default=torch.float32, )
def main(): env_name = "CartPole-v1" env = gym.make(env_name) n_observations = env.observation_space.shape[0] n_actions = env.action_space.n model = DQNModel( n_observations=n_observations, n_actions=n_actions, ) data_module, callbacks = off_policy_dataset( lambda: gym.make(env_name), model.select_online_actions, capacity=100_000, n_populate_steps=10_000, steps_per_epoch=1000, batch_size=32, ) env_loop = EnvironmentLoop(env, model.select_actions) eval_callback = EnvironmentEvaluationCallback(env_loop) trainer = pl.Trainer( gpus=1, fast_dev_run=False, max_epochs=100, callbacks=callbacks + [eval_callback], logger=False, checkpoint_callback=False, ) trainer.fit(model, data_module)
def main(): env_name = "CartPole-v1" # env_name = "MountainCar-v0" env = gym.make(env_name) n_observations = env.observation_space.shape[0] n_actions = env.action_space.n model = PGModel( n_observations=n_observations, n_actions=n_actions, ) data_module, callbacks = on_policy_dataset( lambda: gym.make(env_name), model.select_online_actions, batch_size=4000, steps_per_epoch=25, n_envs=10, post_process_function=model.post_process_function, fetch_agent_info=model.agent_info, ) env_loop = EnvironmentLoop(env, model.select_actions) eval_callback = EnvironmentEvaluationCallback(env_loop) trainer = pl.Trainer( gpus=1, fast_dev_run=False, max_epochs=100, callbacks=callbacks + [eval_callback], logger=False, checkpoint_callback=False, ) trainer.fit(model, data_module)
def test_shapes_are_correct_env_with_discrete_obs_and_action_spaces_sample( self, _, n_steps): env = gym.make("Taxi-v3") loop = EnvironmentLoop(env, self._create_discrite_policy(env)) batch = loop.sample() for _ in range(1, n_steps): batch = loop.sample() self._assert_has_shapes(batch, default=(1, )) self._assert_has_dtype( batch, expected={ SampleBatch.REWARDS: torch.float32, SampleBatch.DONES: torch.float32, }, default=torch.int64, )
def _create_env_loop(self, env_name, n_envs=None, fetch_agent_info=None): if n_envs is None: env = gym.make(env_name) else: env = SyncVectorEnv( [lambda: gym.make(env_name) for _ in range(n_envs)]) if env_name == "MountainCarContinuous-v0": return EnvironmentLoop(env, self._create_continuouse_policy(env), fetch_agent_info=fetch_agent_info) if env_name == "Taxi-v3" or "CartPole" in env_name: return EnvironmentLoop(env, self._create_discrite_policy(env), fetch_agent_info=fetch_agent_info) raise RuntimeError("Unknown env", env_name)
def main(): env_name = "Taxi-v3" env = gym.make(env_name) n_observations = env.observation_space.n if not isinstance( env, VectorEnv) else env.envs[0].observation_space.n n_actions = env.action_space.n if not isinstance( env, VectorEnv) else env.action_space[0].n model = QTableModel( n_observations=n_observations, n_actions=n_actions, ) """ data_module, callbacks = off_policy_dataset( lambda: gym.make(env_name), model.select_online_actions, capacity=100, n_populate_steps=100, steps_per_epoch=250, ) """ data_module, callbacks = on_policy_dataset( lambda: gym.make(env_name), model.select_online_actions, batch_size=32, steps_per_epoch=250, ) env_loop = EnvironmentLoop(env, model.select_actions) eval_callback = EnvironmentEvaluationCallback(env_loop) trainer = pl.Trainer( gpus=0, fast_dev_run=False, max_epochs=50, callbacks=callbacks + [eval_callback], logger=False, checkpoint_callback=False, ) trainer.fit(model, data_module)
def on_policy_dataset( env_builder: EnvBuilder, select_online_actions: Policy, fetch_agent_info: Optional[FetchAgentInfo] = None, # batch batch_size: int = 4000, # online callback n_envs: int = 10, steps_per_epoch: int = 5000, # post processing post_process_function: Optional[PostProcessFunction] = None, ) -> Tuple[OnlineDataModule, List[Callback]]: buffer = UniformReplayBuffer(batch_size) samples_per_epoch = steps_per_epoch * batch_size sampler = EntireBufferSampler(buffer, samples_per_epoch) data_module = OnlineDataModule(buffer, batch_size, sampler=sampler, pin_memory=True, n_workers=0) online_env = _build_env(env_builder, n_envs) n_samples_per_step = batch_size env_loop = EnvironmentLoop(online_env, select_online_actions, fetch_agent_info=fetch_agent_info) online_step_callback = OnlineDataCollectionCallback( buffer, env_loop, n_samples_per_step=n_samples_per_step, n_populate_steps=0, post_process_function=post_process_function, clear_buffer_before_gather=True, ) return data_module, [online_step_callback]
def eval_callback( env_builder: EnvBuilder, select_actions: Policy, seed: Optional[int] = None, n_envs: int = 1, n_eval_episodes: int = 10, n_test_episodes: int = 100, to_eval: bool = False, logging_prefix: str = "Evaluation", mean_return_in_progress_bar: bool = True, ) -> EnvironmentEvaluationCallback: env = _build_env(env_builder, n_envs) env_loop = EnvironmentLoop(env, select_actions) return EnvironmentEvaluationCallback( env_loop, n_eval_episodes=n_eval_episodes, n_test_episodes=n_test_episodes, to_eval=to_eval, seed=seed, logging_prefix=logging_prefix, mean_return_in_progress_bar=mean_return_in_progress_bar, )
buffer = UniformReplayBuffer(capacity) samples_per_epoch = steps_per_epoch * batch_size sampler = UniformSampler(buffer, samples_per_epoch) data_module = OnlineDataModule(buffer, batch_size, sampler=sampler, pin_memory=True, n_workers=0) online_env = _build_env(env_builder, n_envs) n_samples_per_step = batch_size env_loop = EnvironmentLoop(online_env, select_online_actions, fetch_agent_info=fetch_agent_info) online_step_callback = OnlineDataCollectionCallback( buffer, env_loop, n_samples_per_step=n_samples_per_step, n_populate_steps=n_populate_steps, post_process_function=post_process_function, clear_buffer_before_gather=False, ) return data_module, [online_step_callback] def eval_callback(