def _create_replay_buffer_and_insert(env: EnvWrapper):
    env.seed(1)
    replay_buffer = ReplayBuffer(replay_capacity=6, batch_size=1)
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append({
            "observation": obs,
            "action": action,
            "reward": reward,
            "terminal": terminal,
        })
        transition = Transition(
            mdp_id=0,
            sequence_number=i,
            observation=obs,
            action=action,
            reward=reward,
            terminal=terminal,
            log_prob=0.0,
        )
        replay_buffer_inserter(replay_buffer, transition)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
Beispiel #2
0
def run_test_episode_buffer(
    env: EnvWrapper,
    policy: Policy,
    trainer: Trainer,
    num_train_episodes: int,
    passing_score_bar: float,
    num_eval_episodes: int,
    use_gpu: bool = False,
):
    training_policy = policy

    post_episode_callback = train_post_episode(env, trainer, use_gpu)

    # pyre-fixme[16]: `EnvWrapper` has no attribute `seed`.
    env.seed(SEED)
    # pyre-fixme[16]: `EnvWrapper` has no attribute `action_space`.
    env.action_space.seed(SEED)

    train_rewards = train_policy(
        env,
        training_policy,
        num_train_episodes,
        post_step=None,
        post_episode=post_episode_callback,
        use_gpu=use_gpu,
    )

    # Check whether the max score passed the score bar; we explore during training
    # the return could be bad (leading to flakiness in C51 and QRDQN).
    assert np.max(train_rewards) >= passing_score_bar, (
        f"max reward ({np.max(train_rewards)}) after training for "
        f"{len(train_rewards)} episodes is less than < {passing_score_bar}.\n")

    serving_policy = policy
    eval_rewards = eval_policy(env,
                               serving_policy,
                               num_eval_episodes,
                               serving=False)
    assert (
        eval_rewards.mean() >= passing_score_bar
    ), f"Eval reward is {eval_rewards.mean()}, less than < {passing_score_bar}.\n"