Example #1
0
class ExperienceReplay(object):
    def __init__(self, agent, enviroment, batch_size):
        self._replay_buffer = TFUniformReplayBuffer(
            data_spec=agent.collect_data_spec,
            batch_size=enviroment.batch_size,
            max_length=100000)

        self._random_policy = RandomTFPolicy(enviroment.time_step_spec(),
                                             enviroment.action_spec())

        self._fill_buffer(enviroment, self._random_policy, steps=100)

        self.dataset = self._replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=2,
            single_deterministic_pass=False).prefetch(3)

        self.iterator = iter(self.dataset)

    def _fill_buffer(self, enviroment, policy, steps):
        for _ in range(steps):
            self.timestamp_data(enviroment, policy)

    def timestamp_data(self, environment, policy):
        time_step = environment.current_time_step()
        action_step = policy.action(time_step)
        next_time_step = environment.step(action_step.action)
        timestamp_trajectory = trajectory.from_transition(
            time_step, action_step, next_time_step)

        self._replay_buffer.add_batch(timestamp_trajectory)
Example #2
0
def train(agent: DdqnAgent, train_env: TFEnvironment,
          replay_buffer: TFUniformReplayBuffer, num_episodes: int,
          replay_buffer_batch_size: int, save_path: str, ramdomize_step: int,
          validate_step: int):

    train_dataset = replay_buffer.as_dataset(
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
        num_steps=REPLAY_BUFFER_NUM_STEPS,
        sample_batch_size=replay_buffer_batch_size)

    train_iterator = iter(train_dataset)

    # Savers
    checkpointer, saver = create_savers(save_path, agent, replay_buffer)
    policy_path = os.path.join(save_path, "saved", "policy")

    checkpointer.initialize_or_restore()

    global_step = 0
    episode = 0
    tf.print("Aguardando conexão do cliente")

    random_policy = create_random_policy(train_env)
    collect_episode_data(train_env,
                         random_policy,
                         replay_buffer,
                         repeats=3,
                         phase="Random")

    for episode in range(num_episodes):
        tf.print(f"Episódio {episode} iniciado")
        # Colect random data
        episode_info = collect_episode_data(train_env,
                                            agent.collect_policy,
                                            replay_buffer,
                                            repeats=1)

        experience, unused_info = next(train_iterator)
        train_loss = agent.train(experience).loss
        global_step = agent.train_step_counter.numpy()

        tf.print(
            f"Episódio {episode} finalizado, com custo {train_loss} e recompensa {episode_info['reward']}"
        )
        collect_episode_data(train_env,
                             agent.policy,
                             replay_buffer,
                             phase="Inference")

        #TODO: Salvar métricas no tensorboard ou outro cara

        # Salvar política e agente
        checkpointer.save(global_step=global_step)
        saver.save(policy_path)

    tf.print(
        f"Treinamento finalizado no eposiódio {episode} passo {global_step}")
