コード例 #1
0
    def test_shapes_are_correct_env_with_continuous_action_spaces_vector_sample(
            self, _, n_envs, n_steps):
        env_name = "MountainCarContinuous-v0"
        env = SyncVectorEnv(
            [lambda: gym.make(env_name) for _ in range(n_envs)])
        observation_shape = (n_envs, ) + env.envs[0].observation_space.shape
        action_shape = (n_envs, ) + env.envs[0].action_space.shape

        loop = EnvironmentLoop(env, self._create_continuouse_policy(env))
        batch = loop.sample()
        for _ in range(1, n_steps):
            batch = loop.sample()

        self._assert_has_shapes(
            batch,
            expected={
                SampleBatch.OBSERVATIONS: observation_shape,
                SampleBatch.OBSERVATION_NEXTS: observation_shape,
                SampleBatch.ACTIONS: action_shape,
            },
            default=(n_envs, ),
        )
        self._assert_has_dtype(
            batch,
            expected={
                SampleBatch.EPS_ID: torch.int64,
            },
            default=torch.float32,
        )
コード例 #2
0
    def test_shapes_are_correct_env_with_continuous_obs_and_discrete_action_spaces_vector(
            self, _, n_envs, n_steps):

        env = SyncVectorEnv(
            [lambda: gym.make("CartPole-v0") for _ in range(n_envs)])
        observation_shape = env.observation_space.shape

        loop = EnvironmentLoop(env, self._create_discrite_policy(env))
        batch = loop.step()
        for _ in range(1, n_steps):
            batch = loop.step()

        self._assert_has_shapes(
            batch,
            expected={
                SampleBatch.OBSERVATIONS: observation_shape,
                SampleBatch.OBSERVATION_NEXTS: observation_shape,
            },
            default=(n_envs, ),
        )
        self._assert_has_dtype(
            batch,
            expected={
                SampleBatch.ACTIONS: torch.int64,
                SampleBatch.EPS_ID: torch.int64,
            },
            default=torch.float32,
        )
コード例 #3
0
ファイル: train_dqn.py プロジェクト: j0rd1smit/lighting-RL
def main():
    env_name = "CartPole-v1"
    env = gym.make(env_name)

    n_observations = env.observation_space.shape[0]
    n_actions = env.action_space.n
    model = DQNModel(
        n_observations=n_observations,
        n_actions=n_actions,
    )

    data_module, callbacks = off_policy_dataset(
        lambda: gym.make(env_name),
        model.select_online_actions,
        capacity=100_000,
        n_populate_steps=10_000,
        steps_per_epoch=1000,
        batch_size=32,
    )

    env_loop = EnvironmentLoop(env, model.select_actions)
    eval_callback = EnvironmentEvaluationCallback(env_loop)

    trainer = pl.Trainer(
        gpus=1,
        fast_dev_run=False,
        max_epochs=100,
        callbacks=callbacks + [eval_callback],
        logger=False,
        checkpoint_callback=False,
    )

    trainer.fit(model, data_module)
コード例 #4
0
ファイル: train_pg.py プロジェクト: j0rd1smit/lighting-RL
def main():
    env_name = "CartPole-v1"
    # env_name = "MountainCar-v0"
    env = gym.make(env_name)

    n_observations = env.observation_space.shape[0]
    n_actions = env.action_space.n
    model = PGModel(
        n_observations=n_observations,
        n_actions=n_actions,
    )

    data_module, callbacks = on_policy_dataset(
        lambda: gym.make(env_name),
        model.select_online_actions,
        batch_size=4000,
        steps_per_epoch=25,
        n_envs=10,
        post_process_function=model.post_process_function,
        fetch_agent_info=model.agent_info,
    )

    env_loop = EnvironmentLoop(env, model.select_actions)
    eval_callback = EnvironmentEvaluationCallback(env_loop)

    trainer = pl.Trainer(
        gpus=1,
        fast_dev_run=False,
        max_epochs=100,
        callbacks=callbacks + [eval_callback],
        logger=False,
        checkpoint_callback=False,
    )

    trainer.fit(model, data_module)
コード例 #5
0
    def test_shapes_are_correct_env_with_discrete_obs_and_action_spaces_sample(
            self, _, n_steps):
        env = gym.make("Taxi-v3")

        loop = EnvironmentLoop(env, self._create_discrite_policy(env))
        batch = loop.sample()
        for _ in range(1, n_steps):
            batch = loop.sample()

        self._assert_has_shapes(batch, default=(1, ))
        self._assert_has_dtype(
            batch,
            expected={
                SampleBatch.REWARDS: torch.float32,
                SampleBatch.DONES: torch.float32,
            },
            default=torch.int64,
        )
