Esempio n. 1
0
    def create_for_trainer(
        cls,
        trainer,
        env: EnvWrapper,
        agent: Agent,
        replay_buffer: ReplayBuffer,
        batch_size: int,
        training_frequency: int = 1,
        num_episodes: Optional[int] = None,
        max_steps: Optional[int] = None,
        trainer_preprocessor=None,
        replay_buffer_inserter=None,
    ):
        device = torch.device("cpu")
        if trainer_preprocessor is None:
            trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
                trainer, device, env
            )

        if replay_buffer_inserter is None:
            replay_buffer_inserter = make_replay_buffer_inserter(env)

        return cls(
            env=env,
            agent=agent,
            replay_buffer=replay_buffer,
            batch_size=batch_size,
            training_frequency=training_frequency,
            num_episodes=num_episodes,
            max_steps=max_steps,
            trainer_preprocessor=trainer_preprocessor,
            replay_buffer_inserter=replay_buffer_inserter,
        )
def _create_replay_buffer_and_insert(env: gym.Env):
    env.seed(1)
    replay_buffer = ReplayBuffer.create_from_env(
        env, replay_memory_size=10, batch_size=1
    )
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "terminal": terminal,
            }
        )
        log_prob = 0.0
        replay_buffer_inserter(replay_buffer, obs, action, reward, terminal, log_prob)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
Esempio n. 3
0
def _create_replay_buffer_and_insert(env: EnvWrapper):
    env.seed(1)
    replay_buffer = ReplayBuffer(replay_capacity=6, batch_size=1)
    replay_buffer_inserter = make_replay_buffer_inserter(env)
    obs = env.reset()
    inserted = []
    terminal = False
    i = 0
    while not terminal and i < 5:
        logger.info(f"Iteration: {i}")
        action = env.action_space.sample()
        next_obs, reward, terminal, _ = env.step(action)
        inserted.append(
            {
                "observation": obs,
                "action": action,
                "reward": reward,
                "terminal": terminal,
            }
        )
        transition = Transition(
            mdp_id=0,
            sequence_number=i,
            observation=obs,
            action=action,
            reward=reward,
            terminal=terminal,
            log_prob=0.0,
        )
        replay_buffer_inserter(replay_buffer, transition)
        obs = next_obs
        i += 1

    return replay_buffer, inserted
Esempio n. 4
0
def add_replay_buffer_post_step(replay_buffer: ReplayBuffer,
                                env: gym.Env,
                                replay_buffer_inserter=None):
    """
    Simply add transitions to replay_buffer.
    """

    if replay_buffer_inserter is None:
        replay_buffer_inserter = make_replay_buffer_inserter(env)

    def post_step(transition: Transition) -> None:
        replay_buffer_inserter(replay_buffer, transition)

    return post_step
 def test_cartpole(self):
     env = gym.make("CartPole-v0")
     replay_buffer = ReplayBuffer.create_from_env(
         env, replay_memory_size=10, batch_size=5
     )
     replay_buffer_inserter = make_replay_buffer_inserter(env)
     obs = env.reset()
     terminal = False
     i = 0
     while not terminal and i < 5:
         action = env.action_space.sample()
         next_obs, reward, terminal, _ = env.step(action)
         replay_buffer_inserter(
             replay_buffer, obs, action, reward, terminal, log_prob=0.0
         )
         obs = next_obs
Esempio n. 6
0
def add_replay_buffer_post_step(replay_buffer: ReplayBuffer,
                                env: gym.Env,
                                replay_buffer_inserter=None):
    """
    Simply add transitions to replay_buffer.
    """

    if replay_buffer_inserter is None:
        replay_buffer_inserter = make_replay_buffer_inserter(env)

    def post_step(obs: Any, action: Any, reward: float, terminal: bool,
                  log_prob: float) -> None:
        replay_buffer_inserter(replay_buffer, obs, action, reward, terminal,
                               log_prob)

    return post_step
Esempio n. 7
0
def train_with_replay_buffer_post_step(
    replay_buffer: ReplayBuffer,
    env: gym.Env,
    trainer: Trainer,
    training_freq: int,
    batch_size: int,
    trainer_preprocessor=None,
    device: Union[str, torch.device] = "cpu",
    replay_buffer_inserter=None,
) -> PostStep:
    """ Called in post_step of agent to train based on replay buffer (RB).
        Args:
            trainer: responsible for having a .train method to train the model
            trainer_preprocessor: format RB output for trainer.train
            training_freq: how many steps in between trains
            batch_size: how big of a batch to sample
    """
    if isinstance(device, str):
        device = torch.device(device)

    if trainer_preprocessor is None:
        trainer_preprocessor = make_replay_buffer_trainer_preprocessor(
            trainer, device, env)

    if replay_buffer_inserter is None:
        replay_buffer_inserter = make_replay_buffer_inserter(env)

    _num_steps = 0

    def post_step(obs: Any, action: Any, reward: float, terminal: bool,
                  log_prob: float) -> None:
        nonlocal _num_steps

        replay_buffer_inserter(replay_buffer, obs, action, reward, terminal,
                               log_prob)

        if _num_steps % training_freq == 0:
            assert replay_buffer.size >= batch_size
            train_batch = replay_buffer.sample_transition_batch_tensor(
                batch_size=batch_size)
            preprocessed_batch = trainer_preprocessor(train_batch)
            trainer.train(preprocessed_batch)
        _num_steps += 1
        return

    return post_step
Esempio n. 8
0
def add_replay_buffer_post_step(
    replay_buffer: ReplayBuffer,
    # pyre-fixme[11]: Annotation `Env` is not defined as a type.
    env: gym.Env,
    replay_buffer_inserter=None,
):
    """
    Simply add transitions to replay_buffer.
    """

    if replay_buffer_inserter is None:
        replay_buffer_inserter = make_replay_buffer_inserter(env)

    def post_step(transition: Transition) -> None:
        replay_buffer_inserter(replay_buffer, transition)

    return post_step