def main():

    env = suite_gym.load('Trajectory-v0', gym_kwargs={
        'num_dimensions': 2,
        'num_observables': 3,
        'max_targets': 100,
        'max_steps': 5000,
        'max_steps_without_target': 5000,
        'max_position': 100.0,
        'max_acceleration': 10.2,
        'max_velocity': 15.0,
        'collision_epsilon': 10.0
    })
    tf_env = tf_py_environment.TFPyEnvironment(env)

    agent = RandomAgent(tf_env.time_step_spec(), tf_env.action_spec())
    uniform_replay_buffer = TFUniformReplayBuffer(agent.collect_data_spec, batch_size=1)

    transitions = []

    driver = DynamicStepDriver(
        tf_env,
        policy=agent.policy,
        observers=[uniform_replay_buffer.add_batch],
        transition_observers=[transitions.append],
        num_steps=500
    )

    initial_time_step = tf_env.reset()
    final_time_step, final_policy_state = driver.run(initial_time_step)
    dataset = uniform_replay_buffer.as_dataset()

    input_state = []
    input_action = []
    output_state = []
    output_reward = []
    for transition in transitions:
        input_state.append(tf.concat(tf.nest.flatten(transition[0].observation), axis=-1))
        input_action.append(tf.concat(tf.nest.flatten(transition[1].action), axis=-1))
        output_state.append(tf.concat(tf.nest.flatten(transition[2].observation), axis=-1))
        output_reward.append(tf.concat(tf.nest.flatten(transition[2].reward), axis=-1))

    tf_input_state = tf.squeeze(tf.stack(input_state), axis=1)
    tf_input_action = tf.squeeze(tf.stack(input_action), axis=1)
    tf_output_state = tf.squeeze(tf.stack(output_state), axis=1)
    tf_output_reward = tf.stack(output_reward)
     
    # dataset = (features, labels)

    # (time_step_before, policy_step_action, time_step_after) = transitions[0]
    # observation = time_step_before.observation
    # action = policy_step_action.action
    # # (discount_, observation_, reward_, step_type_) = time_step_after
    # observation_ = time_step_after.observation

    pass
def main():

    env = suite_gym.load('Trajectory-v0',
                         gym_kwargs={
                             'num_dimensions': 2,
                             'num_observables': 15,
                             'max_targets': 100,
                             'max_steps': 5000,
                             'max_steps_without_target': 5000,
                             'max_position': 100.0,
                             'max_acceleration': 10.2,
                             'max_velocity': 15.0,
                             'collision_epsilon': 10.0
                         })
    tf_env = tf_py_environment.TFPyEnvironment(env)

    agent = RandomAgent(time_step_spec=tf_env.time_step_spec(),
                        action_spec=tf_env.action_spec())

    metric = AverageReturnMetric()
    replay_buffer = []
    # uniform_replay_buffer = PyUniformReplayBuffer(data_spec=agent.collect_data_spec, capacity=2000)
    uniform_replay_buffer = TFUniformReplayBuffer(
        data_spec=agent.collect_data_spec, batch_size=1)
    # observers = [replay_buffer.append, metric]

    # driver = PyDriver(
    #     env,
    #     policy=RandomPyPolicy(env.time_step_spec(), env.action_spec()),
    #     observers=[replay_buffer.append, metric],
    #     max_steps=2000
    # )

    # driver = TFDriver(
    #     tf_env,
    #     # policy=RandomTFPolicy(tf_env.time_step_spec(), tf_env.action_spec()),
    #     policy=agent.policy,
    #     observers=[uniform_replay_buffer],
    #     max_steps=2000
    # )

    driver = DynamicStepDriver(
        tf_env,
        policy=agent.policy,
        observers=[uniform_replay_buffer.add_batch],  #, metric],
        # transition_observers=None,
        num_steps=1000)

    agent.initialize()
    initial_time_step = tf_env.reset()
    final_time_step, final_policy_state = driver.run(initial_time_step)

    dataset = uniform_replay_buffer.as_dataset()
Example #5
0
class SyncUniformExperienceReplayer(ExperienceReplayer):
    """
    For synchronous off-policy training.

    Example algorithms: DDPG, SAC
    """
    def __init__(self, experience_spec, batch_size):
        self._buffer = TFUniformReplayBuffer(experience_spec, batch_size)
        self._data_iter = None

    def observe(self, exp, env_ids=None):
        """
        For the sync driver, `exp` has the shape (`env_batch_size`, ...)
        with `num_envs`==1 and `unroll_length`==1. This function always ignores
        `env_ids`.
        """
        self._buffer.add_batch(exp)

    def replay(self, sample_batch_size, mini_batch_length):
        if self._data_iter is None:
            dataset = self._buffer.as_dataset(
                num_parallel_calls=3,
                sample_batch_size=sample_batch_size,
                num_steps=mini_batch_length).prefetch(3)
            self._data_iter = iter(dataset)
        return next(self._data_iter)

    def replay_all(self):
        return self._buffer.gather_all()

    def clear(self):
        self._buffer.clear()

    @property
    def batch_size(self):
        return self._buffer._batch_size
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = from_transition(time_step, action_step, next_time_step)

    # Add trajectory to the replay buffer
    buffer.add_batch(traj)


# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(env, agent.policy, 5)
returns = [avg_return]

collect_steps_per_iteration = 1
batch_size = 64
dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                   sample_batch_size=batch_size,
                                   num_steps=2).prefetch(3)
iterator = iter(dataset)
num_iterations = 10000
env.reset()

for _ in range(batch_size):
    collect_step(env, agent.policy, replay_buffer)

for _ in range(num_iterations):
    # Collect a few steps using collect_policy and save to the replay buffer.
    for _ in range(collect_steps_per_iteration):
        collect_step(env, agent.collect_policy, replay_buffer)

    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
Example #7
0
class DQNAgent:
    def __init__(self) -> None:
        """
        A class for training a TF-agent
        based on https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial
        """

        self.train_env = None  # Training environment
        self.agent = None  # The algorithm used to solve an RL problem is represented by a TF-Agent
        self.replay_buffer = None  # The replay buffer keeps track of data collected from the environment
        self.dataset = None  # The agent needs access to the replay buffer via an iterable tf.data.Dataset
        self.iterator = None  # The iterator of self.dataset

    def compile(self, X_train: np.ndarray, y_train: np.ndarray, lr: float, epsilon: float, gamma: float, imb_ratio: float,
                replay_buffer_max_length: int, layers: dict) -> None:
        """
        Create the Q-network, agent and policy

        Args:
            X_train: A np.ndarray for training samples.
            y_train: A np.ndarray for the class labels of the training samples.
            lr: learn rate for the optimizer (default Adam)
            epsilon: Used for the default epsilon greedy policy for choosing a random action.
            gamma: The discount factor for learning Q-values
            imb_ratio: ratio of imbalance. Used to specifiy reward in the environment
            replay_buffer_max_length: Maximum lenght of replay memory.
            layers: A dict containing the layers of the Q-Network (eg, conv, dense, rnn, dropout).
        """

        dense_layers = layers.get("dense")
        conv_layers = layers.get("conv")
        dropout_layers = layers.get("dropout")

        self.train_env = TFPyEnvironment(ClassifyEnv(X_train, y_train, imb_ratio))  # create a custom environment

        q_net = QNetwork(self.train_env.observation_spec(), self.train_env.action_spec(), conv_layer_params=conv_layers,
                         fc_layer_params=dense_layers, dropout_layer_params=dropout_layers)

        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr)

        train_step_counter = tf.Variable(0)

        self.agent = DqnAgent(
            self.train_env.time_step_spec(),
            self.train_env.action_spec(),
            q_network=q_net,
            optimizer=optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=train_step_counter,
            gamma=gamma,
            epsilon_greedy=epsilon,
        )

        self.agent.initialize()

        self.replay_buffer = TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.train_env.batch_size,
            max_length=replay_buffer_max_length)

    def fit(self, X_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size: int, eval_step: int, log_step: int,
            collect_steps_per_episode: int) -> None:
        """
        Starts the training of the Agent.

        Args:
            X_train: A np.ndarray for training samples.
            y_train: A np.ndarray for the class labels of the training samples.
            epochs: Number of epochs to train Agent
            batch_size: The Batch Size
            eval_step: Evaluate Model each 'eval_step'
            log_step: Monitor results of model each 'log_step'
            collect_steps_per_episode: Collect a few steps using collect_policy and save to the replay buffer.
        """

        self.dataset = self.replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=batch_size,
            num_steps=2).prefetch(3)

        self.iterator = iter(self.dataset)

        def collect_step(environment, policy, buffer):
            time_step = environment.current_time_step()
            action_step = policy.action(time_step)
            next_time_step = environment.step(action_step.action)
            traj = trajectory.from_transition(time_step, action_step, next_time_step)

            # Add trajectory to the replay buffer
            buffer.add_batch(traj)

        def collect_data(env, policy, buffer, steps):
            for _ in range(steps):
                collect_step(env, policy, buffer)

        # (Optional) Optimize by wrapping some of the code in a graph using TF function.
        self.agent.train = common.function(self.agent.train)

        # Reset the train step
        self.agent.train_step_counter.assign(0)

        for _ in range(epochs):
            #print("epoch: ", _)
            # Collect a few steps using collect_policy and save to the replay buffer.
            collect_data(self.train_env, self.agent.collect_policy, self.replay_buffer, collect_steps_per_episode)

            # Sample a batch of data from the buffer and update the agent's network.
            experience, _ = next(self.iterator)
            train_loss = self.agent.train(experience).loss

            step = self.agent.train_step_counter.numpy()

            if step % log_step == 0:
                print('step = {0}: loss = {1}'.format(step, train_loss))

            if step % eval_step == 0:
                metrics = self.compute_metrics(X_train, y_train)
                print(metrics)

    def compute_metrics(self, X: np.ndarray, y_true: list) -> dict:
        """Compute Metrics for Evaluation"""
        # TODO: apply softmax layer for q logits?

        q, _ = self.agent._target_q_network (X, training=False)

        # y_scores = np.max(q.numpy(), axis=1)  # predicted scores (Q-Values)
        y_pred = np.argmax(q.numpy(), axis=1)  # predicted class label

        metrics = custom_metrics(y_true, y_pred)

        return metrics

    def evaluate(self, X: np.ndarray, y: list, X_train=None, y_train=None) -> dict:
        """
         Evaluation of trained Q-network
        """
        metrics = self.compute_metrics(X, y)

        print("evaluation: ", metrics)
        return metrics
