示例#1
0
def fill_replay_buffer(env, replay_buffer: ReplayBuffer, desired_size: int):
    """ Fill replay buffer with random transitions until size reaches desired_size. """
    assert (
        0 < desired_size and desired_size <= replay_buffer._replay_capacity
    ), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}."
    assert replay_buffer.size < desired_size, (
        f"Replay buffer already has {replay_buffer.size} elements. "
        f"(more than desired_size = {desired_size})")
    logger.info(
        f" Starting to fill replay buffer using random policy to size: {desired_size}."
    )
    random_policy = make_random_policy_for_env(env)
    post_step = add_replay_buffer_post_step(replay_buffer, env=env)

    agent = Agent.create_for_env(env,
                                 policy=random_policy,
                                 post_transition_callback=post_step)
    max_episode_steps = env.max_steps
    with tqdm(
            total=desired_size - replay_buffer.size,
            desc=
            f"Filling replay buffer from {replay_buffer.size} to size {desired_size} using random policy",
    ) as pbar:
        mdp_id = 0
        while replay_buffer.size < desired_size:
            last_size = replay_buffer.size
            max_steps = desired_size - replay_buffer.size - 1
            if max_episode_steps is not None:
                max_steps = min(max_episode_steps, max_steps)
            run_episode(env=env,
                        agent=agent,
                        mdp_id=mdp_id,
                        max_steps=max_steps)
            size_delta = replay_buffer.size - last_size
            # The assertion below is commented out because it can't
            # support input samples which has seq_len>1. This should be
            # treated as a bug, and need to be fixed in the future.
            # assert (
            #     size_delta >= 0
            # ), f"size delta is {size_delta} which should be non-negative."
            pbar.update(n=size_delta)
            mdp_id += 1
            if size_delta <= 0:
                # replay buffer size isn't increasing... so stop early
                break

    if replay_buffer.size >= desired_size:
        logger.info(
            f"Successfully filled replay buffer to size: {replay_buffer.size}!"
        )
    else:
        logger.info(
            f"Stopped early and filled replay buffer to size: {replay_buffer.size}."
        )
示例#2
0
def fill_replay_buffer(env: Env, replay_buffer: ReplayBuffer,
                       desired_size: int):
    """ Fill replay buffer with random transitions until size reaches desired_size. """
    assert (
        0 < desired_size and desired_size <= replay_buffer._replay_capacity
    ), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}."
    assert replay_buffer.size < desired_size, (
        f"Replay buffer already has {replay_buffer.size} elements. "
        f"(more than desired_size = {desired_size})")
    logger.info(f"Starting to fill replay buffer to size: {desired_size}.")
    random_policy = make_random_policy_for_env(env)
    post_step = add_replay_buffer_post_step(replay_buffer, env=env)
    agent = Agent.create_for_env(env,
                                 policy=random_policy,
                                 post_transition_callback=post_step)
    max_episode_steps = get_max_steps(env)
    with tqdm(
            total=desired_size - replay_buffer.size,
            desc=
            f"Filling replay buffer from {replay_buffer.size} to size {desired_size}",
    ) as pbar:
        mdp_id = 0
        while replay_buffer.size < desired_size:
            last_size = replay_buffer.size
            max_steps = desired_size - replay_buffer.size - 1
            if max_episode_steps is not None:
                max_steps = min(max_episode_steps, max_steps)
            run_episode(env=env,
                        agent=agent,
                        mdp_id=mdp_id,
                        max_steps=max_steps)
            size_delta = replay_buffer.size - last_size
            assert (
                size_delta >=
                0), f"size delta is {size_delta} which should be non-negative."
            pbar.update(n=size_delta)
            mdp_id += 1
            if size_delta == 0:
                # replay buffer size isn't increasing... so stop early
                break

    if replay_buffer.size >= desired_size:
        logger.info(
            f"Successfully filled replay buffer to size: {replay_buffer.size}!"
        )
    else:
        logger.info(
            f"Stopped early and filled replay buffer to size: {replay_buffer.size}."
        )