Esempio n. 1
0
    def __init__(self,
                 action_spec,
                 initial_alpha=0.01,
                 target_entropy=None,
                 slow_update_rate=0.01,
                 fast_update_rate=np.log(2),
                 min_alpha=1e-4,
                 debug_summaries=False):
        """Create an EntropyTargetAlgorithm

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            initial_alpha (float): initial value for alpha.
            target_entropy (float): the lower bound of the entropy. If not
                provided, a default value proportional to the action dimension
                is used.
            slow_update_rate (float): minimal update rate for log_alpha
            fast_update_rate (float): maximum update rate for log_alpha
            min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is
                used.
            optimizer (tf.optimizers.Optimizer): The optimizer for training. If
                not provided, will use the same optimizer of the parent
                algorithm.
            debug_summaries (bool): True if debug summaries should be created.
        """
        super().__init__(
            debug_summaries=debug_summaries, name="EntropyTargetAlgorithm")

        self._log_alpha = tf.Variable(
            name='log_alpha',
            initial_value=np.log(initial_alpha),
            dtype=tf.float32,
            trainable=False)
        self._stage = tf.Variable(
            name='stage', initial_value=-1, dtype=tf.int32, trainable=False)
        self._avg_entropy = ScalarWindowAverager(2)
        self._update_rate = tf.Variable(
            name='update_rate',
            initial_value=fast_update_rate,
            dtype=tf.float32,
            trainable=False)
        self._action_spec = action_spec
        self._min_log_alpha = -100.
        if min_alpha >= 0.:
            self._min_log_alpha = np.log(min_alpha)

        if target_entropy is None:
            flat_action_spec = tf.nest.flatten(self._action_spec)
            target_entropy = np.sum(
                list(map(calc_default_target_entropy, flat_action_spec)))
        if target_entropy > 0:
            self._fast_stage_thresh = 0.5 * target_entropy
        else:
            self._fast_stage_thresh = 2.0 * target_entropy
        self._target_entropy = target_entropy
        self._slow_update_rate = slow_update_rate
        self._fast_update_rate = fast_update_rate
        logging.info("target_entropy=%s" % target_entropy)