Example #8
0
class SyncUniformExperienceReplayer(ExperienceReplayer):
    """
    For synchronous off-policy training.

    Example algorithms: DDPG, SAC
    """

    def __init__(self, experience_spec, batch_size):
        # TFUniformReplayBuffer does not support list in spec, we have to do
        # some conversion.
        self._experience_spec = experience_spec
        self._exp_has_list = nest_utils.nest_contains_list(experience_spec)
        tuple_experience_spec = nest_utils.nest_list_to_tuple(experience_spec)
        self._buffer = TFUniformReplayBuffer(tuple_experience_spec, batch_size)
        self._data_iter = None

    def _list_to_tuple(self, exp):
        if self._exp_has_list:
            return nest_utils.nest_list_to_tuple(exp)
        else:
            return exp

    def _tuple_to_list(self, exp):
        if self._exp_has_list:
            return nest_utils.nest_tuple_to_list(exp, self._experience_spec)
        else:
            return exp

    def observe(self, exp, env_ids=None):
        """
        For the sync driver, `exp` has the shape (`env_batch_size`, ...)
        with `num_envs`==1 and `unroll_length`==1. This function always ignores
        `env_ids`.
        """
        self._buffer.add_batch(self._list_to_tuple(exp))

    def replay(self, sample_batch_size, mini_batch_length):
        """Get a random batch.

        Args:
            sample_batch_size (int): number of sequences
            mini_batch_length (int): the length of each sequence
        Returns:
            Experience: experience batch in batch major (B, T, ...)
            tf_uniform_replay_buffer.BufferInfo: information about the batch
        """
        if self._data_iter is None:
            dataset = self._buffer.as_dataset(
                num_parallel_calls=3,
                sample_batch_size=sample_batch_size,
                num_steps=mini_batch_length).prefetch(3)
            self._data_iter = iter(dataset)
        exp, info = next(self._data_iter)
        return self._tuple_to_list(exp), info

    def replay_all(self):
        return self._tuple_to_list(self._buffer.gather_all())

    def clear(self):
        self._buffer.clear()

    @property
    def batch_size(self):
        return self._buffer._batch_size
