def _create_replay_buffer_and_insert(env: EnvWrapper): env.seed(1) replay_buffer = ReplayBuffer(replay_capacity=6, batch_size=1) replay_buffer_inserter = make_replay_buffer_inserter(env) obs = env.reset() inserted = [] terminal = False i = 0 while not terminal and i < 5: logger.info(f"Iteration: {i}") action = env.action_space.sample() next_obs, reward, terminal, _ = env.step(action) inserted.append({ "observation": obs, "action": action, "reward": reward, "terminal": terminal, }) transition = Transition( mdp_id=0, sequence_number=i, observation=obs, action=action, reward=reward, terminal=terminal, log_prob=0.0, ) replay_buffer_inserter(replay_buffer, transition) obs = next_obs i += 1 return replay_buffer, inserted
def run_test_episode_buffer( env: EnvWrapper, policy: Policy, trainer: Trainer, num_train_episodes: int, passing_score_bar: float, num_eval_episodes: int, use_gpu: bool = False, ): training_policy = policy post_episode_callback = train_post_episode(env, trainer, use_gpu) # pyre-fixme[16]: `EnvWrapper` has no attribute `seed`. env.seed(SEED) # pyre-fixme[16]: `EnvWrapper` has no attribute `action_space`. env.action_space.seed(SEED) train_rewards = train_policy( env, training_policy, num_train_episodes, post_step=None, post_episode=post_episode_callback, use_gpu=use_gpu, ) # Check whether the max score passed the score bar; we explore during training # the return could be bad (leading to flakiness in C51 and QRDQN). assert np.max(train_rewards) >= passing_score_bar, ( f"max reward ({np.max(train_rewards)}) after training for " f"{len(train_rewards)} episodes is less than < {passing_score_bar}.\n") serving_policy = policy eval_rewards = eval_policy(env, serving_policy, num_eval_episodes, serving=False) assert ( eval_rewards.mean() >= passing_score_bar ), f"Eval reward is {eval_rewards.mean()}, less than < {passing_score_bar}.\n"