Esempio n. 2
0
class EntropyTargetAlgorithm(Algorithm):
    """Algorithm for adjust entropy regularization.

    It tries to adjust the entropy regularization (i.e. alpha) so that the
    the entropy is not smaller than `target_entropy`.

    The algorithm has two stages:
    1. init stage. During this stage, the alpha is not changed. It transitions
       to adjust_stage once entropy drops below `target_entropy`.
    2. adjust stage. During this stage, log_alpha is adjusted using this formula:
       ((below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate
       Note that log_alpha will always be decreased if entropy is increasing
       even when the entropy is below the target entropy. This is to prevent
       overshooting log_alpha to a too big value. Same reason for always
       increasing log_alpha even when the entropy is above the target entropy.
       `update_rate` is initialized to `fast_update_rate` and is reduced by a
       factor of 0.9 whenever the entropy crosses `target_entropy`. `udpate_rate`
       is reset to `fast_update_rate` if entropy drops too much below
       `target_entropy` (i.e., fast_stage_thresh in the code, which is the half
       of `target_entropy` if it is positive, and twice of `target_entropy` if
       it is negative.
    """

    def __init__(self,
                 action_spec,
                 initial_alpha=0.01,
                 target_entropy=None,
                 slow_update_rate=0.01,
                 fast_update_rate=np.log(2),
                 min_alpha=1e-4,
                 debug_summaries=False):
        """Create an EntropyTargetAlgorithm

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            initial_alpha (float): initial value for alpha.
            target_entropy (float): the lower bound of the entropy. If not
                provided, a default value proportional to the action dimension
                is used.
            slow_update_rate (float): minimal update rate for log_alpha
            fast_update_rate (float): maximum update rate for log_alpha
            min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is
                used.
            optimizer (tf.optimizers.Optimizer): The optimizer for training. If
                not provided, will use the same optimizer of the parent
                algorithm.
            debug_summaries (bool): True if debug summaries should be created.
        """
        super().__init__(
            debug_summaries=debug_summaries, name="EntropyTargetAlgorithm")

        self._log_alpha = tf.Variable(
            name='log_alpha',
            initial_value=np.log(initial_alpha),
            dtype=tf.float32,
            trainable=False)
        self._stage = tf.Variable(
            name='stage', initial_value=-1, dtype=tf.int32, trainable=False)
        self._avg_entropy = ScalarWindowAverager(2)
        self._update_rate = tf.Variable(
            name='update_rate',
            initial_value=fast_update_rate,
            dtype=tf.float32,
            trainable=False)
        self._action_spec = action_spec
        self._min_log_alpha = -100.
        if min_alpha >= 0.:
            self._min_log_alpha = np.log(min_alpha)

        if target_entropy is None:
            flat_action_spec = tf.nest.flatten(self._action_spec)
            target_entropy = np.sum(
                list(map(calc_default_target_entropy, flat_action_spec)))
        if target_entropy > 0:
            self._fast_stage_thresh = 0.5 * target_entropy
        else:
            self._fast_stage_thresh = 2.0 * target_entropy
        self._target_entropy = target_entropy
        self._slow_update_rate = slow_update_rate
        self._fast_update_rate = fast_update_rate
        logging.info("target_entropy=%s" % target_entropy)

    def train_step(self, distribution, step_type):
        """Train step.

        Args:
            distribution (nested Distribution): action distribution from the
                policy.
        Returns:
            AlgorithmStep. `info` field is LossInfo, other fields are empty.
        """
        entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
            distribution, self._action_spec)
        return AlgorithmStep(
            outputs=(),
            state=(),
            info=EntropyTargetInfo(
                step_type=step_type,
                loss=LossInfo(
                    loss=-entropy_for_gradient,
                    extra=EntropyTargetLossInfo(entropy_loss=-entropy))))

    def calc_loss(self, training_info: EntropyTargetInfo):
        loss_info = training_info.loss
        mask = tf.cast(training_info.step_type != StepType.LAST, tf.float32)
        entropy = -loss_info.extra.entropy_loss * mask
        num = tf.reduce_sum(mask)
        entropy2 = tf.reduce_sum(tf.square(entropy)) / num
        entropy = tf.reduce_sum(entropy) / num
        entropy_std = tf.sqrt(tf.maximum(0.0, entropy2 - entropy * entropy))
        prev_avg_entropy = self._avg_entropy.get()
        avg_entropy = self._avg_entropy.average(entropy)

        def _init():
            crossing = avg_entropy < self._target_entropy
            self._stage.assign_add(tf.cast(crossing, tf.int32))

        def _adjust():
            previous_above = tf.cast(self._stage, tf.bool)
            above = avg_entropy > self._target_entropy
            self._stage.assign(tf.cast(above, tf.int32))
            crossing = above != previous_above
            update_rate = self._update_rate
            update_rate = tf.where(crossing, 0.9 * update_rate, update_rate)
            update_rate = tf.maximum(update_rate, self._slow_update_rate)
            update_rate = tf.where(entropy < self._fast_stage_thresh,
                                   np.float32(self._fast_update_rate),
                                   update_rate)
            self._update_rate.assign(update_rate)
            above = tf.cast(above, tf.float32)
            below = 1 - above
            increasing = tf.cast(avg_entropy > prev_avg_entropy, tf.float32)
            decreasing = 1 - increasing
            log_alpha = self._log_alpha + (
                (below + 0.5 * above) * decreasing -
                (above + 0.5 * below) * increasing) * update_rate
            log_alpha = tf.maximum(log_alpha, np.float32(self._min_log_alpha))
            self._log_alpha.assign(log_alpha)

        run_if(self._stage == -1, _init)
        run_if(self._stage >= 0, _adjust)
        alpha = tf.exp(self._log_alpha)

        def _summarize():
            with self.name_scope:
                tf.summary.scalar("alpha", alpha)
                tf.summary.scalar("entropy_std", entropy_std)
                tf.summary.scalar("avg_entropy", avg_entropy)
                tf.summary.scalar("stage", self._stage)
                tf.summary.scalar("update_rate", self._update_rate)

        if self._debug_summaries:
            run_if(should_record_summaries(), _summarize)

        return loss_info._replace(loss=loss_info.loss * alpha)