Example #9
0
    def train_implementation(self, train_context: core.TrainContext):
        """Tf-Agents Ppo Implementation of the train loop.

        The implementation follows
        https://colab.research.google.com/github/tensorflow/agents/blob/master/tf_agents/colabs/1_dqn_tutorial.ipynb
        """
        assert isinstance(train_context, core.StepsTrainContext)
        dc: core.StepsTrainContext = train_context

        train_env = self._create_env(discount=dc.reward_discount_gamma)
        observation_spec = train_env.observation_spec()
        action_spec = train_env.action_spec()
        timestep_spec = train_env.time_step_spec()

        # SetUp Optimizer, Networks and DqnAgent
        self.log_api('AdamOptimizer', '()')
        optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=dc.learning_rate)
        self.log_api('QNetwork', '()')
        q_net = q_network.QNetwork(observation_spec,
                                   action_spec,
                                   fc_layer_params=self.model_config.fc_layers)
        self.log_api('DqnAgent', '()')
        tf_agent = dqn_agent.DqnAgent(
            timestep_spec,
            action_spec,
            q_network=q_net,
            optimizer=optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss)

        self.log_api('tf_agent.initialize', f'()')
        tf_agent.initialize()
        self._trained_policy = tf_agent.policy

        # SetUp Data collection & Buffering
        self.log_api('TFUniformReplayBuffer', '()')
        replay_buffer = TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=train_env.batch_size,
            max_length=dc.max_steps_in_buffer)
        self.log_api('RandomTFPolicy', '()')
        random_policy = random_tf_policy.RandomTFPolicy(
            timestep_spec, action_spec)
        self.log_api('replay_buffer.add_batch', '(trajectory)')
        for _ in range(dc.num_steps_buffer_preload):
            self.collect_step(env=train_env,
                              policy=random_policy,
                              replay_buffer=replay_buffer)

        # Train
        tf_agent.train = common.function(tf_agent.train, autograph=False)
        self.log_api(
            'replay_buffer.as_dataset', f'(num_parallel_calls=3, ' +
            f'sample_batch_size={dc.num_steps_sampled_from_buffer}, num_steps=2).prefetch(3)'
        )
        dataset = replay_buffer.as_dataset(
            num_parallel_calls=3,
            sample_batch_size=dc.num_steps_sampled_from_buffer,
            num_steps=2).prefetch(3)
        iter_dataset = iter(dataset)
        self.log_api('for each iteration')
        self.log_api('  replay_buffer.add_batch', '(trajectory)')
        self.log_api('  tf_agent.train', '(experience=trajectory)')
        while True:
            self.on_train_iteration_begin()
            for _ in range(dc.num_steps_per_iteration):
                self.collect_step(env=train_env,
                                  policy=tf_agent.collect_policy,
                                  replay_buffer=replay_buffer)
            trajectories, _ = next(iter_dataset)
            tf_loss_info = tf_agent.train(experience=trajectories)
            self.on_train_iteration_end(tf_loss_info.loss)
            if train_context.training_done:
                break
        return
