Exemplo n.º 1
0
class SAC(Algorithm):
    """
    Soft Actor-Critic (SAC) variant with stochastic policy and two Q-functions and two Q-targets (no V-function)

    .. seealso::
        [1] T. Haarnoja, A. Zhou, P. Abbeel, S. Levine, "Soft Actor-Critic: Off-Policy Maximum Entropy Deep
        Reinforcement Learning with a Stochastic Actor", ICML, 2018

        [2] This implementation was inspired by https://github.com/pranz24/pytorch-soft-actor-critic
            which is seems to be based on https://github.com/vitchyr/rlkit
    """

    name: str = 'sac'

    def __init__(self,
                 save_dir: str,
                 env: Env,
                 policy: TwoHeadedPolicy,
                 q_fcn_1: Policy,
                 q_fcn_2: Policy,
                 memory_size: int,
                 gamma: float,
                 max_iter: int,
                 num_batch_updates: int,
                 tau: float = 0.995,
                 alpha_init: float = 0.2,
                 learn_alpha: bool = True,
                 target_update_intvl: int = 1,
                 standardize_rew: bool = True,
                 batch_size: int = 500,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 num_sampler_envs: int = 4,
                 max_grad_norm: float = 5.,
                 lr: float = 3e-4,
                 lr_scheduler=None,
                 lr_scheduler_hparam: [dict, None] = None,
                 logger: StepLogger = None):
        """
        Constructor

        :param save_dir: directory to save the snapshots i.e. the results in
        :param env: the environment which the policy operates
        :param policy: policy to be updated
        :param q_fcn_1: state-action value function $Q(s,a)$, the associated target Q-functions is created from a
                        re-initialized copies of this one
        :param q_fcn_2: state-action value function $Q(s,a)$, the associated target Q-functions is created from a
                        re-initialized copies of this one
        :param memory_size: number of transitions in the replay memory buffer, e.g. 1000000
        :param gamma: temporal discount factor for the state values
        :param max_iter: number of iterations (policy updates)
        :param num_batch_updates: number of batch updates per algorithm steps
        :param tau: interpolation factor in averaging for target networks, update used for the soft update a.k.a. polyak
                    update, between 0 and 1
        :param alpha_init: initial weighting factor of the entropy term in the loss function
        :param learn_alpha: adapt the weighting factor of the entropy term
        :param target_update_intvl: number of iterations that pass before updating the target network
        :param standardize_rew: bool to flag if the rewards should be standardized
        :param batch_size: number of samples per policy update batch
        :param min_rollouts: minimum number of rollouts sampled per policy update batch
        :param min_steps: minimum number of state transitions sampled per policy update batch
        :param num_sampler_envs: number of environments for parallel sampling
        :param max_grad_norm: maximum L2 norm of the gradients for clipping, set to `None` to disable gradient clipping
        :param lr: (initial) learning rate for the optimizer which can be by modified by the scheduler.
                   By default, the learning rate is constant.
        :param lr_scheduler: learning rate scheduler type for the policy and the Q-functions that does one step
                             per `update()` call
        :param lr_scheduler_hparam: hyper-parameters for the learning rate scheduler
        :param logger: logger for every step of the algorithm, if `None` the default logger will be created
        """
        if not isinstance(env, Env):
            raise pyrado.TypeErr(given=env, expected_type=Env)
        if typed_env(env, ActNormWrapper) is None:
            raise pyrado.TypeErr(
                msg='SAC required an environment wrapped by an ActNormWrapper!'
            )
        if not isinstance(q_fcn_1, Policy):
            raise pyrado.TypeErr(given=q_fcn_1, expected_type=Policy)
        if not isinstance(q_fcn_2, Policy):
            raise pyrado.TypeErr(given=q_fcn_2, expected_type=Policy)

        if logger is None:
            # Create logger that only logs every 100 steps of the algorithm
            logger = StepLogger(print_interval=100)
            logger.printers.append(ConsolePrinter())
            logger.printers.append(
                CSVPrinter(osp.join(save_dir, 'progress.csv')))

        # Call Algorithm's constructor
        super().__init__(save_dir, max_iter, policy, logger)

        # Store the inputs
        self._env = env
        self.q_fcn_1 = q_fcn_1
        self.q_fcn_2 = q_fcn_2
        self.q_targ_1 = deepcopy(self.q_fcn_1)
        self.q_targ_2 = deepcopy(self.q_fcn_2)
        self.q_targ_1.eval()
        self.q_targ_2.eval()
        self.gamma = gamma
        self.tau = tau
        self.learn_alpha = learn_alpha
        self.target_update_intvl = target_update_intvl
        self.standardize_rew = standardize_rew
        self.num_batch_updates = num_batch_updates
        self.batch_size = batch_size
        self.max_grad_norm = max_grad_norm

        # Initialize
        self._memory = ReplayMemory(memory_size)
        if policy.is_recurrent:
            init_expl_policy = RecurrentDummyPolicy(env.spec,
                                                    policy.hidden_size)
        else:
            init_expl_policy = DummyPolicy(env.spec)
        self.sampler_init = ParallelSampler(
            env,
            init_expl_policy,  # samples uniformly random from the action space
            num_envs=num_sampler_envs,
            min_steps=memory_size,
        )
        self._expl_strat = SACExplStrat(
            self._policy,
            std_init=1.)  # std_init will be overwritten by 2nd policy head
        self.sampler = ParallelSampler(
            env,
            self._expl_strat,
            num_envs=1,
            min_steps=min_steps,  # in [2] this would be 1
            min_rollouts=min_rollouts  # in [2] this would be None
        )
        self.sampler_eval = ParallelSampler(env,
                                            self._policy,
                                            num_envs=num_sampler_envs,
                                            min_steps=100 * env.max_steps,
                                            min_rollouts=None)
        self._optim_policy = to.optim.Adam([{
            'params': self._policy.parameters()
        }],
                                           lr=lr)
        self._optim_q_fcn_1 = to.optim.Adam(
            [{
                'params': self.q_fcn_1.parameters()
            }], lr=lr)
        self._optim_q_fcn_2 = to.optim.Adam(
            [{
                'params': self.q_fcn_2.parameters()
            }], lr=lr)
        log_alpha_init = to.log(
            to.tensor(alpha_init, dtype=to.get_default_dtype()))
        if learn_alpha:
            # Automatic entropy tuning
            self._log_alpha = nn.Parameter(log_alpha_init, requires_grad=True)
            self._alpha_optim = to.optim.Adam([{
                'params': self._log_alpha
            }],
                                              lr=lr)
            self.target_entropy = -to.prod(to.tensor(env.act_space.shape))
        else:
            self._log_alpha = log_alpha_init

        self._lr_scheduler_policy = lr_scheduler
        self._lr_scheduler_hparam = lr_scheduler_hparam
        if lr_scheduler is not None:
            self._lr_scheduler_policy = lr_scheduler(self._optim_policy,
                                                     **lr_scheduler_hparam)
            self._lr_scheduler_q_fcn_1 = lr_scheduler(self._optim_q_fcn_1,
                                                      **lr_scheduler_hparam)
            self._lr_scheduler_q_fcn_2 = lr_scheduler(self._optim_q_fcn_2,
                                                      **lr_scheduler_hparam)

    @property
    def expl_strat(self) -> SACExplStrat:
        return self._expl_strat

    @property
    def memory(self) -> ReplayMemory:
        """ Get the replay memory. """
        return self._memory

    @property
    def alpha(self) -> to.Tensor:
        """ Get the detached entropy coefficient. """
        return to.exp(self._log_alpha.detach())

    def step(self, snapshot_mode: str, meta_info: dict = None):
        if self._memory.isempty:
            # Warm-up phase
            print_cbt(
                'Collecting samples until replay memory contains if full.',
                'w')
            # Sample steps and store them in the replay memory
            ros = self.sampler_init.sample()
            self._memory.push(ros)
        else:
            # Sample steps and store them in the replay memory
            ros = self.sampler.sample()
            self._memory.push(ros)

        # Log return-based metrics
        if self._curr_iter % self.logger.print_interval == 0:
            ros = self.sampler_eval.sample()
            rets = [ro.undiscounted_return() for ro in ros]
            ret_max = np.max(rets)
            ret_med = np.median(rets)
            ret_avg = np.mean(rets)
            ret_min = np.min(rets)
            ret_std = np.std(rets)
        else:
            ret_max, ret_med, ret_avg, ret_min, ret_std = 5 * [
                -pyrado.inf
            ]  # dummy values
        self.logger.add_value('max return', np.round(ret_max, 4))
        self.logger.add_value('median return', np.round(ret_med, 4))
        self.logger.add_value('avg return', np.round(ret_avg, 4))
        self.logger.add_value('min return', np.round(ret_min, 4))
        self.logger.add_value('std return', np.round(ret_std, 4))
        self.logger.add_value('avg rollout length',
                              np.round(np.mean([ro.length for ro in ros]), 2))
        self.logger.add_value('num rollouts', len(ros))
        self.logger.add_value('avg memory reward',
                              np.round(self._memory.avg_reward(), 4))

        # Use data in the memory to update the policy and the Q-functions
        self.update()

        # Save snapshot data
        self.make_snapshot(snapshot_mode, float(ret_avg), meta_info)

    @staticmethod
    def soft_update(target: nn.Module, source: nn.Module, tau: float = 0.995):
        """
        Moving average update, a.k.a. Polyak update.
        Modifies the input argument `target`.

        :param target: PyTroch module with parameters to be updated
        :param source: PyTroch module with parameters to update to
        :param tau: interpolation factor for averaging, between 0 and 1
        """
        if not 0 < tau < 1:
            raise pyrado.ValueErr(given=tau,
                                  g_constraint='0',
                                  l_constraint='1')

        for targ_param, src_param in zip(target.parameters(),
                                         source.parameters()):
            targ_param.data = targ_param.data * tau + src_param.data * (1. -
                                                                        tau)

    def update(self):
        """ Update the policy's and Q-functions' parameters on transitions sampled from the replay memory. """
        # Containers for logging
        policy_losses = to.zeros(self.num_batch_updates)
        expl_strat_stds = to.zeros(self.num_batch_updates)
        q_fcn_1_losses = to.zeros(self.num_batch_updates)
        q_fcn_2_losses = to.zeros(self.num_batch_updates)
        policy_grad_norm = to.zeros(self.num_batch_updates)
        q_fcn_1_grad_norm = to.zeros(self.num_batch_updates)
        q_fcn_2_grad_norm = to.zeros(self.num_batch_updates)

        for b in tqdm(range(self.num_batch_updates),
                      total=self.num_batch_updates,
                      desc=f'Updating',
                      unit='batches',
                      file=sys.stdout,
                      leave=False):

            # Sample steps and the associated next step from the replay memory
            steps, next_steps = self._memory.sample(self.batch_size)
            steps.torch(data_type=to.get_default_dtype())
            next_steps.torch(data_type=to.get_default_dtype())

            # Standardize rewards
            if self.standardize_rew:
                rewards = standardize(steps.rewards).unsqueeze(1)
            else:
                rewards = steps.rewards.unsqueeze(1)
            rew_scale = 1.
            rewards *= rew_scale

            with to.no_grad():
                # Create masks for the non-final observations
                not_done = to.tensor(1. - steps.done,
                                     dtype=to.get_default_dtype()).unsqueeze(1)

                # Compute the (next)state-(next)action values Q(s',a') from the target networks
                if self.policy.is_recurrent:
                    next_act_expl, next_log_probs, _ = self._expl_strat(
                        next_steps.observations, next_steps.hidden_states)
                else:
                    next_act_expl, next_log_probs = self._expl_strat(
                        next_steps.observations)
                next_q_val_target_1 = self.q_targ_1(
                    to.cat([next_steps.observations, next_act_expl], dim=1))
                next_q_val_target_2 = self.q_targ_2(
                    to.cat([next_steps.observations, next_act_expl], dim=1))
                next_q_val_target_min = to.min(
                    next_q_val_target_1,
                    next_q_val_target_2) - self.alpha * next_log_probs
                next_q_val = rewards + not_done * self.gamma * next_q_val_target_min

            # Compute the two Q-function losses
            # E_{(s_t, a_t) ~ D} [1/2 * (Q_i(s_t, a_t) - r_t - gamma * E_{s_{t+1} ~ p} [V(s_{t+1})] )^2]
            q_val_1 = self.q_fcn_1(
                to.cat([steps.observations, steps.actions], dim=1))
            q_val_2 = self.q_fcn_2(
                to.cat([steps.observations, steps.actions], dim=1))
            q_1_loss = nn.functional.mse_loss(q_val_1, next_q_val)
            q_2_loss = nn.functional.mse_loss(q_val_2, next_q_val)
            q_fcn_1_losses[b] = q_1_loss.data
            q_fcn_2_losses[b] = q_2_loss.data

            # Compute the policy loss
            # E_{s_t ~ D, eps_t ~ N} [log( pi( f(eps_t; s_t) ) ) - Q(s_t, f(eps_t; s_t))]
            if self.policy.is_recurrent:
                act_expl, log_probs, _ = self._expl_strat(
                    steps.observations, steps.hidden_states)
            else:
                act_expl, log_probs = self._expl_strat(steps.observations)
            q1_pi = self.q_fcn_1(to.cat([steps.observations, act_expl], dim=1))
            q2_pi = self.q_fcn_2(to.cat([steps.observations, act_expl], dim=1))
            min_q_pi = to.min(q1_pi, q2_pi)
            policy_loss = to.mean(self.alpha * log_probs - min_q_pi)
            policy_losses[b] = policy_loss.data
            expl_strat_stds[b] = to.mean(self._expl_strat.std.data)

            # Do one optimization step for each optimizer, and clip the gradients if desired
            # Q-fcn 1
            self._optim_q_fcn_1.zero_grad()
            q_1_loss.backward()
            q_fcn_1_grad_norm[b] = self.clip_grad(self.q_fcn_1, None)
            self._optim_q_fcn_1.step()
            # Q-fcn 2
            self._optim_q_fcn_2.zero_grad()
            q_2_loss.backward()
            q_fcn_2_grad_norm[b] = self.clip_grad(self.q_fcn_2, None)
            self._optim_q_fcn_2.step()
            # Policy
            self._optim_policy.zero_grad()
            policy_loss.backward()
            policy_grad_norm[b] = self.clip_grad(self._expl_strat.policy,
                                                 self.max_grad_norm)
            self._optim_policy.step()

            if self.learn_alpha:
                # Compute entropy coefficient loss
                alpha_loss = -to.mean(
                    self._log_alpha *
                    (log_probs.detach() + self.target_entropy))
                # Do one optimizer step for the entropy coefficient optimizer
                self._alpha_optim.zero_grad()
                alpha_loss.backward()
                self._alpha_optim.step()

            # Soft-update the target networks
            if (self._curr_iter * self.num_batch_updates +
                    b) % self.target_update_intvl == 0:
                SAC.soft_update(self.q_targ_1, self.q_fcn_1, self.tau)
                SAC.soft_update(self.q_targ_2, self.q_fcn_2, self.tau)

        # Update the learning rate if the schedulers have been specified
        if self._lr_scheduler_policy is not None:
            self._lr_scheduler_policy.step()
            self._lr_scheduler_q_fcn_1.step()
            self._lr_scheduler_q_fcn_2.step()

        # Logging
        self.logger.add_value('Q1 loss', to.mean(q_fcn_1_losses).item())
        self.logger.add_value('Q2 loss', to.mean(q_fcn_2_losses).item())
        self.logger.add_value('policy loss', to.mean(policy_losses).item())
        self.logger.add_value('avg policy grad norm',
                              to.mean(policy_grad_norm).item())
        self.logger.add_value('avg expl strat std',
                              to.mean(expl_strat_stds).item())
        self.logger.add_value('alpha', self.alpha.item())
        if self._lr_scheduler_policy is not None:
            self.logger.add_value('learning rate',
                                  self._lr_scheduler_policy.get_lr())

    def save_snapshot(self, meta_info: dict = None):
        super().save_snapshot(meta_info)

        if meta_info is None:
            # This instance is not a subroutine of a meta-algorithm
            joblib.dump(self._env, osp.join(self._save_dir, 'env.pkl'))
            to.save(self.q_targ_1, osp.join(self._save_dir, 'target1.pt'))
            to.save(self.q_targ_2, osp.join(self._save_dir, 'target2.pt'))
        else:
            # This algorithm instance is a subroutine of a meta-algorithm
            if 'prefix' in meta_info and 'suffix' in meta_info:
                to.save(
                    self.q_targ_1,
                    osp.join(
                        self._save_dir,
                        f"{meta_info['prefix']}_target1_{meta_info['suffix']}.pt"
                    ))
                to.save(
                    self.q_targ_2,
                    osp.join(
                        self._save_dir,
                        f"{meta_info['prefix']}_target2_{meta_info['suffix']}.pt"
                    ))
            elif 'prefix' in meta_info and 'suffix' not in meta_info:
                to.save(
                    self.q_targ_1,
                    osp.join(self._save_dir,
                             f"{meta_info['prefix']}_target1.pt"))
                to.save(
                    self.q_targ_2,
                    osp.join(self._save_dir,
                             f"{meta_info['prefix']}_target2.pt"))
            elif 'prefix' not in meta_info and 'suffix' in meta_info:
                to.save(
                    self.q_targ_1,
                    osp.join(self._save_dir,
                             f"target1_{meta_info['suffix']}.pt"))
                to.save(
                    self.q_targ_2,
                    osp.join(self._save_dir,
                             f"target2_{meta_info['suffix']}.pt"))
            else:
                raise NotImplementedError

    def load_snapshot(self, load_dir: str = None, meta_info: dict = None):
        # Get the directory to load from
        ld = load_dir if load_dir is not None else self._save_dir
        super().load_snapshot(ld, meta_info)

        if meta_info is None:
            # This algorithm instance is not a subroutine of a meta-algorithm
            self._env = joblib.load(osp.join(ld, 'env.pkl'))
            self.q_targ_1.load_state_dict(
                to.load(osp.join(ld, 'target1.pt')).state_dict())
            self.q_targ_2.load_state_dict(
                to.load(osp.join(ld, 'target2.pt')).state_dict())
        else:
            # This algorithm instance is a subroutine of a meta-algorithm
            if 'prefix' in meta_info and 'suffix' in meta_info:
                self.q_targ_1.load_state_dict(
                    to.load(
                        osp.join(
                            ld,
                            f"{meta_info['prefix']}_target1_{meta_info['suffix']}.pt"
                        )).state_dict())
                self.q_targ_2.load_state_dict(
                    to.load(
                        osp.join(
                            ld,
                            f"{meta_info['prefix']}_target2_{meta_info['suffix']}.pt"
                        )).state_dict())
            elif 'prefix' in meta_info and 'suffix' not in meta_info:
                self.q_targ_1.load_state_dict(
                    to.load(osp.join(
                        ld, f"{meta_info['prefix']}_target1.pt")).state_dict())
                self.q_targ_2.load_state_dict(
                    to.load(osp.join(
                        ld, f"{meta_info['prefix']}_target2.pt")).state_dict())
            elif 'prefix' not in meta_info and 'suffix' in meta_info:
                self.q_targ_1.load_state_dict(
                    to.load(osp.join(
                        ld, f"target1_{meta_info['suffix']}.pt")).state_dict())
                self.q_targ_2.load_state_dict(
                    to.load(osp.join(
                        ld, f"target2_{meta_info['suffix']}.pt")).state_dict())
            else:
                raise NotImplementedError

    def reset(self, seed: int = None):
        # Reset the exploration strategy, internal variables and the random seeds
        super().reset(seed)

        # Re-initialize sampler in case env or policy changed
        self.sampler.reinit()

        # Reset the replay memory
        self._memory.reset()

        # Reset the learning rate schedulers
        if self._lr_scheduler_policy is not None:
            self._lr_scheduler_policy.last_epoch = -1
        if self._lr_scheduler_q_fcn_1 is not None:
            self._lr_scheduler_q_fcn_1.last_epoch = -1
        if self._lr_scheduler_q_fcn_2 is not None:
            self._lr_scheduler_q_fcn_2.last_epoch = -1
