def main():

    env = suite_gym.load('Trajectory-v0', gym_kwargs={
        'num_dimensions': 2,
        'num_observables': 3,
        'max_targets': 100,
        'max_steps': 5000,
        'max_steps_without_target': 5000,
        'max_position': 100.0,
        'max_acceleration': 10.2,
        'max_velocity': 15.0,
        'collision_epsilon': 10.0
    })
    tf_env = tf_py_environment.TFPyEnvironment(env)

    agent = RandomAgent(tf_env.time_step_spec(), tf_env.action_spec())
    uniform_replay_buffer = TFUniformReplayBuffer(agent.collect_data_spec, batch_size=1)

    transitions = []

    driver = DynamicStepDriver(
        tf_env,
        policy=agent.policy,
        observers=[uniform_replay_buffer.add_batch],
        transition_observers=[transitions.append],
        num_steps=500
    )

    initial_time_step = tf_env.reset()
    final_time_step, final_policy_state = driver.run(initial_time_step)
    dataset = uniform_replay_buffer.as_dataset()

    input_state = []
    input_action = []
    output_state = []
    output_reward = []
    for transition in transitions:
        input_state.append(tf.concat(tf.nest.flatten(transition[0].observation), axis=-1))
        input_action.append(tf.concat(tf.nest.flatten(transition[1].action), axis=-1))
        output_state.append(tf.concat(tf.nest.flatten(transition[2].observation), axis=-1))
        output_reward.append(tf.concat(tf.nest.flatten(transition[2].reward), axis=-1))

    tf_input_state = tf.squeeze(tf.stack(input_state), axis=1)
    tf_input_action = tf.squeeze(tf.stack(input_action), axis=1)
    tf_output_state = tf.squeeze(tf.stack(output_state), axis=1)
    tf_output_reward = tf.stack(output_reward)
     
    # dataset = (features, labels)

    # (time_step_before, policy_step_action, time_step_after) = transitions[0]
    # observation = time_step_before.observation
    # action = policy_step_action.action
    # # (discount_, observation_, reward_, step_type_) = time_step_after
    # observation_ = time_step_after.observation

    pass
def main():

    env = suite_gym.load('Trajectory-v0',
                         gym_kwargs={
                             'num_dimensions': 2,
                             'num_observables': 15,
                             'max_targets': 100,
                             'max_steps': 5000,
                             'max_steps_without_target': 5000,
                             'max_position': 100.0,
                             'max_acceleration': 10.2,
                             'max_velocity': 15.0,
                             'collision_epsilon': 10.0
                         })
    tf_env = tf_py_environment.TFPyEnvironment(env)

    agent = RandomAgent(time_step_spec=tf_env.time_step_spec(),
                        action_spec=tf_env.action_spec())

    metric = AverageReturnMetric()
    replay_buffer = []
    # uniform_replay_buffer = PyUniformReplayBuffer(data_spec=agent.collect_data_spec, capacity=2000)
    uniform_replay_buffer = TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec, batch_size=1)
    # observers = [replay_buffer.append, metric]

    # driver = PyDriver(
    #     env,
    #     policy=RandomPyPolicy(env.time_step_spec(), env.action_spec()),
    #     observers=[replay_buffer.append, metric],
    #     max_steps=2000
    # )

    # driver = TFDriver(
    #     tf_env,
    #     # policy=RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()),
    #     policy=agent.policy,
    #     observers=[uniform_replay_buffer],
    #     max_steps=2000
    # )

    driver = DynamicStepDriver(
        tf_env,
        policy=agent.policy,
        observers=[uniform_replay_buffer.add_batch],  #, metric],
        # transition_observers=None,
        num_steps=1000)

    agent.initialize()
    initial_time_step = tf_env.reset()
    final_time_step, final_policy_state = driver.run(initial_time_step)

    dataset = uniform_replay_buffer.as_dataset()
Example #3
0
 def driver(self):
     collect_driver = DynamicStepDriver(
         self.tf_env,
         self.agent.collect_policy,
         observers=[self.replay_buffer_observer] + self.training_metrics,
         num_steps=self.update_period
     )  # collect 4 steps for each training iterations
     return collect_driver
Example #4
0
def test_random_shooting_with_dynamic_step_driver(observation_space, action_space):
    """
    This test uses the environment wrapper as an adapter so that a driver from TF-Agents can be used
    to generate a rollout. This also serves as an example of how to construct "random shooting"
    rollouts from an environment model.

    The assertion in this test is that selected action has the expected log_prob value consistent
    with optimisers from a uniform distribution. All this is really checking is that the preceeding
    code has run successfully.
    """

    network = LinearTransitionNetwork(observation_space)
    environment = KerasTransitionModel([network], observation_space, action_space)
    wrapped_environment = EnvironmentModel(
        environment,
        ConstantReward(observation_space, action_space, 0.0),
        ConstantFalseTermination(observation_space),
        create_uniform_initial_state_distribution(observation_space),
    )

    random_policy = RandomTFPolicy(
        wrapped_environment.time_step_spec(), action_space, emit_log_probability=True
    )

    transition_observer = _RecordLastLogProbTransitionObserver()

    driver = DynamicStepDriver(
        env=wrapped_environment,
        policy=random_policy,
        transition_observers=[transition_observer],
    )
    driver.run()

    last_log_prob = transition_observer.last_log_probability

    uniform_distribution = create_uniform_distribution_from_spec(action_space)
    action_log_prob = uniform_distribution.log_prob(transition_observer.action)
    expected = np.sum(action_log_prob.numpy().astype(np.float32))
    actual = np.sum(last_log_prob.numpy())

    np.testing.assert_array_almost_equal(actual, expected, decimal=4)