Esempio n. 3
0
    def __init__(self,
                 action_spec,
                 initial_alpha=0.1,
                 skip_free_stage=False,
                 max_entropy=None,
                 target_entropy=None,
                 very_slow_update_rate=0.001,
                 slow_update_rate=0.01,
                 fast_update_rate=np.log(2),
                 min_alpha=1e-4,
                 average_window=2,
                 debug_summaries=False):
        """Create an EntropyTargetAlgorithm

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            initial_alpha (float): initial value for alpha; make sure that it's
                large enough for initial meaningful exploration
            skip_free_stage (bool): If True, directly goes to the adjust stage.
            max_entropy (float): the upper bound of the entropy. If not provided,
                min(initial_entropy * 0.8, initial_entropy / 0.8) is used.
                initial_entropy is estimated from the first `average_window`
                steps. 0.8 is to ensure that we can get a policy a less random
                as the initial policy before starting the free stage.
            target_entropy (float): the lower bound of the entropy. If not
                provided, a default value proportional to the action dimension
                is used. This value should be less or equal than `max_entropy`.
            very_slow_update_rate (float): a tiny update rate for log_alpha; used
                in stage 0
            slow_update_rate (float): minimal update rate for log_alpha; used in
                stage 2
            fast_update_rate (float): maximum update rate for log_alpha; used in
                state 2
            min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is
                used.
            average_window (int): window size for averaging past entropies.
            optimizer (tf.optimizers.Optimizer): The optimizer for training. If
                not provided, will use the same optimizer of the parent
                algorithm.
            debug_summaries (bool): True if debug summaries should be created.
        """
        super().__init__(
            debug_summaries=debug_summaries, name="EntropyTargetAlgorithm")

        self._log_alpha = tf.Variable(
            name='log_alpha',
            initial_value=np.log(initial_alpha),
            dtype=tf.float32,
            trainable=False)
        self._stage = tf.Variable(
            name='stage', initial_value=-2, dtype=tf.int32, trainable=False)
        self._avg_entropy = ScalarWindowAverager(average_window)
        self._update_rate = tf.Variable(
            name='update_rate',
            initial_value=fast_update_rate,
            dtype=tf.float32,
            trainable=False)
        self._action_spec = action_spec
        self._min_log_alpha = -100.
        if min_alpha >= 0.:
            self._min_log_alpha = np.log(min_alpha)

        flat_action_spec = tf.nest.flatten(self._action_spec)
        if target_entropy is None:
            target_entropy = np.sum(
                list(map(calc_default_target_entropy, flat_action_spec)))
            logging.info("target_entropy=%s" % target_entropy)

        if max_entropy is None:
            # max_entropy will be estimated in the first `average_window` steps.
            max_entropy = 0.
            self._stage.assign(-2 - average_window)
        else:
            assert target_entropy <= max_entropy, (
                "Target entropy %s should be less or equal than max entropy %s!"
                % (target_entropy, max_entropy))
        self._max_entropy = tf.Variable(
            name='max_entropy',
            initial_value=max_entropy,
            dtype=tf.float32,
            trainable=False)

        if skip_free_stage:
            self._stage.assign(1)

        if target_entropy > 0:
            self._fast_stage_thresh = 0.5 * target_entropy
        else:
            self._fast_stage_thresh = 2.0 * target_entropy
        self._target_entropy = target_entropy
        self._very_slow_update_rate = very_slow_update_rate
        self._slow_update_rate = slow_update_rate
        self._fast_update_rate = fast_update_rate