Example #10
0
class TrainDDQN():
    """Wrapper for DDQN training, validation, saving etc."""
    def __init__(self,
                 episodes: int,
                 warmup_steps: int,
                 learning_rate: float,
                 gamma: float,
                 min_epsilon: float,
                 decay_episodes: int,
                 model_path: str = None,
                 log_dir: str = None,
                 batch_size: int = 64,
                 memory_length: int = None,
                 collect_steps_per_episode: int = 1,
                 val_every: int = None,
                 target_update_period: int = 1,
                 target_update_tau: float = 1.0,
                 progressbar: bool = True,
                 n_step_update: int = 1,
                 gradient_clipping: float = 1.0,
                 collect_every: int = 1) -> None:
        """
        Wrapper to make training easier.
        Code is partly based of https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial

        :param episodes: Number of training episodes
        :type  episodes: int
        :param warmup_steps: Number of episodes to fill Replay Buffer with random state-action pairs before training starts
        :type  warmup_steps: int
        :param learning_rate: Learning Rate for the Adam Optimizer
        :type  learning_rate: float
        :param gamma: Discount factor for the Q-values
        :type  gamma: float
        :param min_epsilon: Lowest and final value for epsilon
        :type  min_epsilon: float
        :param decay_episodes: Amount of episodes to decay from 1 to `min_epsilon`
        :type  decay_episodes: int
        :param model_path: Location to save the trained model
        :type  model_path: str
        :param log_dir: Location to save the logs, usefull for TensorBoard
        :type  log_dir: str
        :param batch_size: Number of samples in minibatch to train on each step
        :type  batch_size: int
        :param memory_length: Maximum size of the Replay Buffer
        :type  memory_length: int
        :param collect_steps_per_episode: Amount of data to collect for Replay Buffer each episiode
        :type  collect_steps_per_episode: int
        :param collect_every: Step interval to collect data during training
        :type  collect_every: int
        :param val_every: Validate the model every X episodes using the `collect_metrics()` function
        :type  val_every: int
        :param target_update_period: Update the target Q-network every X episodes
        :type  target_update_period: int
        :param target_update_tau: Parameter for softening the `target_update_period`
        :type  target_update_tau: float
        :param progressbar: Enable or disable the progressbar for collecting data and training
        :type  progressbar: bool

        :return: None
        :rtype: NoneType
        """
        self.episodes = episodes  # Total episodes
        self.warmup_steps = warmup_steps  # Amount of warmup steps before training
        self.batch_size = batch_size  # Batch size of Replay Memory
        self.collect_steps_per_episode = collect_steps_per_episode  # Amount of steps to collect data each episode
        self.collect_every = collect_every  # Step interval to collect data during training
        self.learning_rate = learning_rate  # Learning Rate
        self.gamma = gamma  # Discount factor
        self.min_epsilon = min_epsilon  # Minimal chance of choosing random action
        self.decay_episodes = decay_episodes  # Number of episodes to decay from 1.0 to `EPSILON`
        self.target_update_period = target_update_period  # Period for soft updates
        self.target_update_tau = target_update_tau
        self.progressbar = progressbar  # Enable or disable the progressbar for collecting data and training
        self.n_step_update = n_step_update
        self.gradient_clipping = gradient_clipping  # Clip the loss
        self.compiled = False
        NOW = datetime.now().strftime("%Y%m%d_%H%M%S")

        if memory_length is not None:
            self.memory_length = memory_length  # Max Replay Memory length
        else:
            self.memory_length = warmup_steps

        if val_every is not None:
            self.val_every = val_every  # Validate the policy every `val_every` episodes
        else:
            self.val_every = self.episodes // min(
                50, self.episodes
            )  # Can't validate the model 50 times if self.episodes < 50

        if model_path is not None:
            self.model_path = model_path
        else:
            self.model_path = "./models/" + NOW + ".pkl"

        if log_dir is None:
            log_dir = "./logs/" + NOW
        self.writer = tf.summary.create_file_writer(log_dir)

    def compile_model(self,
                      X_train,
                      y_train,
                      layers: list = [],
                      imb_ratio: float = None,
                      loss_fn=common.element_wise_squared_loss) -> None:
        """Initializes the neural networks, DDQN-agent, collect policies and replay buffer.

        :param X_train: Training data for the model.
        :type  X_train: np.ndarray
        :param y_train: Labels corresponding to `X_train`.  1 for the positive class, 0 for the negative class.
        :param y_train: np.ndarray
        :param layers: List of layers to feed into the TF-agents custom Sequential(!) layer.
        :type  layers: list
        :param imb_ratio: The imbalance ratio of the data.
        :type  imb_ratio: float
        :param loss_fn: Callable loss function
        :type  loss_fn: tf.compat.v1.losses

        :return: None
        :rtype: NoneType
        """
        if imb_ratio is None:
            imb_ratio = imbalance_ratio(y_train)

        self.train_env = TFPyEnvironment(
            ClassifierEnv(X_train, y_train, imb_ratio))
        self.global_episode = tf.Variable(
            0, name="global_episode", dtype=np.int64,
            trainable=False)  # Global train episode counter

        # Custom epsilon decay: https://github.com/tensorflow/agents/issues/339
        epsilon_decay = tf.compat.v1.train.polynomial_decay(
            1.0,
            self.global_episode,
            self.decay_episodes,
            end_learning_rate=self.min_epsilon)

        self.q_net = Sequential(layers, self.train_env.observation_spec())

        self.agent = DdqnAgent(
            self.train_env.time_step_spec(),
            self.train_env.action_spec(),
            q_network=self.q_net,
            optimizer=Adam(learning_rate=self.learning_rate),
            td_errors_loss_fn=loss_fn,
            train_step_counter=self.global_episode,
            target_update_period=self.target_update_period,
            target_update_tau=self.target_update_tau,
            gamma=self.gamma,
            epsilon_greedy=epsilon_decay,
            n_step_update=self.n_step_update,
            gradient_clipping=self.gradient_clipping)
        self.agent.initialize()

        self.random_policy = RandomTFPolicy(self.train_env.time_step_spec(),
                                            self.train_env.action_spec())
        self.replay_buffer = TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.train_env.batch_size,
            max_length=self.memory_length)

        self.warmup_driver = DynamicStepDriver(
            self.train_env,
            self.random_policy,
            observers=[self.replay_buffer.add_batch],
            num_steps=self.warmup_steps)  # Uses a random policy

        self.collect_driver = DynamicStepDriver(
            self.train_env,
            self.agent.collect_policy,
            observers=[self.replay_buffer.add_batch],
            num_steps=self.collect_steps_per_episode
        )  # Uses the epsilon-greedy policy of the agent

        self.agent.train = common.function(self.agent.train)  # Optimalization
        self.warmup_driver.run = common.function(self.warmup_driver.run)
        self.collect_driver.run = common.function(self.collect_driver.run)

        self.compiled = True

    def train(self, *args) -> None:
        """Starts the training of the model. Includes warmup period, metrics collection and model saving.

        :param *args: All arguments will be passed to `collect_metrics()`.
            This can be usefull to pass callables, testing environments or validation data.
            Overwrite the TrainDDQN.collect_metrics() function to use your own *args.
        :type  *args: Any

        :return: None
        :rtype: NoneType, last step is saving the model as a side-effect
        """
        assert self.compiled, "Model must be compiled with model.compile_model(X_train, y_train, layers) before training."

        # Warmup period, fill memory with random actions
        if self.progressbar:
            print(
                f"\033[92mCollecting data for {self.warmup_steps:_} steps... This might take a few minutes...\033[0m"
            )

        self.warmup_driver.run(
            time_step=None,
            policy_state=self.random_policy.get_initial_state(
                self.train_env.batch_size))

        if self.progressbar:
            print(
                f"\033[92m{self.replay_buffer.num_frames():_} frames collected!\033[0m"
            )

        dataset = self.replay_buffer.as_dataset(
            sample_batch_size=self.batch_size,
            num_steps=self.n_step_update + 1,
            num_parallel_calls=data.experimental.AUTOTUNE).prefetch(
                data.experimental.AUTOTUNE)
        iterator = iter(dataset)

        def _train():
            experiences, _ = next(iterator)
            return self.agent.train(experiences).loss

        _train = common.function(_train)  # Optimalization

        ts = None
        policy_state = self.agent.collect_policy.get_initial_state(
            self.train_env.batch_size)
        self.collect_metrics(*args)  # Initial collection for step 0
        pbar = tqdm(total=self.episodes,
                    disable=(not self.progressbar),
                    desc="Training the DDQN")  # TQDM progressbar
        for _ in range(self.episodes):
            if not self.global_episode % self.collect_every:
                # Collect a few steps using collect_policy and save to `replay_buffer`
                if self.collect_steps_per_episode != 0:
                    ts, policy_state = self.collect_driver.run(
                        time_step=ts, policy_state=policy_state)
                pbar.update(
                    self.collect_every
                )  # More stable TQDM updates, collecting could take some time

            # Sample a batch of data from `replay_buffer` and update the agent's network
            train_loss = _train()

            if not self.global_episode % self.val_every:
                with self.writer.as_default():
                    tf.summary.scalar("train_loss",
                                      train_loss,
                                      step=self.global_episode)

                self.collect_metrics(*args)
        pbar.close()

    def collect_metrics(self,
                        X_val: np.ndarray,
                        y_val: np.ndarray,
                        save_best: str = None):
        """Collects metrics using the trained Q-network.

        :param X_val: Features of validation data, same shape as X_train
        :type  X_val: np.ndarray
        :param y_val: Labels of validation data, same shape as y_train
        :type  y_val: np.ndarray
        :param save_best: Saving the best model of all validation runs based on given metric:
            Choose one of: {Gmean, F1, Precision, Recall, TP, TN, FP, FN}
            This improves stability since the model at the last episode is not guaranteed to be the best model.
        :type  save_best: str
        """
        y_pred = network_predictions(self.agent._target_q_network, X_val)
        stats = classification_metrics(y_val, y_pred)
        avgQ = np.mean(decision_function(self.agent._target_q_network,
                                         X_val))  # Max action for each x in X

        if save_best is not None:
            if not hasattr(self, "best_score"):  # If no best model yet
                self.best_score = 0.0

            if stats.get(save_best) >= self.best_score:  # Overwrite best model
                self.save_network(
                )  # Saving directly to avoid shallow copy without trained weights
                self.best_score = stats.get(save_best)

        with self.writer.as_default():
            tf.summary.scalar(
                "AverageQ", avgQ,
                step=self.global_episode)  # Average Q-value for this epoch
            for k, v in stats.items():
                tf.summary.scalar(k, v, step=self.global_episode)

    def evaluate(self, X_test, y_test, X_train=None, y_train=None):
        """
        Final evaluation of trained Q-network with X_test and y_test.
        Optional PR and ROC curve comparison to X_train, y_train to ensure no overfitting is taking place.

        :param X_test: Features of test data, same shape as X_train
        :type  X_test: np.ndarray
        :param y_test: Labels of test data, same shape as y_train
        :type  y_test: np.ndarray
        :param X_train: Features of train data
        :type  X_train: np.ndarray
        :param y_train: Labels of train data
        :type  y_train: np.ndarray
        """
        if hasattr(self, "best_score"):
            print(f"\033[92mBest score: {self.best_score:6f}!\033[0m")
            network = self.load_network(
                self.model_path)  # Load best saved model
        else:
            network = self.agent._target_q_network  # Load latest target model

        if (X_train is not None) and (y_train is not None):
            plot_pr_curve(network, X_test, y_test, X_train, y_train)
            plot_roc_curve(network, X_test, y_test, X_train, y_train)

        y_pred = network_predictions(network, X_test)
        return classification_metrics(y_test, y_pred)

    def save_network(self):
        """Saves Q-network as pickle to `model_path`."""
        with open(self.model_path, "wb") as f:  # Save Q-network as pickle
            pickle.dump(self.agent._target_q_network, f)

    @staticmethod
    def load_network(fp: str):
        """Static method to load Q-network pickle from given filepath.

        :param fp: Filepath to the saved pickle of the network
        :type  fp: str

        :returns: The network-object loaded from a pickle file.
        :rtype: tensorflow.keras.models.Model
        """
        with open(fp, "rb") as f:  # Load the Q-network
            network = pickle.load(f)
        return network