def train_dyke_agent(train_env: TFPyEnvironment, eval_env: TFPyEnvironment,
                     agent: DqnAgent, train_steps: int, steps_per_episode: int,
                     eval_episodes: int) -> Dict[str, Any]:
    """
	Trains the DQN agent on the dyke maintenance task.

	:param train_env: The training environment.
	:param eval_env: The environment for testing agent performance.
	:param agent: The agent.
	:param train_steps: The number of training steps to use.
	:param steps_per_episode: The number of time steps that can be taken in a single dyke environment episode.
	:param eval_episodes: The number of episodes to use per evaluation.
	:return: A mapping to various metrics pertaining to the training's results.
	"""
    losses: np.ndarray = np.zeros(shape=(train_steps, steps_per_episode))
    evaluations: np.ndarray = np.zeros(shape=(train_steps, eval_episodes))
    train_metrics: Tuple = (AverageReturnMetric, )
    train_metric_results: np.ndarray = np.zeros(shape=(len(train_metrics),
                                                       train_steps,
                                                       steps_per_episode))
    for step in range(train_steps):
        # we uniformly sample experiences (single time steps) from one episode per train step
        print('STEP %d/%d' % (step + 1, train_steps))
        train_env.reset()
        rep_buf = _dyke_replay_buffer(train_env, agent, steps_per_episode)
        train_metric_inst: Tuple = tuple(
            [metric() for metric in train_metrics])  # instantiate the metrics
        obs: Tuple = (rep_buf.add_batch, ) + train_metric_inst
        _ = DynamicStepDriver(
            env=train_env,
            policy=agent.collect_policy,
            observers=obs,
            num_steps=steps_per_episode
        ).run(
        )  # experience a single episode using the agent's current configuration
        dataset: tf.data.Dataset = rep_buf.as_dataset(
            sample_batch_size=_REP_BUF_BATCH_SIZE,
            num_steps=_REP_BUF_NUM_STEPS)
        iterator = iter(dataset)
        for tr in range(steps_per_episode):
            trajectories, _ = next(iterator)
            losses[step, tr] = agent.train(experience=trajectories).loss
            for met in range(len(train_metrics)):
                train_metric_results[
                    met, step, tr] = train_metric_inst[met].result().numpy()
        evaluations[step, :] = _evaluate_dyke_agent(eval_env, agent,
                                                    eval_episodes)
    return {
        'loss': losses,
        'eval': evaluations,
        'train-metrics': train_metric_results
    }
Example #6
0
def main(_):
    # Environment
    env_name = "Breakout-v4"
    train_num_parallel_environments = 5
    max_steps_per_episode = 1000
    # Replay buffer
    replay_buffer_capacity = 50000
    init_replay_buffer = 500
    # Driver
    collect_steps_per_iteration = 1 * train_num_parallel_environments
    # Training
    train_batch_size = 32
    train_iterations = 100000
    train_summary_interval = 200
    train_checkpoint_interval = 200
    # Evaluation
    eval_num_parallel_environments = 5
    eval_summary_interval = 500
    eval_num_episodes = 20
    # File paths
    path = pathlib.Path(__file__)
    parent_dir = path.parent.resolve()
    folder_name = path.stem + time.strftime("_%Y%m%d_%H%M%S")
    train_checkpoint_dir = str(parent_dir / folder_name / "train_checkpoint")
    train_summary_dir = str(parent_dir / folder_name / "train_summary")
    eval_summary_dir = str(parent_dir / folder_name / "eval_summary")

    # Parallel training environment
    tf_env = TFPyEnvironment(
        ParallelPyEnvironment([
            lambda: suite_atari.load(
                env_name,
                env_wrappers=
                [lambda env: TimeLimit(env, duration=max_steps_per_episode)],
                gym_env_wrappers=[AtariPreprocessing, FrameStack4],
            )
        ] * train_num_parallel_environments))
    tf_env.seed([42] * tf_env.batch_size)
    tf_env.reset()

    # Parallel evaluation environment
    eval_tf_env = TFPyEnvironment(
        ParallelPyEnvironment([
            lambda: suite_atari.load(
                env_name,
                env_wrappers=
                [lambda env: TimeLimit(env, duration=max_steps_per_episode)],
                gym_env_wrappers=[AtariPreprocessing, FrameStack4],
            )
        ] * eval_num_parallel_environments))
    eval_tf_env.seed([42] * eval_tf_env.batch_size)
    eval_tf_env.reset()

    # Creating the Deep Q-Network
    preprocessing_layer = keras.layers.Lambda(
        lambda obs: tf.cast(obs, np.float32) / 255.)

    conv_layer_params = [(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]
    fc_layer_params = [512]

    q_net = QNetwork(tf_env.observation_spec(),
                     tf_env.action_spec(),
                     preprocessing_layers=preprocessing_layer,
                     conv_layer_params=conv_layer_params,
                     fc_layer_params=fc_layer_params)

    # Creating the DQN Agent
    optimizer = keras.optimizers.RMSprop(lr=2.5e-4,
                                         rho=0.95,
                                         momentum=0.0,
                                         epsilon=0.00001,
                                         centered=True)

    epsilon_fn = keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=1.0,  # initial ε
        decay_steps=2500000,
        end_learning_rate=0.01)  # final ε

    global_step = tf.compat.v1.train.get_or_create_global_step()

    agent = DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        target_update_period=200,
        td_errors_loss_fn=keras.losses.Huber(reduction="none"),
        gamma=0.99,  # discount factor
        train_step_counter=global_step,
        epsilon_greedy=lambda: epsilon_fn(global_step))
    agent.initialize()

    # Creating the Replay Buffer
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=replay_buffer_capacity)

    # Observer: Replay Buffer Observer
    replay_buffer_observer = replay_buffer.add_batch

    # Observer: Training Metrics
    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(batch_size=tf_env.batch_size),
    ]

    # Creating the Collect Driver
    collect_driver = DynamicStepDriver(tf_env,
                                       agent.collect_policy,
                                       observers=[replay_buffer_observer] +
                                       train_metrics,
                                       num_steps=collect_steps_per_iteration)

    # Initialize replay buffer
    initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                            tf_env.action_spec())
    init_driver = DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[replay_buffer_observer,
                   ShowProgress()],
        num_steps=init_replay_buffer)
    final_time_step, final_policy_state = init_driver.run()

    # Creating the Dataset
    dataset = replay_buffer.as_dataset(sample_batch_size=train_batch_size,
                                       num_steps=2,
                                       num_parallel_calls=3).prefetch(3)

    # Optimize by wrapping some of the code in a graph using TF function.
    collect_driver.run = function(collect_driver.run)
    agent.train = function(agent.train)

    print("\n\n++++++++++++++++++++++++++++++++++\n")

    # Create checkpoint
    train_checkpointer = Checkpointer(
        ckpt_dir=train_checkpoint_dir,
        max_to_keep=1,
        agent=agent,
        # replay_buffer=replay_buffer,
        global_step=global_step,
        # metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')
    )

    # Restore checkpoint
    # train_checkpointer.initialize_or_restore()

    # Summary writers and metrics
    train_summary_writer = tf.summary.create_file_writer(train_summary_dir)
    eval_summary_writer = tf.summary.create_file_writer(eval_summary_dir)
    eval_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(batch_size=eval_tf_env.batch_size,
                                       buffer_size=eval_num_episodes),
        tf_metrics.AverageEpisodeLengthMetric(
            batch_size=eval_tf_env.batch_size, buffer_size=eval_num_episodes)
    ]

    # Create evaluate callback function
    eval_callback = evaluate(eval_metrics=eval_metrics,
                             eval_tf_env=eval_tf_env,
                             eval_policy=agent.policy,
                             eval_num_episodes=eval_num_episodes,
                             train_step=global_step,
                             eval_summary_writer=eval_summary_writer)

    # Train agent
    train_agent(tf_env=tf_env,
                train_iterations=train_iterations,
                global_step=global_step,
                agent=agent,
                dataset=dataset,
                collect_driver=collect_driver,
                train_metrics=train_metrics,
                train_checkpointer=train_checkpointer,
                train_checkpoint_interval=train_checkpoint_interval,
                train_summary_writer=train_summary_writer,
                train_summary_interval=train_summary_interval,
                eval_summary_interval=eval_summary_interval,
                eval_callback=eval_callback)

    print("\n\n++++++++++ END OF TF_AGENTS RL TRAINING ++++++++++\n\n")