Esempio n. 4
0
class EntropyTargetAlgorithm(Algorithm):
    """Algorithm for adjust entropy regularization.

    It tries to adjust the entropy regularization (i.e. alpha) so that the
    the entropy is not smaller than `target_entropy`.

    The algorithm has three stages:
    0. init stage. This is an optional stage. If the initial entropy is already
       below `max_entropy`, then this stage is skipped. Otherwise, the alpha will
       be slowly decreased so that the entropy will land at `max_entropy` to
       trigger the next `free stage`. Basically, this stage let the user to choose
       an arbitrary large init alpha without considering every specific case.
    1. free stage. During this stage, the alpha is not changed. It transitions
       to adjust_stage once entropy drops below `target_entropy`.
    2. adjust stage. During this stage, log_alpha is adjusted using this formula:
       ((below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate
       Note that log_alpha will always be decreased if entropy is increasing
       even when the entropy is below the target entropy. This is to prevent
       overshooting log_alpha to a too big value. Same reason for always
       increasing log_alpha even when the entropy is above the target entropy.
       `update_rate` is initialized to `fast_update_rate` and is reduced by a
       factor of 0.9 whenever the entropy crosses `target_entropy`. `udpate_rate`
       is reset to `fast_update_rate` if entropy drops too much below
       `target_entropy` (i.e., fast_stage_thresh in the code, which is the half
       of `target_entropy` if it is positive, and twice of `target_entropy` if
       it is negative.
    """

    def __init__(self,
                 action_spec,
                 initial_alpha=0.1,
                 skip_free_stage=False,
                 max_entropy=None,
                 target_entropy=None,
                 very_slow_update_rate=0.001,
                 slow_update_rate=0.01,
                 fast_update_rate=np.log(2),
                 min_alpha=1e-4,
                 average_window=2,
                 debug_summaries=False):
        """Create an EntropyTargetAlgorithm

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            initial_alpha (float): initial value for alpha; make sure that it's
                large enough for initial meaningful exploration
            skip_free_stage (bool): If True, directly goes to the adjust stage.
            max_entropy (float): the upper bound of the entropy. If not provided,
                min(initial_entropy * 0.8, initial_entropy / 0.8) is used.
                initial_entropy is estimated from the first `average_window`
                steps. 0.8 is to ensure that we can get a policy a less random
                as the initial policy before starting the free stage.
            target_entropy (float): the lower bound of the entropy. If not
                provided, a default value proportional to the action dimension
                is used. This value should be less or equal than `max_entropy`.
            very_slow_update_rate (float): a tiny update rate for log_alpha; used
                in stage 0
            slow_update_rate (float): minimal update rate for log_alpha; used in
                stage 2
            fast_update_rate (float): maximum update rate for log_alpha; used in
                state 2
            min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is
                used.
            average_window (int): window size for averaging past entropies.
            optimizer (tf.optimizers.Optimizer): The optimizer for training. If
                not provided, will use the same optimizer of the parent
                algorithm.
            debug_summaries (bool): True if debug summaries should be created.
        """
        super().__init__(
            debug_summaries=debug_summaries, name="EntropyTargetAlgorithm")

        self._log_alpha = tf.Variable(
            name='log_alpha',
            initial_value=np.log(initial_alpha),
            dtype=tf.float32,
            trainable=False)
        self._stage = tf.Variable(
            name='stage', initial_value=-2, dtype=tf.int32, trainable=False)
        self._avg_entropy = ScalarWindowAverager(average_window)
        self._update_rate = tf.Variable(
            name='update_rate',
            initial_value=fast_update_rate,
            dtype=tf.float32,
            trainable=False)
        self._action_spec = action_spec
        self._min_log_alpha = -100.
        if min_alpha >= 0.:
            self._min_log_alpha = np.log(min_alpha)

        flat_action_spec = tf.nest.flatten(self._action_spec)
        if target_entropy is None:
            target_entropy = np.sum(
                list(map(calc_default_target_entropy, flat_action_spec)))
            logging.info("target_entropy=%s" % target_entropy)

        if max_entropy is None:
            # max_entropy will be estimated in the first `average_window` steps.
            max_entropy = 0.
            self._stage.assign(-2 - average_window)
        else:
            assert target_entropy <= max_entropy, (
                "Target entropy %s should be less or equal than max entropy %s!"
                % (target_entropy, max_entropy))
        self._max_entropy = tf.Variable(
            name='max_entropy',
            initial_value=max_entropy,
            dtype=tf.float32,
            trainable=False)

        if skip_free_stage:
            self._stage.assign(1)

        if target_entropy > 0:
            self._fast_stage_thresh = 0.5 * target_entropy
        else:
            self._fast_stage_thresh = 2.0 * target_entropy
        self._target_entropy = target_entropy
        self._very_slow_update_rate = very_slow_update_rate
        self._slow_update_rate = slow_update_rate
        self._fast_update_rate = fast_update_rate

    def train_step(self, distribution, step_type):
        """Train step.

        Args:
            distribution (nested Distribution): action distribution from the
                policy.
            step_type (StepType): the step type for the distributions.
        Returns:
            AlgorithmStep. `info` field is LossInfo, other fields are empty.
        """
        entropy, entropy_for_gradient = dist_utils.entropy_with_fallback(
            distribution, self._action_spec)
        return AlgorithmStep(
            outputs=(),
            state=(),
            info=EntropyTargetInfo(
                step_type=step_type,
                loss=LossInfo(
                    loss=-entropy_for_gradient,
                    extra=EntropyTargetLossInfo(neg_entropy=-entropy))))

    def calc_loss(self, training_info: EntropyTargetInfo, valid_mask=None):
        loss_info = training_info.loss
        mask = tf.cast(training_info.step_type != StepType.LAST, tf.float32)
        if valid_mask:
            mask = mask * tf.cast(valid_mask, tf.float32)
        entropy = -loss_info.extra.neg_entropy * mask
        num = tf.reduce_sum(mask)
        not_empty = num > 0
        num = tf.maximum(num, 1)
        entropy2 = tf.reduce_sum(tf.square(entropy)) / num
        entropy = tf.reduce_sum(entropy) / num
        entropy_std = tf.sqrt(tf.maximum(0.0, entropy2 - entropy * entropy))

        run_if(not_empty, lambda: self.adjust_alpha(entropy))

        def _summarize():
            with self.name_scope:
                tf.summary.scalar("entropy_std", entropy_std)

        if self._debug_summaries:
            run_if(
                tf.logical_and(not_empty, should_record_summaries()),
                _summarize)

        alpha = tf.exp(self._log_alpha)
        return loss_info._replace(loss=loss_info.loss * alpha)

    def adjust_alpha(self, entropy):
        """Adjust alpha according to the current entropy.

        Args:
            entropy (scalar Tensor). the current entropy.
        Returns:
            adjusted entropy regularization
        """
        prev_avg_entropy = self._avg_entropy.get()
        avg_entropy = self._avg_entropy.average(entropy)

        def _init_entropy():
            self._max_entropy.assign(
                tf.minimum(0.8 * avg_entropy, avg_entropy / 0.8))
            self._stage.assign_add(1)

        def _init():
            below = avg_entropy < self._max_entropy
            increasing = tf.cast(avg_entropy > prev_avg_entropy, tf.float32)
            # -1 * increasing + 0.5 * (1 - increasing)
            update_rate = (
                0.5 - 1.5 * increasing) * self._very_slow_update_rate
            self._stage.assign_add(tf.cast(below, tf.int32))
            self._log_alpha.assign(
                tf.maximum(self._log_alpha + update_rate,
                           np.float32(self._min_log_alpha)))

        def _free():
            crossing = avg_entropy < self._target_entropy
            self._stage.assign_add(tf.cast(crossing, tf.int32))

        def _adjust():
            previous_above = tf.cast(self._stage, tf.bool)
            above = avg_entropy > self._target_entropy
            self._stage.assign(tf.cast(above, tf.int32))
            crossing = above != previous_above
            update_rate = self._update_rate
            update_rate = tf.where(crossing, 0.9 * update_rate, update_rate)
            update_rate = tf.maximum(update_rate, self._slow_update_rate)
            update_rate = tf.where(entropy < self._fast_stage_thresh,
                                   np.float32(self._fast_update_rate),
                                   update_rate)
            self._update_rate.assign(update_rate)
            above = tf.cast(above, tf.float32)
            below = 1 - above
            increasing = tf.cast(avg_entropy > prev_avg_entropy, tf.float32)
            decreasing = 1 - increasing
            log_alpha = self._log_alpha + (
                (below + 0.5 * above) * decreasing -
                (above + 0.5 * below) * increasing) * update_rate
            log_alpha = tf.maximum(log_alpha, np.float32(self._min_log_alpha))
            self._log_alpha.assign(log_alpha)

        run_if(self._stage < -2, _init_entropy)
        run_if(self._stage == -2, _init)
        run_if(self._stage == -1, _free)
        run_if(self._stage >= 0, _adjust)
        alpha = tf.exp(self._log_alpha)

        def _summarize():
            with self.name_scope:
                tf.summary.scalar("alpha", alpha)
                tf.summary.scalar("avg_entropy", avg_entropy)
                tf.summary.scalar("stage", self._stage)
                tf.summary.scalar("update_rate", self._update_rate)

        if self._debug_summaries:
            run_if(should_record_summaries(), _summarize)

        return alpha
