def run_test_episode_buffer( env: EnvWrapper, policy: Policy, trainer: Trainer, num_train_episodes: int, passing_score_bar: float, num_eval_episodes: int, use_gpu: bool = False, ): training_policy = policy post_episode_callback = train_post_episode(env, trainer, use_gpu) # pyre-fixme[16]: `EnvWrapper` has no attribute `seed`. env.seed(SEED) # pyre-fixme[16]: `EnvWrapper` has no attribute `action_space`. env.action_space.seed(SEED) train_rewards = train_policy( env, training_policy, num_train_episodes, post_step=None, post_episode=post_episode_callback, use_gpu=use_gpu, ) # Check whether the max score passed the score bar; we explore during training # the return could be bad (leading to flakiness in C51 and QRDQN). assert np.max(train_rewards) >= passing_score_bar, ( f"max reward ({np.max(train_rewards)}) after training for " f"{len(train_rewards)} episodes is less than < {passing_score_bar}.\n") serving_policy = policy eval_rewards = eval_policy(env, serving_policy, num_eval_episodes, serving=False) assert ( eval_rewards.mean() >= passing_score_bar ), f"Eval reward is {eval_rewards.mean()}, less than < {passing_score_bar}.\n"
def run_test_online_episode( env: Env__Union, model: ModelManager__Union, num_train_episodes: int, passing_score_bar: float, num_eval_episodes: int, use_gpu: bool, ): """ Run an online learning test. At the end of each episode training is run on the trajectory. """ env = env.value # pyre-fixme[16]: Module `pl` has no attribute `seed_everything`. pl.seed_everything(SEED) env.seed(SEED) env.action_space.seed(SEED) normalization = build_normalizer(env) logger.info(f"Normalization is: \n{pprint.pformat(normalization)}") manager = model.value trainer = manager.initialize_trainer( use_gpu=use_gpu, reward_options=RewardOptions(), normalization_data_map=normalization, ) policy = manager.create_policy(serving=False) device = torch.device("cuda") if use_gpu else torch.device("cpu") agent = Agent.create_for_env(env, policy, device=device) # pyre-fixme[16]: Module `pl` has no attribute `LightningModule`. if isinstance(trainer, pl.LightningModule): # pyre-fixme[16]: Module `pl` has no attribute `Trainer`. pl_trainer = pl.Trainer(max_epochs=1, gpus=int(use_gpu), deterministic=True) dataset = EpisodicDataset(env=env, agent=agent, num_episodes=num_train_episodes, seed=SEED) pl_trainer.fit(trainer, dataset) else: post_episode_callback = train_post_episode(env, trainer, use_gpu) _ = train_policy( env, policy, num_train_episodes, post_step=None, post_episode=post_episode_callback, use_gpu=use_gpu, ) eval_rewards = evaluate_for_n_episodes( n=num_eval_episodes, env=env, agent=agent, max_steps=env.max_steps, num_processes=1, ).squeeze(1) assert ( eval_rewards.mean() >= passing_score_bar ), f"Eval reward is {eval_rewards.mean()}, less than < {passing_score_bar}.\n"