def _create_replay_buffer_and_insert(env: EnvWrapper): env.seed(1) replay_buffer = ReplayBuffer(replay_capacity=6, batch_size=1) replay_buffer_inserter = make_replay_buffer_inserter(env) obs = env.reset() inserted = [] terminal = False i = 0 while not terminal and i < 5: logger.info(f"Iteration: {i}") action = env.action_space.sample() next_obs, reward, terminal, _ = env.step(action) inserted.append( { "observation": obs, "action": action, "reward": reward, "terminal": terminal, } ) transition = Transition( mdp_id=0, sequence_number=i, observation=obs, action=action, reward=reward, terminal=terminal, log_prob=0.0, ) replay_buffer_inserter(replay_buffer, transition) obs = next_obs i += 1 return replay_buffer, inserted
def run_episode( env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None ) -> Trajectory: """ Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() obs = env.reset() terminal = False num_steps = 0 while not terminal: action, log_prob = agent.act(obs) next_obs, reward, terminal, _ = env.step(action) if max_steps is not None and num_steps >= max_steps: terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=float(reward), terminal=bool(terminal), log_prob=log_prob, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs num_steps += 1 return trajectory
async def async_run_episode( env: EnvWrapper, agent: Agent, mdp_id: int = 0, max_steps: Optional[int] = None, fill_info: bool = False, ) -> Trajectory: """ NOTE: this funciton is an async coroutine in order to support async env.step(). If you are using it with regular env.step() method, use non-async run_episode(), which wraps this function. Return sum of rewards from episode. After max_steps (if specified), the environment is assumed to be terminal. Can also specify the mdp_id and gamma of episode. """ trajectory = Trajectory() obs = env.reset() possible_actions_mask = env.possible_actions_mask terminal = False num_steps = 0 step_is_coroutine = asyncio.iscoroutinefunction(env.step) while not terminal: action, log_prob = agent.act(obs, possible_actions_mask) if step_is_coroutine: next_obs, reward, terminal, info = await env.step(action) else: next_obs, reward, terminal, info = env.step(action) if not fill_info: info = None next_possible_actions_mask = env.possible_actions_mask if max_steps is not None and num_steps >= max_steps: terminal = True # Only partially filled. Agent can fill in more fields. transition = Transition( mdp_id=mdp_id, sequence_number=num_steps, observation=obs, action=action, reward=float(reward), terminal=bool(terminal), log_prob=log_prob, possible_actions_mask=possible_actions_mask, info=info, ) agent.post_step(transition) trajectory.add_transition(transition) SummaryWriterContext.increase_global_step() obs = next_obs possible_actions_mask = next_possible_actions_mask num_steps += 1 agent.post_episode(trajectory) return trajectory