Esempio n. 5
0
    def __init__(self,
                 action_spec,
                 initial_alpha=0.1,
                 skip_free_stage=False,
                 max_entropy=None,
                 target_entropy=None,
                 very_slow_update_rate=0.001,
                 slow_update_rate=0.01,
                 fast_update_rate=np.log(2),
                 min_alpha=1e-4,
                 average_window=2,
                 debug_summaries=False,
                 name="EntropyTargetAlgorithm"):
        """
        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            initial_alpha (float): initial value for alpha; make sure that it's
                large enough for initial meaningful exploration
            skip_free_stage (bool): If True, directly goes to the adjust stage.
            max_entropy (float): the upper bound of the entropy. If not provided,
                ``min(initial_entropy * 0.8, initial_entropy / 0.8)`` is used.
                initial_entropy is estimated from the first ``average_window``
                steps. 0.8 is to ensure that we can get a policy a less random
                as the initial policy before starting the free stage.
            target_entropy (float): the lower bound of the entropy. If not
                provided, a default value proportional to the action dimension
                is used. This value should be less or equal than ``max_entropy``.
            very_slow_update_rate (float): a tiny update rate for ``log_alpha``;
                used in stage 0.
            slow_update_rate (float): minimal update rate for ``log_alpha``; used
                in stage 2.
            fast_update_rate (float): maximum update rate for ``log_alpha``; used
                in state 2.
            min_alpha (float): the minimal value of alpha. If <=0, :math:`e^{-100}`
                is used.
            average_window (int): window size for averaging past entropies.
            debug_summaries (bool): True if debug summaries should be created.
        """
        super().__init__(debug_summaries=debug_summaries, name=name)

        self.register_buffer(
            '_log_alpha',
            torch.tensor(np.log(initial_alpha), dtype=torch.float32))
        self.register_buffer('_stage', torch.tensor(-2, dtype=torch.int32))
        self._avg_entropy = ScalarWindowAverager(average_window)
        self.register_buffer(
            "_update_rate", torch.tensor(fast_update_rate,
                                         dtype=torch.float32))
        self._action_spec = action_spec
        self._min_log_alpha = -100.
        if min_alpha >= 0.:
            self._min_log_alpha = np.log(min_alpha)
        self._min_log_alpha = torch.tensor(self._min_log_alpha)

        flat_action_spec = alf.nest.flatten(self._action_spec)
        if target_entropy is None:
            target_entropy = np.sum(
                list(map(calc_default_target_entropy, flat_action_spec)))
            logging.info("target_entropy=%s" % target_entropy)

        if not isinstance(target_entropy, Callable):
            target_entropy = ConstantScheduler(target_entropy)

        if max_entropy is None:
            # max_entropy will be estimated in the first `average_window` steps.
            max_entropy = 0.
            self._stage.fill_(-2 - average_window)
        else:
            assert target_entropy() <= max_entropy, (
                "Target entropy %s should be less or equal than max entropy %s!"
                % (target_entropy(), max_entropy))
        self.register_buffer("_max_entropy",
                             torch.tensor(max_entropy, dtype=torch.float32))

        if skip_free_stage:
            self._stage.fill_(1)

        self._target_entropy = target_entropy
        self._very_slow_update_rate = very_slow_update_rate
        self._slow_update_rate = torch.tensor(slow_update_rate)
        self._fast_update_rate = torch.tensor(fast_update_rate)