Example #7
0
 average_return = tf_metrics.AverageReturnMetric(
     prefix='Train',
     buffer_size=num_eval_episodes,
     batch_size=tf_env.batch_size)
 train_metrics = [
     tf_metrics.NumberOfEpisodes(),
     env_steps,
     average_return,
     tf_metrics.AverageEpisodeLengthMetric(prefix='Train',
                                           buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
 ]
 initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                         tf_env.action_spec())
 init_driver = DynamicStepDriver(tf_env,
                                 initial_collect_policy,
                                 observers=[replay_buffer.add_batch],
                                 num_steps=nRandTotalSteps)
 collect_driver = DynamicStepDriver(tf_env,
                                    agent.collect_policy,
                                    observers=[replay_buffer_observer] +
                                    train_metrics,
                                    num_steps=update_period)
 collect_driver.run = function(collect_driver.run)
 init_driver.run = function(init_driver.run)
 agent.train = function(agent.train)
 final_time_step, final_policy_state = init_driver.run()
 time_acc = 0
 env_steps_before = env_steps.result().numpy()
 dataset = replay_buffer.as_dataset(sample_batch_size=64,
                                    num_steps=2,
                                    num_parallel_calls=3).prefetch(3)
Example #8
0
def breakout_v4(seed=42):
    env = suite_gym.load("Breakout-v4")
    env.seed(seed)
    env.reset()

    repeating_env = ActionRepeat(env, times=4)
    for name in dir(tf_agents.environments.wrappers):
        obj = getattr(tf_agents.environments.wrappers, name)
        if hasattr(obj, "__base__") and issubclass(
                obj, tf_agents.environments.wrappers.PyEnvironmentBaseWrapper):
            print("{:27s} {}".format(name, obj.__doc__.split("\n")[0]))

    limited_repeating_env = suite_gym.load(
        "Breakout-v4",
        gym_env_wrappers=[partial(TimeLimit, max_episode_steps=10000)],
        env_wrappers=[partial(ActionRepeat, times=4)],
    )

    max_episode_steps = 27000  # <=> 108k ALE frames since 1 step = 4 frames
    environment_name = "BreakoutNoFrameskip-v4"

    env = suite_atari.load(
        environment_name,
        max_episode_steps=max_episode_steps,
        gym_env_wrappers=[AtariPreprocessing, FrameStack4],
    )

    env.seed(42)
    env.reset()
    time_step = env.step(np.array(1))  # FIRE
    for _ in range(4):
        time_step = env.step(np.array(3))  # LEFT

    def plot_observation(obs):
        # Since there are only 3 color channels, you cannot display 4 frames
        # with one primary color per frame. So this code computes the delta between
        # the current frame and the mean of the other frames, and it adds this delta
        # to the red and blue channels to get a pink color for the current frame.
        obs = obs.astype(np.float32)
        img_ = obs[..., :3]
        current_frame_delta = np.maximum(
            obs[..., 3] - obs[..., :3].mean(axis=-1), 0.0)
        img_[..., 0] += current_frame_delta
        img_[..., 2] += current_frame_delta
        img_ = np.clip(img_ / 150, 0, 1)
        plt.imshow(img_)
        plt.axis("off")

    plt.figure(figsize=(6, 6))
    plot_observation(time_step.observation)
    plt.tight_layout()
    plt.savefig("./images/preprocessed_breakout_plot.png",
                format="png",
                dpi=300)
    plt.show()

    tf_env = TFPyEnvironment(env)

    preprocessing_layer = keras.layers.Lambda(
        lambda obs: tf.cast(obs, np.float32) / 255.0)
    conv_layer_params = [(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]
    fc_layer_params = [512]

    q_net = QNetwork(
        tf_env.observation_spec(),
        tf_env.action_spec(),
        preprocessing_layers=preprocessing_layer,
        conv_layer_params=conv_layer_params,
        fc_layer_params=fc_layer_params,
    )

    # see TF-agents issue #113
    # optimizer = keras.optimizers.RMSprop(lr=2.5e-4, rho=0.95, momentum=0.0,
    #                                     epsilon=0.00001, centered=True)

    train_step = tf.Variable(0)
    update_period = 4  # run a training step every 4 collect steps
    optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=2.5e-4,
                                                    decay=0.95,
                                                    momentum=0.0,
                                                    epsilon=0.00001,
                                                    centered=True)
    epsilon_fn = keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=1.0,  # initial ε
        decay_steps=250000 // update_period,  # <=> 1,000,000 ALE frames
        end_learning_rate=0.01,
    )  # final ε
    agent = DqnAgent(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        target_update_period=2000,  # <=> 32,000 ALE frames
        td_errors_loss_fn=keras.losses.Huber(reduction="none"),
        gamma=0.99,  # discount factor
        train_step_counter=train_step,
        epsilon_greedy=lambda: epsilon_fn(train_step),
    )
    agent.initialize()

    from tf_agents.replay_buffers import tf_uniform_replay_buffer

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=tf_env.batch_size,
        max_length=1000000)

    replay_buffer_observer = replay_buffer.add_batch

    class ShowProgress:
        def __init__(self, total):
            self.counter = 0
            self.total = total

        def __call__(self, trajectory):
            if not trajectory.is_boundary():
                self.counter += 1
            if self.counter % 100 == 0:
                print("\r{}/{}".format(self.counter, self.total), end="")

    from tf_agents.metrics import tf_metrics

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

    from tf_agents.eval.metric_utils import log_metrics
    import logging

    logging.getLogger().setLevel(logging.INFO)
    log_metrics(train_metrics)

    from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

    collect_driver = DynamicStepDriver(
        tf_env,
        agent.collect_policy,
        observers=[replay_buffer_observer] + train_metrics,
        num_steps=update_period,
    )  # collect 4 steps for each training iteration

    from tf_agents.policies.random_tf_policy import RandomTFPolicy

    initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                            tf_env.action_spec())
    init_driver = DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=[replay_buffer.add_batch,
                   ShowProgress(20000)],
        num_steps=20000,
    )  # <=> 80,000 ALE frames
    final_time_step, final_policy_state = init_driver.run()