コード例 #6
0
    def _create_env_loop(self, env_name, n_envs=None, fetch_agent_info=None):
        if n_envs is None:
            env = gym.make(env_name)
        else:
            env = SyncVectorEnv(
                [lambda: gym.make(env_name) for _ in range(n_envs)])

        if env_name == "MountainCarContinuous-v0":
            return EnvironmentLoop(env,
                                   self._create_continuouse_policy(env),
                                   fetch_agent_info=fetch_agent_info)

        if env_name == "Taxi-v3" or "CartPole" in env_name:
            return EnvironmentLoop(env,
                                   self._create_discrite_policy(env),
                                   fetch_agent_info=fetch_agent_info)

        raise RuntimeError("Unknown env", env_name)
コード例 #7
0
def main():
    env_name = "Taxi-v3"
    env = gym.make(env_name)

    n_observations = env.observation_space.n if not isinstance(
        env, VectorEnv) else env.envs[0].observation_space.n
    n_actions = env.action_space.n if not isinstance(
        env, VectorEnv) else env.action_space[0].n
    model = QTableModel(
        n_observations=n_observations,
        n_actions=n_actions,
    )
    """
    data_module, callbacks = off_policy_dataset(
        lambda: gym.make(env_name),
        model.select_online_actions,
        capacity=100,
        n_populate_steps=100,
        steps_per_epoch=250,
    )
    """
    data_module, callbacks = on_policy_dataset(
        lambda: gym.make(env_name),
        model.select_online_actions,
        batch_size=32,
        steps_per_epoch=250,
    )

    env_loop = EnvironmentLoop(env, model.select_actions)
    eval_callback = EnvironmentEvaluationCallback(env_loop)

    trainer = pl.Trainer(
        gpus=0,
        fast_dev_run=False,
        max_epochs=50,
        callbacks=callbacks + [eval_callback],
        logger=False,
        checkpoint_callback=False,
    )

    trainer.fit(model, data_module)
コード例 #8
0
ファイル: builders.py プロジェクト: j0rd1smit/lighting-RL
def on_policy_dataset(
    env_builder: EnvBuilder,
    select_online_actions: Policy,
    fetch_agent_info: Optional[FetchAgentInfo] = None,
    # batch
    batch_size: int = 4000,
    # online callback
    n_envs: int = 10,
    steps_per_epoch: int = 5000,
    # post processing
    post_process_function: Optional[PostProcessFunction] = None,
) -> Tuple[OnlineDataModule, List[Callback]]:
    buffer = UniformReplayBuffer(batch_size)

    samples_per_epoch = steps_per_epoch * batch_size
    sampler = EntireBufferSampler(buffer, samples_per_epoch)

    data_module = OnlineDataModule(buffer,
                                   batch_size,
                                   sampler=sampler,
                                   pin_memory=True,
                                   n_workers=0)

    online_env = _build_env(env_builder, n_envs)

    n_samples_per_step = batch_size
    env_loop = EnvironmentLoop(online_env,
                               select_online_actions,
                               fetch_agent_info=fetch_agent_info)

    online_step_callback = OnlineDataCollectionCallback(
        buffer,
        env_loop,
        n_samples_per_step=n_samples_per_step,
        n_populate_steps=0,
        post_process_function=post_process_function,
        clear_buffer_before_gather=True,
    )

    return data_module, [online_step_callback]
コード例 #9
0
ファイル: builders.py プロジェクト: j0rd1smit/lighting-RL
def eval_callback(
    env_builder: EnvBuilder,
    select_actions: Policy,
    seed: Optional[int] = None,
    n_envs: int = 1,
    n_eval_episodes: int = 10,
    n_test_episodes: int = 100,
    to_eval: bool = False,
    logging_prefix: str = "Evaluation",
    mean_return_in_progress_bar: bool = True,
) -> EnvironmentEvaluationCallback:
    env = _build_env(env_builder, n_envs)

    env_loop = EnvironmentLoop(env, select_actions)
    return EnvironmentEvaluationCallback(
        env_loop,
        n_eval_episodes=n_eval_episodes,
        n_test_episodes=n_test_episodes,
        to_eval=to_eval,
        seed=seed,
        logging_prefix=logging_prefix,
        mean_return_in_progress_bar=mean_return_in_progress_bar,
    )
コード例 #10
0
ファイル: builders.py プロジェクト: j0rd1smit/lighting-RL
    buffer = UniformReplayBuffer(capacity)

    samples_per_epoch = steps_per_epoch * batch_size
    sampler = UniformSampler(buffer, samples_per_epoch)

    data_module = OnlineDataModule(buffer,
                                   batch_size,
                                   sampler=sampler,
                                   pin_memory=True,
                                   n_workers=0)

    online_env = _build_env(env_builder, n_envs)

    n_samples_per_step = batch_size
    env_loop = EnvironmentLoop(online_env,
                               select_online_actions,
                               fetch_agent_info=fetch_agent_info)

    online_step_callback = OnlineDataCollectionCallback(
        buffer,
        env_loop,
        n_samples_per_step=n_samples_per_step,
        n_populate_steps=n_populate_steps,
        post_process_function=post_process_function,
        clear_buffer_before_gather=False,
    )

    return data_module, [online_step_callback]


def eval_callback(