def _create_replay_buffer_and_insert(env: EnvWrapper):
    env.seed(1)
    replay_buffer = ReplayBuffer(replay_capacity=6, batch_size=1)
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append({
            "observation": obs,
            "action": action,
            "reward": reward,
            "terminal": terminal,
        })
        transition = Transition(
            mdp_id=0,
            sequence_number=i,
            observation=obs,
            action=action,
            reward=reward,
            terminal=terminal,
            log_prob=0.0,
        )
        replay_buffer_inserter(replay_buffer, transition)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
Ejemplo n.º 2
0
def run_episode(env: EnvWrapper,
                agent: Agent,
                mdp_id: int = 0,
                max_steps: Optional[int] = None) -> Trajectory:
    """
    Return sum of rewards from episode.
    After max_steps (if specified), the environment is assumed to be terminal.
    Can also specify the mdp_id and gamma of episode.
    """
    trajectory = Trajectory()
    obs = env.reset()
    possible_actions_mask = env.possible_actions_mask
    terminal = False
    num_steps = 0
    while not terminal:
        action, log_prob = agent.act(obs, possible_actions_mask)
        next_obs, reward, terminal, _ = env.step(action)
        next_possible_actions_mask = env.possible_actions_mask
        if max_steps is not None and num_steps >= max_steps:
            terminal = True

        # Only partially filled. Agent can fill in more fields.
        transition = Transition(
            mdp_id=mdp_id,
            sequence_number=num_steps,
            observation=obs,
            action=action,
            reward=float(reward),
            terminal=bool(terminal),
            log_prob=log_prob,
            possible_actions_mask=possible_actions_mask,
        )
        agent.post_step(transition)
        trajectory.add_transition(transition)
        SummaryWriterContext.increase_global_step()
        obs = next_obs
        possible_actions_mask = next_possible_actions_mask
        num_steps += 1
    agent.post_episode(trajectory)
    return trajectory