Exemplo n.º 2
0
class DQL(Algorithm):
    """
    Deep Q-Learning (without bells and whistles)

    .. seealso::
        [1] V. Mnih et.al., "Human-level control through deep reinforcement learning", Nature, 2015
    """

    name: str = 'dql'

    def __init__(self,
                 save_dir: str,
                 env: Env,
                 policy: DiscrActQValFNNPolicy,
                 memory_size: int,
                 eps_init: float,
                 eps_schedule_gamma: float,
                 gamma: float,
                 max_iter: int,
                 num_batch_updates: int,
                 target_update_intvl: int = 5,
                 min_rollouts: int = None,
                 min_steps: int = None,
                 batch_size: int = 256,
                 num_sampler_envs: int = 4,
                 max_grad_norm: float = 0.5,
                 lr: float = 5e-4,
                 lr_scheduler=None,
                 lr_scheduler_hparam: [dict, None] = None,
                 logger: StepLogger = None):
        """
        Constructor

        :param save_dir: directory to save the snapshots i.e. the results in
        :param env: environment which the policy operates
        :param policy: (current) Q-network updated by this algorithm
        :param memory_size: number of transitions in the replay memory buffer
        :param eps_init: initial value for the probability of taking a random action, constant if `eps_schedule_gamma==1`
        :param eps_schedule_gamma: temporal discount factor for the exponential decay of epsilon
        :param gamma: temporal discount factor for the state values
        :param max_iter: number of iterations (policy updates)
        :param num_batch_updates: number of batch updates per algorithm steps
        :param target_update_intvl: number of iterations that pass before updating the target network
        :param min_rollouts: minimum number of rollouts sampled per policy update batch
        :param min_steps: minimum number of state transitions sampled per policy update batch
        :param batch_size: number of samples per policy update batch
        :param num_sampler_envs: number of environments for parallel sampling
        :param max_grad_norm: maximum L2 norm of the gradients for clipping, set to `None` to disable gradient clipping
        :param lr: (initial) learning rate for the optimizer which can be by modified by the scheduler.
                   By default, the learning rate is constant.
        :param lr_scheduler: learning rate scheduler that does one step per epoch (pass through the whole data set)
        :param lr_scheduler_hparam: hyper-parameters for the learning rate scheduler
        :param logger: logger for every step of the algorithm
        """
        if not isinstance(env, Env):
            raise pyrado.TypeErr(given=env, expected_type=Env)
        if not isinstance(policy, DiscrActQValFNNPolicy):
            raise pyrado.TypeErr(given=policy, expected_type=DiscrActQValFNNPolicy)

        if logger is None:
            # Create logger that only logs every 100 steps of the algorithm
            logger = StepLogger(print_interval=100)
            logger.printers.append(ConsolePrinter())
            logger.printers.append(CSVPrinter(osp.join(save_dir, 'progress.csv')))

        # Call Algorithm's constructor
        super().__init__(save_dir, max_iter, policy, logger)

        # Store the inputs
        self._env = env
        self.target = deepcopy(self._policy)
        self.target.eval()  # will not be trained using the optimizer
        self._memory_size = memory_size
        self.eps = eps_init
        self.gamma = gamma
        self.target_update_intvl = target_update_intvl
        self.num_batch_updates = num_batch_updates
        self.batch_size = batch_size
        self.max_grad_norm = max_grad_norm

        # Initialize
        self._expl_strat = EpsGreedyExplStrat(self._policy, eps_init, eps_schedule_gamma)
        self._memory = ReplayMemory(memory_size)
        self.sampler = ParallelSampler(
            env, self._expl_strat,
            num_envs=1,
            min_steps=min_steps,
            min_rollouts=min_rollouts
        )
        self.sampler_eval = ParallelSampler(
            env, self._policy,
            num_envs=num_sampler_envs,
            min_steps=100*env.max_steps,
            min_rollouts=None
        )
        self.optim = to.optim.RMSprop([{'params': self._policy.parameters()}], lr=lr)
        self._lr_scheduler = lr_scheduler
        self._lr_scheduler_hparam = lr_scheduler_hparam
        if lr_scheduler is not None:
            self._lr_scheduler = lr_scheduler(self.optim, **lr_scheduler_hparam)

    @property
    def expl_strat(self) -> EpsGreedyExplStrat:
        return self._expl_strat

    @property
    def memory(self) -> ReplayMemory:
        """ Get the replay memory. """
        return self._memory

    def step(self, snapshot_mode: str, meta_info: dict = None):
        # Sample steps and store them in the replay memory
        ros = self.sampler.sample()
        self._memory.push(ros)

        while len(self._memory) < self.memory.capacity:
            # Warm-up phase
            print_cbt('Collecting samples until replay memory contains if full.', 'w')
            # Sample steps and store them in the replay memory
            ros = self.sampler.sample()
            self._memory.push(ros)

        # Log return-based metrics
        if self._curr_iter % self.logger.print_interval == 0:
            ros = self.sampler_eval.sample()
            rets = [ro.undiscounted_return() for ro in ros]
            ret_max = np.max(rets)
            ret_med = np.median(rets)
            ret_avg = np.mean(rets)
            ret_min = np.min(rets)
            ret_std = np.std(rets)
        else:
            ret_max, ret_med, ret_avg, ret_min, ret_std = 5*[-pyrado.inf]  # dummy values
        self.logger.add_value('max return', np.round(ret_max, 4))
        self.logger.add_value('median return', np.round(ret_med, 4))
        self.logger.add_value('avg return', np.round(ret_avg, 4))
        self.logger.add_value('min return', np.round(ret_min, 4))
        self.logger.add_value('std return', np.round(ret_std, 4))
        self.logger.add_value('avg rollout length', np.round(np.mean([ro.length for ro in ros]), 2))
        self.logger.add_value('num rollouts', len(ros))
        self.logger.add_value('avg memory reward', np.round(self._memory.avg_reward(), 4))

        # Use data in the memory to update the policy and the target Q-function
        self.update()

        # Save snapshot data
        self.make_snapshot(snapshot_mode, float(ret_avg), meta_info)

    def loss_fcn(self, q_vals: to.Tensor, expected_q_vals: to.Tensor) -> to.Tensor:
        r"""
        The Huber loss function on the one-step TD error $\delta = Q(s,a) - (r + \gamma \max_a Q(s^\prime, a))$.

        :param q_vals: state-action values $Q(s,a)$, from policy network
        :param expected_q_vals: expected state-action values $r + \gamma \max_a Q(s^\prime, a)$, from target network
        :return: loss value
        """
        return nn.functional.smooth_l1_loss(q_vals, expected_q_vals)

    def update(self):
        """ Update the policy's and target Q-function's parameters on transitions sampled from the replay memory. """
        losses = to.zeros(self.num_batch_updates)
        policy_grad_norm = to.zeros(self.num_batch_updates)

        for b in tqdm(range(self.num_batch_updates), total=self.num_batch_updates,
                      desc=f'Updating', unit='batches', file=sys.stdout, leave=False):

            # Sample steps and the associated next step from the replay memory
            steps, next_steps = self._memory.sample(self.batch_size)
            steps.torch(data_type=to.get_default_dtype())
            next_steps.torch(data_type=to.get_default_dtype())

            # Create masks for the non-final observations
            not_done = to.tensor(1. - steps.done, dtype=to.get_default_dtype())

            # Compute the state-action values Q(s,a) using the current DQN policy
            q_vals = self.expl_strat.policy.q_values_chosen(steps.observations)

            # Compute the second term of TD-error
            next_v_vals = self.target.q_values_chosen(next_steps.observations).detach()
            expected_q_val = steps.rewards + not_done*self.gamma*next_v_vals

            # Compute the loss, clip the gradients if desired, and do one optimization step
            loss = self.loss_fcn(q_vals, expected_q_val)
            losses[b] = loss.data
            self.optim.zero_grad()
            loss.backward()
            policy_grad_norm[b] = self.clip_grad(self.expl_strat.policy, self.max_grad_norm)
            self.optim.step()

            # Update the target network by copying all weights and biases from the DQN policy
            if (self._curr_iter*self.num_batch_updates + b)%self.target_update_intvl == 0:
                self.target.load_state_dict(self.expl_strat.policy.state_dict())

        # Schedule the exploration parameter epsilon
        self.expl_strat.schedule_eps(self._curr_iter)

        # Update the learning rate if a scheduler has been specified
        if self._lr_scheduler is not None:
            self._lr_scheduler.step()

        # Logging
        with to.no_grad():
            self.logger.add_value('loss after', to.mean(losses).item())
        self.logger.add_value('expl strat eps', self.expl_strat.eps.item())
        self.logger.add_value('avg policy grad norm', to.mean(policy_grad_norm).item())
        if self._lr_scheduler is not None:
            self.logger.add_value('learning rate', self._lr_scheduler.get_lr())

    def save_snapshot(self, meta_info: dict = None):
        super().save_snapshot(meta_info)

        if meta_info is None:
            # This instance is not a subroutine of a meta-algorithm
            joblib.dump(self._env, osp.join(self._save_dir, 'env.pkl'))
            to.save(self.target, osp.join(self._save_dir, 'target.pt'))
        else:
            # This algorithm instance is a subroutine of a meta-algorithm
            if 'prefix' in meta_info and 'suffix' in meta_info:
                to.save(self.target,
                        osp.join(self._save_dir, f"{meta_info['prefix']}_target_{meta_info['suffix']}.pt"))
            elif 'prefix' in meta_info and 'suffix' not in meta_info:
                to.save(self.target, osp.join(self._save_dir, f"{meta_info['prefix']}_target.pt"))
            elif 'prefix' not in meta_info and 'suffix' in meta_info:
                to.save(self.target, osp.join(self._save_dir, f"target_{meta_info['suffix']}.pt"))
            else:
                raise NotImplementedError

    def load_snapshot(self, load_dir: str = None, meta_info: dict = None):
        # Get the directory to load from
        ld = load_dir if load_dir is not None else self._save_dir
        super().load_snapshot(ld, meta_info)

        if meta_info is None:
            # This algorithm instance is not a subroutine of a meta-algorithm
            self._env = joblib.load(osp.join(ld, 'env.pkl'))
            self.target.load_state_dict(to.load(osp.join(ld, 'target.pt')).state_dict())
        else:
            # This algorithm instance is a subroutine of a meta-algorithm
            if 'prefix' in meta_info and 'suffix' in meta_info:
                self.target.load_state_dict(
                    to.load(osp.join(ld, f"{meta_info['prefix']}_target_{meta_info['suffix']}.pt")).state_dict()
                )
            elif 'prefix' in meta_info and 'suffix' not in meta_info:
                self.target.load_state_dict(
                    to.load(osp.join(ld, f"{meta_info['prefix']}_target.pt")).state_dict()
                )
            elif 'prefix' not in meta_info and 'suffix' in meta_info:
                self.target.load_state_dict(
                    to.load(osp.join(ld, f"target_{meta_info['suffix']}.pt")).state_dict()
                )
            else:
                raise NotImplementedError

    def reset(self, seed: int = None):
        # Reset the exploration strategy, internal variables and the random seeds
        super().reset(seed)

        # Re-initialize sampler in case env or policy changed
        self.sampler.reinit()

        # Reset the replay memory
        self._memory.reset()

        # Reset the learning rate scheduler
        if self._lr_scheduler is not None:
            self._lr_scheduler.last_epoch = -1