Example #9
0
Additionally, the data encountered by the driver at each step is saved 
in a named tuple called Trajectory and broadcast to a set of observers 
such as replay buffers and metrics. This data includes the observation 
from the environment, the action recommended by the policy, the reward 
obtained, the type of the current and the next step, etc.

We currently have 2 TensorFlow drivers: DynamicStepDriver, which terminates after a 
given number of (valid) environment steps and DynamicEpisodeDriver, which 
terminates after a given number of episodes. We want to collect experience for
4 steps for each training iteration (as was done in the 2015 DQN paper)
'''

collect_driver = DynamicStepDriver(
    tf_env,
    agent.collect_policy,
    observers=[replay_buffer_observer] + train_metrics,
    num_steps=config.UPDATE_PERIOD
)  # collect 4 steps for each training iteration
'''
We could now run the collect_driver by calling its run() method, ut it is best
to warm up the replay buffer with experiences collected using a purely random policy.
For this, we can use the RandomTFPolicy class and create a second driver that will run
the policy for 20000 steps (which is equivalent to 80000 simulator frames, as was done in 
the 2015 DQN paper). We use ShowProgress to display the progress:
'''
initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                        tf_env.action_spec())
init_driver = DynamicStepDriver(
    tf_env,
    initial_collect_policy,
    observers=[
Example #10
0
)
"""
print("After  replay_buffer")
replay_buffer_observer = replay_buffer.add_batch

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]
print("Before  DynamicStepDriver")

collect_driver = DynamicStepDriver(tf_env,
                                   agent.collect_policy,
                                   observers=[replay_buffer_observer] +
                                   train_metrics,
                                   num_steps=update_period)


class ShowProgress:
    def __init__(self, total):
        self.counter = 0
        self.total = total

    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % 100 == 0:
            print("\r{}/{}".format(self.counter, self.total), end="")
Example #11
0
        frames.append(tf_env.pyenv.envs[0].render(mode="rgb_array"))

    prev_lives = tf_env.pyenv.envs[0].ale.lives()

    def reset_and_fire_on_life_lost(trajectory):
        global prev_lives
        lives = tf_env.pyenv.envs[0].ale.lives()
        if prev_lives != lives:
            tf_env.reset()
            tf_env.step(1)
            prev_lives = lives

    watch_driver = DynamicStepDriver(tf_env,
                                     saved_policy,
                                     observers=[
                                         save_frames,
                                         reset_and_fire_on_life_lost,
                                         ShowProgress(1000)
                                     ],
                                     num_steps=1000)

    tf_env.reset()  # reset the env
    time_step = tf_env.step(1)  # fire the ball to begin playing
    policy_state = saved_policy.get_initial_state()  # empty state ()
    final_time_step, final_policy_state = watch_driver.run(
        time_step, policy_state)

    # render a window that shows the agent plays (works on the jupyter notebook)
    renderingUtils = RenderingUtils(frames)

    renderingUtils.plot_animation()
Example #12
0
    tf_metrics.AverageEpisodeLengthMetric()
]

training_metrics_2 = [
    tf_metrics.MaxReturnMetric(),
    tf_metrics.MinReturnMetric()
]

logging.getLogger().setLevel(logging.INFO)

# %% Driver
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

collect_driver = DynamicStepDriver(
    tf_env,
    agent.collect_policy,
    observers=replay_buffer_observer + training_metrics + training_metrics_2,
    num_steps=update_period,
)

from tf_agents.policies.random_tf_policy import RandomTFPolicy

# initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec())
# init_driver = DynamicStepDriver(
#     tf_env,
#     initial_collect_policy,
#     observers=replay_buffer_observer + training_metrics + training_metrics_2,
#     num_steps=update_period,
# )
# final_time_step, final_policty_state = init_driver.run()

# %% Dataset
Example #13
0
class TrainDDQN():
    """Wrapper for DDQN training, validation, saving etc."""
    def __init__(self,
                 episodes: int,
                 warmup_steps: int,
                 learning_rate: float,
                 gamma: float,
                 min_epsilon: float,
                 decay_episodes: int,
                 model_path: str = None,
                 log_dir: str = None,
                 batch_size: int = 64,
                 memory_length: int = None,
                 collect_steps_per_episode: int = 1,
                 val_every: int = None,
                 target_update_period: int = 1,
                 target_update_tau: float = 1.0,
                 progressbar: bool = True,
                 n_step_update: int = 1,
                 gradient_clipping: float = 1.0,
                 collect_every: int = 1) -> None:
        """
        Wrapper to make training easier.
        Code is partly based of https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial

        :param episodes: Number of training episodes
        :type  episodes: int
        :param warmup_steps: Number of episodes to fill Replay Buffer with random state-action pairs before training starts
        :type  warmup_steps: int
        :param learning_rate: Learning Rate for the Adam Optimizer
        :type  learning_rate: float
        :param gamma: Discount factor for the Q-values
        :type  gamma: float
        :param min_epsilon: Lowest and final value for epsilon
        :type  min_epsilon: float
        :param decay_episodes: Amount of episodes to decay from 1 to `min_epsilon`
        :type  decay_episodes: int
        :param model_path: Location to save the trained model
        :type  model_path: str
        :param log_dir: Location to save the logs, usefull for TensorBoard
        :type  log_dir: str
        :param batch_size: Number of samples in minibatch to train on each step
        :type  batch_size: int
        :param memory_length: Maximum size of the Replay Buffer
        :type  memory_length: int
        :param collect_steps_per_episode: Amount of data to collect for Replay Buffer each episiode
        :type  collect_steps_per_episode: int
        :param collect_every: Step interval to collect data during training
        :type  collect_every: int
        :param val_every: Validate the model every X episodes using the `collect_metrics()` function
        :type  val_every: int
        :param target_update_period: Update the target Q-network every X episodes
        :type  target_update_period: int
        :param target_update_tau: Parameter for softening the `target_update_period`
        :type  target_update_tau: float
        :param progressbar: Enable or disable the progressbar for collecting data and training
        :type  progressbar: bool

        :return: None
        :rtype: NoneType
        """
        self.episodes = episodes  # Total episodes
        self.warmup_steps = warmup_steps  # Amount of warmup steps before training
        self.batch_size = batch_size  # Batch size of Replay Memory
        self.collect_steps_per_episode = collect_steps_per_episode  # Amount of steps to collect data each episode
        self.collect_every = collect_every  # Step interval to collect data during training
        self.learning_rate = learning_rate  # Learning Rate
        self.gamma = gamma  # Discount factor
        self.min_epsilon = min_epsilon  # Minimal chance of choosing random action
        self.decay_episodes = decay_episodes  # Number of episodes to decay from 1.0 to `EPSILON`
        self.target_update_period = target_update_period  # Period for soft updates
        self.target_update_tau = target_update_tau
        self.progressbar = progressbar  # Enable or disable the progressbar for collecting data and training
        self.n_step_update = n_step_update
        self.gradient_clipping = gradient_clipping  # Clip the loss
        self.compiled = False
        NOW = datetime.now().strftime("%Y%m%d_%H%M%S")

        if memory_length is not None:
            self.memory_length = memory_length  # Max Replay Memory length
        else:
            self.memory_length = warmup_steps

        if val_every is not None:
            self.val_every = val_every  # Validate the policy every `val_every` episodes
        else:
            self.val_every = self.episodes // min(
                50, self.episodes
            )  # Can't validate the model 50 times if self.episodes < 50

        if model_path is not None:
            self.model_path = model_path
        else:
            self.model_path = "./models/" + NOW + ".pkl"

        if log_dir is None:
            log_dir = "./logs/" + NOW
        self.writer = tf.summary.create_file_writer(log_dir)

    def compile_model(self,
                      X_train,
                      y_train,
                      layers: list = [],
                      imb_ratio: float = None,
                      loss_fn=common.element_wise_squared_loss) -> None:
        """Initializes the neural networks, DDQN-agent, collect policies and replay buffer.

        :param X_train: Training data for the model.
        :type  X_train: np.ndarray
        :param y_train: Labels corresponding to `X_train`.  1 for the positive class, 0 for the negative class.
        :param y_train: np.ndarray
        :param layers: List of layers to feed into the TF-agents custom Sequential(!) layer.
        :type  layers: list
        :param imb_ratio: The imbalance ratio of the data.
        :type  imb_ratio: float
        :param loss_fn: Callable loss function
        :type  loss_fn: tf.compat.v1.losses

        :return: None
        :rtype: NoneType
        """
        if imb_ratio is None:
            imb_ratio = imbalance_ratio(y_train)

        self.train_env = TFPyEnvironment(
            ClassifierEnv(X_train, y_train, imb_ratio))
        self.global_episode = tf.Variable(
            0, name="global_episode", dtype=np.int64,
            trainable=False)  # Global train episode counter

        # Custom epsilon decay: https://github.com/tensorflow/agents/issues/339
        epsilon_decay = tf.compat.v1.train.polynomial_decay(
            1.0,
            self.global_episode,
            self.decay_episodes,
            end_learning_rate=self.min_epsilon)

        self.q_net = Sequential(layers, self.train_env.observation_spec())

        self.agent = DdqnAgent(
            self.train_env.time_step_spec(),
            self.train_env.action_spec(),
            q_network=self.q_net,
            optimizer=Adam(learning_rate=self.learning_rate),
            td_errors_loss_fn=loss_fn,
            train_step_counter=self.global_episode,
            target_update_period=self.target_update_period,
            target_update_tau=self.target_update_tau,
            gamma=self.gamma,
            epsilon_greedy=epsilon_decay,
            n_step_update=self.n_step_update,
            gradient_clipping=self.gradient_clipping)
        self.agent.initialize()

        self.random_policy = RandomTFPolicy(self.train_env.time_step_spec(),
                                            self.train_env.action_spec())
        self.replay_buffer = TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.train_env.batch_size,
            max_length=self.memory_length)

        self.warmup_driver = DynamicStepDriver(
            self.train_env,
            self.random_policy,
            observers=[self.replay_buffer.add_batch],
            num_steps=self.warmup_steps)  # Uses a random policy

        self.collect_driver = DynamicStepDriver(
            self.train_env,
            self.agent.collect_policy,
            observers=[self.replay_buffer.add_batch],
            num_steps=self.collect_steps_per_episode
        )  # Uses the epsilon-greedy policy of the agent

        self.agent.train = common.function(self.agent.train)  # Optimalization
        self.warmup_driver.run = common.function(self.warmup_driver.run)
        self.collect_driver.run = common.function(self.collect_driver.run)

        self.compiled = True

    def train(self, *args) -> None:
        """Starts the training of the model. Includes warmup period, metrics collection and model saving.

        :param *args: All arguments will be passed to `collect_metrics()`.
            This can be usefull to pass callables, testing environments or validation data.
            Overwrite the TrainDDQN.collect_metrics() function to use your own *args.
        :type  *args: Any

        :return: None
        :rtype: NoneType, last step is saving the model as a side-effect
        """
        assert self.compiled, "Model must be compiled with model.compile_model(X_train, y_train, layers) before training."

        # Warmup period, fill memory with random actions
        if self.progressbar:
            print(
                f"\033[92mCollecting data for {self.warmup_steps:_} steps... This might take a few minutes...\033[0m"
            )

        self.warmup_driver.run(
            time_step=None,
            policy_state=self.random_policy.get_initial_state(
                self.train_env.batch_size))

        if self.progressbar:
            print(
                f"\033[92m{self.replay_buffer.num_frames():_} frames collected!\033[0m"
            )

        dataset = self.replay_buffer.as_dataset(
            sample_batch_size=self.batch_size,
            num_steps=self.n_step_update + 1,
            num_parallel_calls=data.experimental.AUTOTUNE).prefetch(
                data.experimental.AUTOTUNE)
        iterator = iter(dataset)

        def _train():
            experiences, _ = next(iterator)
            return self.agent.train(experiences).loss

        _train = common.function(_train)  # Optimalization

        ts = None
        policy_state = self.agent.collect_policy.get_initial_state(
            self.train_env.batch_size)
        self.collect_metrics(*args)  # Initial collection for step 0
        pbar = tqdm(total=self.episodes,
                    disable=(not self.progressbar),
                    desc="Training the DDQN")  # TQDM progressbar
        for _ in range(self.episodes):
            if not self.global_episode % self.collect_every:
                # Collect a few steps using collect_policy and save to `replay_buffer`
                if self.collect_steps_per_episode != 0:
                    ts, policy_state = self.collect_driver.run(
                        time_step=ts, policy_state=policy_state)
                pbar.update(
                    self.collect_every
                )  # More stable TQDM updates, collecting could take some time

            # Sample a batch of data from `replay_buffer` and update the agent's network
            train_loss = _train()

            if not self.global_episode % self.val_every:
                with self.writer.as_default():
                    tf.summary.scalar("train_loss",
                                      train_loss,
                                      step=self.global_episode)

                self.collect_metrics(*args)
        pbar.close()

    def collect_metrics(self,
                        X_val: np.ndarray,
                        y_val: np.ndarray,
                        save_best: str = None):
        """Collects metrics using the trained Q-network.

        :param X_val: Features of validation data, same shape as X_train
        :type  X_val: np.ndarray
        :param y_val: Labels of validation data, same shape as y_train
        :type  y_val: np.ndarray
        :param save_best: Saving the best model of all validation runs based on given metric:
            Choose one of: {Gmean, F1, Precision, Recall, TP, TN, FP, FN}
            This improves stability since the model at the last episode is not guaranteed to be the best model.
        :type  save_best: str
        """
        y_pred = network_predictions(self.agent._target_q_network, X_val)
        stats = classification_metrics(y_val, y_pred)
        avgQ = np.mean(decision_function(self.agent._target_q_network,
                                         X_val))  # Max action for each x in X

        if save_best is not None:
            if not hasattr(self, "best_score"):  # If no best model yet
                self.best_score = 0.0

            if stats.get(save_best) >= self.best_score:  # Overwrite best model
                self.save_network(
                )  # Saving directly to avoid shallow copy without trained weights
                self.best_score = stats.get(save_best)

        with self.writer.as_default():
            tf.summary.scalar(
                "AverageQ", avgQ,
                step=self.global_episode)  # Average Q-value for this epoch
            for k, v in stats.items():
                tf.summary.scalar(k, v, step=self.global_episode)

    def evaluate(self, X_test, y_test, X_train=None, y_train=None):
        """
        Final evaluation of trained Q-network with X_test and y_test.
        Optional PR and ROC curve comparison to X_train, y_train to ensure no overfitting is taking place.

        :param X_test: Features of test data, same shape as X_train
        :type  X_test: np.ndarray
        :param y_test: Labels of test data, same shape as y_train
        :type  y_test: np.ndarray
        :param X_train: Features of train data
        :type  X_train: np.ndarray
        :param y_train: Labels of train data
        :type  y_train: np.ndarray
        """
        if hasattr(self, "best_score"):
            print(f"\033[92mBest score: {self.best_score:6f}!\033[0m")
            network = self.load_network(
                self.model_path)  # Load best saved model
        else:
            network = self.agent._target_q_network  # Load latest target model

        if (X_train is not None) and (y_train is not None):
            plot_pr_curve(network, X_test, y_test, X_train, y_train)
            plot_roc_curve(network, X_test, y_test, X_train, y_train)

        y_pred = network_predictions(network, X_test)
        return classification_metrics(y_test, y_pred)

    def save_network(self):
        """Saves Q-network as pickle to `model_path`."""
        with open(self.model_path, "wb") as f:  # Save Q-network as pickle
            pickle.dump(self.agent._target_q_network, f)

    @staticmethod
    def load_network(fp: str):
        """Static method to load Q-network pickle from given filepath.

        :param fp: Filepath to the saved pickle of the network
        :type  fp: str

        :returns: The network-object loaded from a pickle file.
        :rtype: tensorflow.keras.models.Model
        """
        with open(fp, "rb") as f:  # Load the Q-network
            network = pickle.load(f)
        return network
Example #14
0
    def compile_model(self,
                      X_train,
                      y_train,
                      layers: list = [],
                      imb_ratio: float = None,
                      loss_fn=common.element_wise_squared_loss) -> None:
        """Initializes the neural networks, DDQN-agent, collect policies and replay buffer.

        :param X_train: Training data for the model.
        :type  X_train: np.ndarray
        :param y_train: Labels corresponding to `X_train`.  1 for the positive class, 0 for the negative class.
        :param y_train: np.ndarray
        :param layers: List of layers to feed into the TF-agents custom Sequential(!) layer.
        :type  layers: list
        :param imb_ratio: The imbalance ratio of the data.
        :type  imb_ratio: float
        :param loss_fn: Callable loss function
        :type  loss_fn: tf.compat.v1.losses

        :return: None
        :rtype: NoneType
        """
        if imb_ratio is None:
            imb_ratio = imbalance_ratio(y_train)

        self.train_env = TFPyEnvironment(
            ClassifierEnv(X_train, y_train, imb_ratio))
        self.global_episode = tf.Variable(
            0, name="global_episode", dtype=np.int64,
            trainable=False)  # Global train episode counter

        # Custom epsilon decay: https://github.com/tensorflow/agents/issues/339
        epsilon_decay = tf.compat.v1.train.polynomial_decay(
            1.0,
            self.global_episode,
            self.decay_episodes,
            end_learning_rate=self.min_epsilon)

        self.q_net = Sequential(layers, self.train_env.observation_spec())

        self.agent = DdqnAgent(
            self.train_env.time_step_spec(),
            self.train_env.action_spec(),
            q_network=self.q_net,
            optimizer=Adam(learning_rate=self.learning_rate),
            td_errors_loss_fn=loss_fn,
            train_step_counter=self.global_episode,
            target_update_period=self.target_update_period,
            target_update_tau=self.target_update_tau,
            gamma=self.gamma,
            epsilon_greedy=epsilon_decay,
            n_step_update=self.n_step_update,
            gradient_clipping=self.gradient_clipping)
        self.agent.initialize()

        self.random_policy = RandomTFPolicy(self.train_env.time_step_spec(),
                                            self.train_env.action_spec())
        self.replay_buffer = TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.train_env.batch_size,
            max_length=self.memory_length)

        self.warmup_driver = DynamicStepDriver(
            self.train_env,
            self.random_policy,
            observers=[self.replay_buffer.add_batch],
            num_steps=self.warmup_steps)  # Uses a random policy

        self.collect_driver = DynamicStepDriver(
            self.train_env,
            self.agent.collect_policy,
            observers=[self.replay_buffer.add_batch],
            num_steps=self.collect_steps_per_episode
        )  # Uses the epsilon-greedy policy of the agent

        self.agent.train = common.function(self.agent.train)  # Optimalization
        self.warmup_driver.run = common.function(self.warmup_driver.run)
        self.collect_driver.run = common.function(self.collect_driver.run)

        self.compiled = True
#An observer to write the trajectories to the buffer
replay_buffer_observer = replay_buffer.add_batch

#Training metrics
train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric()
]
logging.getLogger().setLevel(logging.INFO)

#Create the driver
collect_driver = DynamicStepDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer_observer] + train_metrics,
    num_steps=collect_steps_per_iteration
)  # collect # steps for each training iteration

#Collect some inital experience with random policy
initial_collect_policy = MinYardScenarioPolicy(
    train_env.time_step_spec(), train_env.action_spec(
    ))  #RandomTFPolicy(train_env.time_step_spec(),train_env.action_spec())
#initial_collect_policy = RandomTFPolicy(train_env.time_step_spec(),train_env.action_spec())

init_driver = DynamicStepDriver(
    train_env,
    initial_collect_policy,
    observers=[replay_buffer.add_batch,
               ShowProgress(pretrain_steps)],
    num_steps=pretrain_steps)
        max_to_keep=1,
        agent=ppo_agent,
        policy=ppo_agent.policy,
        replay_buffer=replay_buffer,
        global_step=global_step)
    # Initialize the checkpointer
    checkpointer.initialize_or_restore()
    # Update the global step
    global_step = tf.compat.v1.train.get_global_step()

    # Create policy saver
    policy_saver = PolicySaver(ppo_agent.policy)

    # Create training driver
    train_driver = DynamicStepDriver(train_env,
                                     ppo_agent.collect_policy,
                                     observers=[replay_buffer.add_batch],
                                     num_steps=collect_steps_per_iter)
    # Wrap run function in TF graph
    train_driver.run = common.function(train_driver.run)
    print('Collecting initial data...')
    train_driver.run()

    # Reset the training step
    ppo_agent.train_step_counter.assign(0)

    # Evaluate the policy once before training
    print('Initial evaluation...')
    reward = compute_total_reward(eval_env, ppo_agent.policy)
    rewards = [reward]
    print('Initial total reward: {0}'.format(reward))
Example #17
0
class ShowProgress:
    def __init__(self, total):
        self.counter = 0
        self.total = total

    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % 1000 == 0:
            print("\r{}/{}".format(self.counter, self.total), end="")


init_driver = DynamicStepDriver(
    env=train_env,
    policy=initial_collect_policy,
    observers=[ replay_buffer.add_batch, ShowProgress(initial_collect_steps) ],
    num_steps=initial_collect_steps
)

# Collecting experiences.
print('Collecting random initial experiences...')
init_driver.run()

# 6. Training the agent.
dataset = replay_buffer.as_dataset(sample_batch_size=batch_size, num_steps=n_steps+1, num_parallel_calls=3).prefetch(3)

all_train_loss = []
all_metrics = []

collect_driver = DynamicStepDriver(
    env=train_env,
Example #18
0
saved_policy.info_spec = ()
saved_policy.emit_log_probability = True

saved_policy = EpsilonGreedyPolicy(saved_policy, epsilon=0.005)

#saved_policy = tf_agents.policies.gaussian_policy.GaussianPolicy(saved_policy)

#agent = tf.saved_model.load('policy_100')
#agent = tf.keras.models.load_model('policy_100')
#agent = tf.keras.models.load_model('policy_100')
#policy = tf.saved_model.load('')
#print(type(agent))
tf_env.pyenv.envs[0].step(np.array(1, dtype=np.int32))
watch_driver = DynamicStepDriver(
    tf_env,
    saved_policy,
    observers=[save_frames, reset_and_fire_on_life_lost, ShowProgress(max_episode_steps), debug_trajectory],
    num_steps=max_episode_steps)
#tf_env.pyenv.envs[0].step(np.array(1, dtype=np.int32))
final_time_step, final_policy_state = watch_driver.run()
#obs, reward, done, info = tf_env.pyenv.envs[0].step(np.array(1, dtype=np.int32))
print('rewards earned', rewards_per_game)

def update_scene(num, frames, patch):
    patch.set_data(frames[num])
    return patch,

def plot_animation(frames, repeat=False, interval=40):
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')
Example #19
0



train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]

logging.getLogger().setLevel(logging.INFO)

collect_driver = DynamicStepDriver(
    tf_env,
    agent.collect_policy,
    observers=[replay_buffer_observer] + train_metrics,
    num_steps=update_period) # collect 4 steps for each training iteration


# Warm up the buffer first with 20,000 steps, or 80,000 simulator frames in this case
# Use our custom show progress class as an observer
initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec())

init_driver = DynamicStepDriver(
    tf_env,
    initial_collect_policy,
    observers=[replay_buffer.add_batch, ShowProgress(20000)],
    num_steps=20000) # <=> 80,000 ALE frames

final_time_step, final_policy_state = init_driver.run()
Example #20
0
    # Create game environments: training and evaluation
    train_env = TFPyEnvironment(NineMensMorris(agent.policy, discount=DISCOUNT))
    eval_env = TFPyEnvironment(NineMensMorris(agent.policy, discount=DISCOUNT))

    # Random policy for data collection
    random_policy = RandomTFPolicy(time_step_spec=train_env.time_step_spec(),
                                   action_spec=train_env.action_spec())

    # Create replay buffer for data collection
    replay_buffer = TFUniformReplayBuffer(data_spec=agent.collect_data_spec,
                                          batch_size=train_env.batch_size,
                                          max_length=BUFFER_LENGTH)

    # Create driver for the agent
    driver = DynamicStepDriver(env=train_env,
                               policy=agent.collect_policy,
                               observers=[replay_buffer.add_batch],
                               num_steps=STEPS_PER_ITER)
    # Wrap the run function in a TF graph
    driver.run = common.function(driver.run)
    # Create driver for the random policy
    random_driver = DynamicStepDriver(env=train_env,
                                      policy=random_policy,
                                      observers=[replay_buffer.add_batch],
                                      num_steps=STEPS_PER_ITER)
    # Wrap the run function in a TF graph
    random_driver.run = common.function(random_driver.run)

    # Create a checkpointer
    checkpointer = common.Checkpointer(ckpt_dir=os.path.relpath('checkpoint'),
                                       max_to_keep=1,
                                       agent=agent,
Example #21
0
train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]

logging.getLogger().setLevel(logging.INFO)

## ------------------------------------------------------------------------------
## ------------------------------------------------------------------------------
## ------------------------------------------------------------------------------

collect_driver = DynamicStepDriver(
    tf_env,  # Env to play with
    agent.collect_policy,  # Collect policy of the agent
    observers=[replay_buffer_observer] +
    train_metrics,  # pass to all observers
    num_steps=1)
# Speed up as tensorflow function
collect_driver.run = function(collect_driver.run)

initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                        tf_env.action_spec())
init_driver = DynamicStepDriver(
    tf_env,
    initial_collect_policy,
    observers=[replay_buffer.add_batch,
               ShowProgress(init_replay_buffer)],
    num_steps=init_replay_buffer)
final_time_step, final_policy_state = init_driver.run()
Example #22
0
                time_step = environment.step(action_step.action)
                episode_return += time_step.reward
            total_return += episode_return

        return total_return / num_episodes

    # Replay buffer
    replay_buffer = TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec,
        batch_size=train_env.batch_size,
        max_length=REPLAY_BUFFER_MAX
    )

    driver = DynamicStepDriver(
        train_env,
        agent.collect_policy,
        observers=[replay_buffer.add_batch],
        num_steps=1
    )

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3,
        sample_batch_size=BATCH_SIZE,
        num_steps=2).prefetch(3)
    iterator = iter(dataset)

    agent.train_step_counter.assign(0)
    avg_return = compute_avg_return(eval_env, agent.policy)
    returns = [avg_return]

    # Pre-populate replay buffer
    for _ in range(PRETRAIN_LEN):