Esempio n. 6
0
class EntropyTargetAlgorithm(Algorithm):
    """Algorithm for adjusting entropy regularization.

    It tries to adjust the entropy regularization (i.e. alpha) so that the
    the entropy is not smaller than ``target_entropy``.

    The algorithm has three stages:

    0. init stage. This is an optional stage. If the initial entropy is already
       below ``max_entropy``, then this stage is skipped. Otherwise, the alpha will
       be slowly decreased so that the entropy will land at ``max_entropy`` to
       trigger the next ``free_stage``. Basically, this stage let the user to choose
       an arbitrary large init alpha without considering every specific case.
    1. free stage. During this stage, the alpha is not changed. It transitions
       to adjust_stage once entropy drops below ``target_entropy``.
    2. adjust stage. During this stage, ``log_alpha`` is adjusted using this formula:

       .. code-block:: python

            ((below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate

       Note that ``log_alpha`` will always be decreased if entropy is increasing
       even when the entropy is below the target entropy. This is to prevent
       overshooting ``log_alpha`` to a too big value. Same reason for always
       increasing ``log_alpha`` even when the entropy is above the target entropy.
       ``update_rate`` is initialized to ``fast_update_rate`` and is reduced by a
       factor of 0.9 whenever the entropy crosses ``target_entropy``. ``udpate_rate``
       is reset to ``fast_update_rate`` if entropy drops too much below
       ``target_entropy`` (i.e., ``fast_stage_thresh`` in the code, which is the half
       of ``target_entropy`` if it is positive, and twice of ``target_entropy`` if
       it is negative.

    ``EntropyTargetAlgorithm`` can be used to approximately reproduce the learning
    of temperature in `Soft Actor-Critic Algorithms and Applications <https://arxiv.org/abs/1812.05905>`_.
    To do so, you need to use the same ``target_entropy``, set ``skip_free_stage``
    to True, and  set ``slow_update_rate`` and ``fast_update_rate`` to the 4
    times of the learning rate for temperature.
    """
    def __init__(self,
                 action_spec,
                 initial_alpha=0.1,
                 skip_free_stage=False,
                 max_entropy=None,
                 target_entropy=None,
                 very_slow_update_rate=0.001,
                 slow_update_rate=0.01,
                 fast_update_rate=np.log(2),
                 min_alpha=1e-4,
                 average_window=2,
                 debug_summaries=False,
                 name="EntropyTargetAlgorithm"):
        """
        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            initial_alpha (float): initial value for alpha; make sure that it's
                large enough for initial meaningful exploration
            skip_free_stage (bool): If True, directly goes to the adjust stage.
            max_entropy (float): the upper bound of the entropy. If not provided,
                ``min(initial_entropy * 0.8, initial_entropy / 0.8)`` is used.
                initial_entropy is estimated from the first ``average_window``
                steps. 0.8 is to ensure that we can get a policy a less random
                as the initial policy before starting the free stage.
            target_entropy (float): the lower bound of the entropy. If not
                provided, a default value proportional to the action dimension
                is used. This value should be less or equal than ``max_entropy``.
            very_slow_update_rate (float): a tiny update rate for ``log_alpha``;
                used in stage 0.
            slow_update_rate (float): minimal update rate for ``log_alpha``; used
                in stage 2.
            fast_update_rate (float): maximum update rate for ``log_alpha``; used
                in state 2.
            min_alpha (float): the minimal value of alpha. If <=0, :math:`e^{-100}`
                is used.
            average_window (int): window size for averaging past entropies.
            debug_summaries (bool): True if debug summaries should be created.
        """
        super().__init__(debug_summaries=debug_summaries, name=name)

        self.register_buffer(
            '_log_alpha',
            torch.tensor(np.log(initial_alpha), dtype=torch.float32))
        self.register_buffer('_stage', torch.tensor(-2, dtype=torch.int32))
        self._avg_entropy = ScalarWindowAverager(average_window)
        self.register_buffer(
            "_update_rate", torch.tensor(fast_update_rate,
                                         dtype=torch.float32))
        self._action_spec = action_spec
        self._min_log_alpha = -100.
        if min_alpha >= 0.:
            self._min_log_alpha = np.log(min_alpha)
        self._min_log_alpha = torch.tensor(self._min_log_alpha)

        flat_action_spec = alf.nest.flatten(self._action_spec)
        if target_entropy is None:
            target_entropy = np.sum(
                list(map(calc_default_target_entropy, flat_action_spec)))
            logging.info("target_entropy=%s" % target_entropy)

        if not isinstance(target_entropy, Callable):
            target_entropy = ConstantScheduler(target_entropy)

        if max_entropy is None:
            # max_entropy will be estimated in the first `average_window` steps.
            max_entropy = 0.
            self._stage.fill_(-2 - average_window)
        else:
            assert target_entropy() <= max_entropy, (
                "Target entropy %s should be less or equal than max entropy %s!"
                % (target_entropy(), max_entropy))
        self.register_buffer("_max_entropy",
                             torch.tensor(max_entropy, dtype=torch.float32))

        if skip_free_stage:
            self._stage.fill_(1)

        self._target_entropy = target_entropy
        self._very_slow_update_rate = very_slow_update_rate
        self._slow_update_rate = torch.tensor(slow_update_rate)
        self._fast_update_rate = torch.tensor(fast_update_rate)

    def rollout_step(self, distribution, step_type, on_policy_training):
        """Rollout step.

        Args:
            distribution (nested Distribution): action distribution from the
                policy.
            step_type (StepType): the step type for the distributions.
            on_policy_training (bool): If False, this step does nothing.

        Returns:
            AlgStep: ``info`` field is ``LossInfo``, other fields are empty. All
            fields are empty If ``on_policy_training=False``.
        """
        if on_policy_training:
            return self.train_step(distribution, step_type)
        else:
            return AlgStep()

    def train_step(self, distribution, step_type):
        """Train step.

        Args:
            distribution (nested Distribution): action distribution from the
                policy.
            step_type (StepType): the step type for the distributions.
        Returns:
            AlgStep: ``info`` field is ``LossInfo``, other fields are empty.
        """
        entropy, entropy_for_gradient = entropy_with_fallback(distribution)
        return AlgStep(
            output=(),
            state=(),
            info=EntropyTargetInfo(loss=LossInfo(loss=-entropy_for_gradient,
                                                 extra=EntropyTargetLossInfo(
                                                     neg_entropy=-entropy))))

    def calc_loss(self, experience, info: EntropyTargetInfo, valid_mask=None):
        """Calculate loss.

        Args:
            experience (Experience): experience for gradient update
            info (EntropyTargetInfo): for computing loss.
            valid_mask (tensor): valid mask to be applied on time steps.

        Returns:
            LossInfo:
        """
        loss_info = info.loss
        mask = (experience.step_type != StepType.LAST).type(torch.float32)
        if valid_mask:
            mask = mask * (valid_mask).type(torch.float32)
        entropy = -loss_info.extra.neg_entropy * mask
        num = torch.sum(mask)
        not_empty = num > 0
        num = max(num, 1)
        entropy2 = torch.sum(entropy**2) / num
        entropy = torch.sum(entropy) / num
        entropy_std = torch.sqrt(
            torch.max(torch.tensor(0.0), entropy2 - entropy * entropy))

        if not_empty:
            self.adjust_alpha(entropy)
            if self._debug_summaries and should_record_summaries():
                with alf.summary.scope(self.name):
                    alf.summary.scalar("entropy_std", entropy_std)

        alpha = torch.exp(self._log_alpha)
        return loss_info._replace(loss=loss_info.loss * alpha)

    def adjust_alpha(self, entropy):
        """Adjust alpha according to the current entropy.

        Args:
            entropy (scalar Tensor): the current entropy.
        Returns:
            adjusted entropy regularization
        """
        prev_avg_entropy = self._avg_entropy.get()
        avg_entropy = self._avg_entropy.average(entropy)

        target_entropy = self._target_entropy()

        if target_entropy > 0:
            fast_stage_thresh = 0.5 * target_entropy
        else:
            fast_stage_thresh = 2.0 * target_entropy

        def _init_entropy():
            self._max_entropy.fill_(
                torch.min(0.8 * avg_entropy, avg_entropy / 0.8))
            self._stage.add_(1)

        def _init():
            below = avg_entropy < self._max_entropy
            decreasing = (avg_entropy < prev_avg_entropy).type(torch.float32)
            # -1 * (1 - decreasing) + 0.5 * decreasing
            update_rate = (-1 + 1.5 * decreasing) * self._very_slow_update_rate
            self._stage.add_(below.type(torch.int32))
            self._log_alpha.fill_(
                torch.max(self._log_alpha + update_rate, self._min_log_alpha))

        def _free():
            crossing = avg_entropy < target_entropy
            self._stage.add_(crossing.type(torch.int32))

        def _adjust():
            previous_above = self._stage.type(torch.bool)
            above = avg_entropy > target_entropy
            self._stage.fill_(above.type(torch.int32))
            crossing = above != previous_above
            update_rate = self._update_rate
            update_rate = torch.where(crossing, 0.9 * update_rate, update_rate)
            update_rate = torch.max(update_rate, self._slow_update_rate)
            update_rate = torch.where(entropy < fast_stage_thresh,
                                      self._fast_update_rate, update_rate)
            self._update_rate.fill_(update_rate)
            above = above.type(torch.float32)
            below = 1 - above
            decreasing = (avg_entropy < prev_avg_entropy).type(torch.float32)
            increasing = 1 - decreasing
            log_alpha = self._log_alpha + (
                (below + 0.5 * above) * decreasing -
                (above + 0.5 * below) * increasing) * update_rate
            log_alpha = torch.max(log_alpha, self._min_log_alpha)
            self._log_alpha.fill_(log_alpha)

        if self._stage < -2:
            _init_entropy()
        if self._stage == -2:
            _init()
        if self._stage == -1:
            _free()
        if self._stage >= 0:
            _adjust()
        alpha = torch.exp(self._log_alpha)

        if self._debug_summaries and should_record_summaries():
            with alf.summary.scope(self.name):
                alf.summary.scalar("alpha", alpha)
                alf.summary.scalar("avg_entropy", avg_entropy)
                alf.summary.scalar("stage", self._stage)
                alf.summary.scalar("update_rate", self._update_rate)
                if type(self._target_entropy) != ConstantScheduler:
                    alf.summary.scalar("target_entropy", target_entropy)

        return alpha