def __init__(self):
        # #### Configurations

        # number of workers
        self.n_workers = 8
        # steps sampled on each update
        self.worker_steps = 4
        # number of training iterations
        self.train_epochs = 8

        # number of updates
        self.updates = 1_000_000
        # size of mini batch for training
        self.mini_batch_size = 32

        # exploration as a function of updates
        self.exploration_coefficient = Piecewise(
            [
                (0, 1.0),
                (25_000, 0.1),
                (self.updates / 2, 0.01)
            ], outside_value=0.01)

        # update target network every 250 update
        self.update_target_model = 250

        # $\beta$ for replay buffer as a function of updates
        self.prioritized_replay_beta = Piecewise(
            [
                (0, 0.4),
                (self.updates, 1)
            ], outside_value=1)

        # Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)

        # Model for sampling and training
        self.model = Model().to(device)
        # target model to get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
        self.target_model = Model().to(device)

        # create workers
        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

        # initialize tensors for observations
        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
        for worker in self.workers:
            worker.child.send(("reset", None))
        for i, worker in enumerate(self.workers):
            self.obs[i] = worker.child.recv()

        # loss function
        self.loss_func = QFuncLoss(0.99)
        # optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)
class Trainer:
    """
    ## Trainer
    """

    def __init__(self):
        # #### Configurations

        # number of workers
        self.n_workers = 8
        # steps sampled on each update
        self.worker_steps = 4
        # number of training iterations
        self.train_epochs = 8

        # number of updates
        self.updates = 1_000_000
        # size of mini batch for training
        self.mini_batch_size = 32

        # exploration as a function of updates
        self.exploration_coefficient = Piecewise(
            [
                (0, 1.0),
                (25_000, 0.1),
                (self.updates / 2, 0.01)
            ], outside_value=0.01)

        # update target network every 250 update
        self.update_target_model = 250

        # $\beta$ for replay buffer as a function of updates
        self.prioritized_replay_beta = Piecewise(
            [
                (0, 0.4),
                (self.updates, 1)
            ], outside_value=1)

        # Replay buffer with $\alpha = 0.6$. Capacity of the replay buffer must be a power of 2.
        self.replay_buffer = ReplayBuffer(2 ** 14, 0.6)

        # Model for sampling and training
        self.model = Model().to(device)
        # target model to get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
        self.target_model = Model().to(device)

        # create workers
        self.workers = [Worker(47 + i) for i in range(self.n_workers)]

        # initialize tensors for observations
        self.obs = np.zeros((self.n_workers, 4, 84, 84), dtype=np.uint8)
        for worker in self.workers:
            worker.child.send(("reset", None))
        for i, worker in enumerate(self.workers):
            self.obs[i] = worker.child.recv()

        # loss function
        self.loss_func = QFuncLoss(0.99)
        # optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=2.5e-4)

    def _sample_action(self, q_value: torch.Tensor, exploration_coefficient: float):
        """
        #### $\epsilon$-greedy Sampling
        When sampling actions we use a $\epsilon$-greedy strategy, where we
        take a greedy action with probabiliy $1 - \epsilon$ and
        take a random action with probability $\epsilon$.
        We refer to $\epsilon$ as `exploration_coefficient`.
        """

        # Sampling doesn't need gradients
        with torch.no_grad():
            # Sample the action with highest Q-value. This is the greedy action.
            greedy_action = torch.argmax(q_value, dim=-1)
            # Uniformly sample and action
            random_action = torch.randint(q_value.shape[-1], greedy_action.shape, device=q_value.device)
            # Whether to chose greedy action or the random action
            is_choose_rand = torch.rand(greedy_action.shape, device=q_value.device) < exploration_coefficient
            # Pick the action based on `is_choose_rand`
            return torch.where(is_choose_rand, random_action, greedy_action).cpu().numpy()

    def sample(self, exploration_coefficient: float):
        """### Sample data"""

        # This doesn't need gradients
        with torch.no_grad():
            # Sample `worker_steps`
            for t in range(self.worker_steps):
                # Get Q_values for the current observation
                q_value = self.model(obs_to_torch(self.obs))
                # Sample actions
                actions = self._sample_action(q_value, exploration_coefficient)

                # Run sampled actions on each worker
                for w, worker in enumerate(self.workers):
                    worker.child.send(("step", actions[w]))

                # Collect information from each worker
                for w, worker in enumerate(self.workers):
                    # Get results after executing the actions
                    next_obs, reward, done, info = worker.child.recv()

                    # Add transition to replay buffer
                    self.replay_buffer.add(self.obs[w], actions[w], reward, next_obs, done)

                    # update episode information. 
                    # collect episode info, which is available if an episode finished;
                    #  this includes total reward and length of the episode -
                    #  look at `Game` to see how it works.
                    if info:
                        tracker.add('reward', info['reward'])
                        tracker.add('length', info['length'])

                    # update current observation
                    self.obs[w] = next_obs

    def train(self, beta: float):
        """
        ### Train the model
        """
        for _ in range(self.train_epochs):
            # Sample from priority replay buffer
            samples = self.replay_buffer.sample(self.mini_batch_size, beta)
            # Get the predicted Q-value
            q_value = self.model(obs_to_torch(samples['obs']))

            # Get the Q-values of the next state for [Double Q-learning](index.html).
            # Gradients shouldn't propagate for these
            with torch.no_grad():
                # Get $\color{cyan}Q(s';\color{cyan}{\theta_i})$
                double_q_value = self.model(obs_to_torch(samples['next_obs']))
                # Get $\color{orange}Q(s';\color{orange}{\theta_i^{-}})$
                target_q_value = self.target_model(obs_to_torch(samples['next_obs']))

            # Compute Temporal Difference (TD) errors, $\delta$, and the loss, $\mathcal{L}(\theta)$.
            td_errors, loss = self.loss_func(q_value,
                                             q_value.new_tensor(samples['action']),
                                             double_q_value, target_q_value,
                                             q_value.new_tensor(samples['done']),
                                             q_value.new_tensor(samples['reward']),
                                             q_value.new_tensor(samples['weights']))

            # Calculate priorities for replay buffer $p_i = |\delta_i| + \epsilon$
            new_priorities = np.abs(td_errors.cpu().numpy()) + 1e-6
            # Update replay buffer priorities
            self.replay_buffer.update_priorities(samples['indexes'], new_priorities)

            # Zero out the previously calculated gradients
            self.optimizer.zero_grad()
            # Calculate gradients
            loss.backward()
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
            # Update parameters based on gradients
            self.optimizer.step()

    def run_training_loop(self):
        """
        ### Run training loop
        """

        # Last 100 episode information
        tracker.set_queue('reward', 100, True)
        tracker.set_queue('length', 100, True)

        # Copy to target network initially
        self.target_model.load_state_dict(self.model.state_dict())

        for update in monit.loop(self.updates):
            # $\epsilon$, exploration fraction
            exploration = self.exploration_coefficient(update)
            tracker.add('exploration', exploration)
            # $\beta$ for prioritized replay
            beta = self.prioritized_replay_beta(update)
            tracker.add('beta', beta)

            # Sample with current policy
            self.sample(exploration)

            # Start training after the buffer is full
            if self.replay_buffer.is_full():
                # Train the model
                self.train(beta)

                # Periodically update target network
                if update % self.update_target_model == 0:
                    self.target_model.load_state_dict(self.model.state_dict())

            # Save tracked indicators.
            tracker.save()
            # Add a new line to the screen periodically
            if (update + 1) % 1_000 == 0:
                logger.log()