def fill_replay_buffer(env, replay_buffer: ReplayBuffer, desired_size: int): """ Fill replay buffer with random transitions until size reaches desired_size. """ assert ( 0 < desired_size and desired_size <= replay_buffer._replay_capacity ), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}." assert replay_buffer.size < desired_size, ( f"Replay buffer already has {replay_buffer.size} elements. " f"(more than desired_size = {desired_size})") logger.info( f" Starting to fill replay buffer using random policy to size: {desired_size}." ) random_policy = make_random_policy_for_env(env) post_step = add_replay_buffer_post_step(replay_buffer, env=env) agent = Agent.create_for_env(env, policy=random_policy, post_transition_callback=post_step) max_episode_steps = env.max_steps with tqdm( total=desired_size - replay_buffer.size, desc= f"Filling replay buffer from {replay_buffer.size} to size {desired_size} using random policy", ) as pbar: mdp_id = 0 while replay_buffer.size < desired_size: last_size = replay_buffer.size max_steps = desired_size - replay_buffer.size - 1 if max_episode_steps is not None: max_steps = min(max_episode_steps, max_steps) run_episode(env=env, agent=agent, mdp_id=mdp_id, max_steps=max_steps) size_delta = replay_buffer.size - last_size # The assertion below is commented out because it can't # support input samples which has seq_len>1. This should be # treated as a bug, and need to be fixed in the future. # assert ( # size_delta >= 0 # ), f"size delta is {size_delta} which should be non-negative." pbar.update(n=size_delta) mdp_id += 1 if size_delta <= 0: # replay buffer size isn't increasing... so stop early break if replay_buffer.size >= desired_size: logger.info( f"Successfully filled replay buffer to size: {replay_buffer.size}!" ) else: logger.info( f"Stopped early and filled replay buffer to size: {replay_buffer.size}." )
def fill_replay_buffer(env: Env, replay_buffer: ReplayBuffer, desired_size: int): """ Fill replay buffer with random transitions until size reaches desired_size. """ assert ( 0 < desired_size and desired_size <= replay_buffer._replay_capacity ), f"It's not true that 0 < {desired_size} <= {replay_buffer._replay_capacity}." assert replay_buffer.size < desired_size, ( f"Replay buffer already has {replay_buffer.size} elements. " f"(more than desired_size = {desired_size})") logger.info(f"Starting to fill replay buffer to size: {desired_size}.") random_policy = make_random_policy_for_env(env) post_step = add_replay_buffer_post_step(replay_buffer, env=env) agent = Agent.create_for_env(env, policy=random_policy, post_transition_callback=post_step) max_episode_steps = get_max_steps(env) with tqdm( total=desired_size - replay_buffer.size, desc= f"Filling replay buffer from {replay_buffer.size} to size {desired_size}", ) as pbar: mdp_id = 0 while replay_buffer.size < desired_size: last_size = replay_buffer.size max_steps = desired_size - replay_buffer.size - 1 if max_episode_steps is not None: max_steps = min(max_episode_steps, max_steps) run_episode(env=env, agent=agent, mdp_id=mdp_id, max_steps=max_steps) size_delta = replay_buffer.size - last_size assert ( size_delta >= 0), f"size delta is {size_delta} which should be non-negative." pbar.update(n=size_delta) mdp_id += 1 if size_delta == 0: # replay buffer size isn't increasing... so stop early break if replay_buffer.size >= desired_size: logger.info( f"Successfully filled replay buffer to size: {replay_buffer.size}!" ) else: logger.info( f"Stopped early and filled replay buffer to size: {replay_buffer.size}." )