def __init__(self, action_spec, initial_alpha=0.01, target_entropy=None, slow_update_rate=0.01, fast_update_rate=np.log(2), min_alpha=1e-4, debug_summaries=False): """Create an EntropyTargetAlgorithm Args: action_spec (nested BoundedTensorSpec): representing the actions. initial_alpha (float): initial value for alpha. target_entropy (float): the lower bound of the entropy. If not provided, a default value proportional to the action dimension is used. slow_update_rate (float): minimal update rate for log_alpha fast_update_rate (float): maximum update rate for log_alpha min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is used. optimizer (tf.optimizers.Optimizer): The optimizer for training. If not provided, will use the same optimizer of the parent algorithm. debug_summaries (bool): True if debug summaries should be created. """ super().__init__( debug_summaries=debug_summaries, name="EntropyTargetAlgorithm") self._log_alpha = tf.Variable( name='log_alpha', initial_value=np.log(initial_alpha), dtype=tf.float32, trainable=False) self._stage = tf.Variable( name='stage', initial_value=-1, dtype=tf.int32, trainable=False) self._avg_entropy = ScalarWindowAverager(2) self._update_rate = tf.Variable( name='update_rate', initial_value=fast_update_rate, dtype=tf.float32, trainable=False) self._action_spec = action_spec self._min_log_alpha = -100. if min_alpha >= 0.: self._min_log_alpha = np.log(min_alpha) if target_entropy is None: flat_action_spec = tf.nest.flatten(self._action_spec) target_entropy = np.sum( list(map(calc_default_target_entropy, flat_action_spec))) if target_entropy > 0: self._fast_stage_thresh = 0.5 * target_entropy else: self._fast_stage_thresh = 2.0 * target_entropy self._target_entropy = target_entropy self._slow_update_rate = slow_update_rate self._fast_update_rate = fast_update_rate logging.info("target_entropy=%s" % target_entropy)
class EntropyTargetAlgorithm(Algorithm): """Algorithm for adjust entropy regularization. It tries to adjust the entropy regularization (i.e. alpha) so that the the entropy is not smaller than `target_entropy`. The algorithm has two stages: 1. init stage. During this stage, the alpha is not changed. It transitions to adjust_stage once entropy drops below `target_entropy`. 2. adjust stage. During this stage, log_alpha is adjusted using this formula: ((below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate Note that log_alpha will always be decreased if entropy is increasing even when the entropy is below the target entropy. This is to prevent overshooting log_alpha to a too big value. Same reason for always increasing log_alpha even when the entropy is above the target entropy. `update_rate` is initialized to `fast_update_rate` and is reduced by a factor of 0.9 whenever the entropy crosses `target_entropy`. `udpate_rate` is reset to `fast_update_rate` if entropy drops too much below `target_entropy` (i.e., fast_stage_thresh in the code, which is the half of `target_entropy` if it is positive, and twice of `target_entropy` if it is negative. """ def __init__(self, action_spec, initial_alpha=0.01, target_entropy=None, slow_update_rate=0.01, fast_update_rate=np.log(2), min_alpha=1e-4, debug_summaries=False): """Create an EntropyTargetAlgorithm Args: action_spec (nested BoundedTensorSpec): representing the actions. initial_alpha (float): initial value for alpha. target_entropy (float): the lower bound of the entropy. If not provided, a default value proportional to the action dimension is used. slow_update_rate (float): minimal update rate for log_alpha fast_update_rate (float): maximum update rate for log_alpha min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is used. optimizer (tf.optimizers.Optimizer): The optimizer for training. If not provided, will use the same optimizer of the parent algorithm. debug_summaries (bool): True if debug summaries should be created. """ super().__init__( debug_summaries=debug_summaries, name="EntropyTargetAlgorithm") self._log_alpha = tf.Variable( name='log_alpha', initial_value=np.log(initial_alpha), dtype=tf.float32, trainable=False) self._stage = tf.Variable( name='stage', initial_value=-1, dtype=tf.int32, trainable=False) self._avg_entropy = ScalarWindowAverager(2) self._update_rate = tf.Variable( name='update_rate', initial_value=fast_update_rate, dtype=tf.float32, trainable=False) self._action_spec = action_spec self._min_log_alpha = -100. if min_alpha >= 0.: self._min_log_alpha = np.log(min_alpha) if target_entropy is None: flat_action_spec = tf.nest.flatten(self._action_spec) target_entropy = np.sum( list(map(calc_default_target_entropy, flat_action_spec))) if target_entropy > 0: self._fast_stage_thresh = 0.5 * target_entropy else: self._fast_stage_thresh = 2.0 * target_entropy self._target_entropy = target_entropy self._slow_update_rate = slow_update_rate self._fast_update_rate = fast_update_rate logging.info("target_entropy=%s" % target_entropy) def train_step(self, distribution, step_type): """Train step. Args: distribution (nested Distribution): action distribution from the policy. Returns: AlgorithmStep. `info` field is LossInfo, other fields are empty. """ entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( distribution, self._action_spec) return AlgorithmStep( outputs=(), state=(), info=EntropyTargetInfo( step_type=step_type, loss=LossInfo( loss=-entropy_for_gradient, extra=EntropyTargetLossInfo(entropy_loss=-entropy)))) def calc_loss(self, training_info: EntropyTargetInfo): loss_info = training_info.loss mask = tf.cast(training_info.step_type != StepType.LAST, tf.float32) entropy = -loss_info.extra.entropy_loss * mask num = tf.reduce_sum(mask) entropy2 = tf.reduce_sum(tf.square(entropy)) / num entropy = tf.reduce_sum(entropy) / num entropy_std = tf.sqrt(tf.maximum(0.0, entropy2 - entropy * entropy)) prev_avg_entropy = self._avg_entropy.get() avg_entropy = self._avg_entropy.average(entropy) def _init(): crossing = avg_entropy < self._target_entropy self._stage.assign_add(tf.cast(crossing, tf.int32)) def _adjust(): previous_above = tf.cast(self._stage, tf.bool) above = avg_entropy > self._target_entropy self._stage.assign(tf.cast(above, tf.int32)) crossing = above != previous_above update_rate = self._update_rate update_rate = tf.where(crossing, 0.9 * update_rate, update_rate) update_rate = tf.maximum(update_rate, self._slow_update_rate) update_rate = tf.where(entropy < self._fast_stage_thresh, np.float32(self._fast_update_rate), update_rate) self._update_rate.assign(update_rate) above = tf.cast(above, tf.float32) below = 1 - above increasing = tf.cast(avg_entropy > prev_avg_entropy, tf.float32) decreasing = 1 - increasing log_alpha = self._log_alpha + ( (below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate log_alpha = tf.maximum(log_alpha, np.float32(self._min_log_alpha)) self._log_alpha.assign(log_alpha) run_if(self._stage == -1, _init) run_if(self._stage >= 0, _adjust) alpha = tf.exp(self._log_alpha) def _summarize(): with self.name_scope: tf.summary.scalar("alpha", alpha) tf.summary.scalar("entropy_std", entropy_std) tf.summary.scalar("avg_entropy", avg_entropy) tf.summary.scalar("stage", self._stage) tf.summary.scalar("update_rate", self._update_rate) if self._debug_summaries: run_if(should_record_summaries(), _summarize) return loss_info._replace(loss=loss_info.loss * alpha)
def __init__(self, action_spec, initial_alpha=0.1, skip_free_stage=False, max_entropy=None, target_entropy=None, very_slow_update_rate=0.001, slow_update_rate=0.01, fast_update_rate=np.log(2), min_alpha=1e-4, average_window=2, debug_summaries=False): """Create an EntropyTargetAlgorithm Args: action_spec (nested BoundedTensorSpec): representing the actions. initial_alpha (float): initial value for alpha; make sure that it's large enough for initial meaningful exploration skip_free_stage (bool): If True, directly goes to the adjust stage. max_entropy (float): the upper bound of the entropy. If not provided, min(initial_entropy * 0.8, initial_entropy / 0.8) is used. initial_entropy is estimated from the first `average_window` steps. 0.8 is to ensure that we can get a policy a less random as the initial policy before starting the free stage. target_entropy (float): the lower bound of the entropy. If not provided, a default value proportional to the action dimension is used. This value should be less or equal than `max_entropy`. very_slow_update_rate (float): a tiny update rate for log_alpha; used in stage 0 slow_update_rate (float): minimal update rate for log_alpha; used in stage 2 fast_update_rate (float): maximum update rate for log_alpha; used in state 2 min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is used. average_window (int): window size for averaging past entropies. optimizer (tf.optimizers.Optimizer): The optimizer for training. If not provided, will use the same optimizer of the parent algorithm. debug_summaries (bool): True if debug summaries should be created. """ super().__init__( debug_summaries=debug_summaries, name="EntropyTargetAlgorithm") self._log_alpha = tf.Variable( name='log_alpha', initial_value=np.log(initial_alpha), dtype=tf.float32, trainable=False) self._stage = tf.Variable( name='stage', initial_value=-2, dtype=tf.int32, trainable=False) self._avg_entropy = ScalarWindowAverager(average_window) self._update_rate = tf.Variable( name='update_rate', initial_value=fast_update_rate, dtype=tf.float32, trainable=False) self._action_spec = action_spec self._min_log_alpha = -100. if min_alpha >= 0.: self._min_log_alpha = np.log(min_alpha) flat_action_spec = tf.nest.flatten(self._action_spec) if target_entropy is None: target_entropy = np.sum( list(map(calc_default_target_entropy, flat_action_spec))) logging.info("target_entropy=%s" % target_entropy) if max_entropy is None: # max_entropy will be estimated in the first `average_window` steps. max_entropy = 0. self._stage.assign(-2 - average_window) else: assert target_entropy <= max_entropy, ( "Target entropy %s should be less or equal than max entropy %s!" % (target_entropy, max_entropy)) self._max_entropy = tf.Variable( name='max_entropy', initial_value=max_entropy, dtype=tf.float32, trainable=False) if skip_free_stage: self._stage.assign(1) if target_entropy > 0: self._fast_stage_thresh = 0.5 * target_entropy else: self._fast_stage_thresh = 2.0 * target_entropy self._target_entropy = target_entropy self._very_slow_update_rate = very_slow_update_rate self._slow_update_rate = slow_update_rate self._fast_update_rate = fast_update_rate
class EntropyTargetAlgorithm(Algorithm): """Algorithm for adjust entropy regularization. It tries to adjust the entropy regularization (i.e. alpha) so that the the entropy is not smaller than `target_entropy`. The algorithm has three stages: 0. init stage. This is an optional stage. If the initial entropy is already below `max_entropy`, then this stage is skipped. Otherwise, the alpha will be slowly decreased so that the entropy will land at `max_entropy` to trigger the next `free stage`. Basically, this stage let the user to choose an arbitrary large init alpha without considering every specific case. 1. free stage. During this stage, the alpha is not changed. It transitions to adjust_stage once entropy drops below `target_entropy`. 2. adjust stage. During this stage, log_alpha is adjusted using this formula: ((below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate Note that log_alpha will always be decreased if entropy is increasing even when the entropy is below the target entropy. This is to prevent overshooting log_alpha to a too big value. Same reason for always increasing log_alpha even when the entropy is above the target entropy. `update_rate` is initialized to `fast_update_rate` and is reduced by a factor of 0.9 whenever the entropy crosses `target_entropy`. `udpate_rate` is reset to `fast_update_rate` if entropy drops too much below `target_entropy` (i.e., fast_stage_thresh in the code, which is the half of `target_entropy` if it is positive, and twice of `target_entropy` if it is negative. """ def __init__(self, action_spec, initial_alpha=0.1, skip_free_stage=False, max_entropy=None, target_entropy=None, very_slow_update_rate=0.001, slow_update_rate=0.01, fast_update_rate=np.log(2), min_alpha=1e-4, average_window=2, debug_summaries=False): """Create an EntropyTargetAlgorithm Args: action_spec (nested BoundedTensorSpec): representing the actions. initial_alpha (float): initial value for alpha; make sure that it's large enough for initial meaningful exploration skip_free_stage (bool): If True, directly goes to the adjust stage. max_entropy (float): the upper bound of the entropy. If not provided, min(initial_entropy * 0.8, initial_entropy / 0.8) is used. initial_entropy is estimated from the first `average_window` steps. 0.8 is to ensure that we can get a policy a less random as the initial policy before starting the free stage. target_entropy (float): the lower bound of the entropy. If not provided, a default value proportional to the action dimension is used. This value should be less or equal than `max_entropy`. very_slow_update_rate (float): a tiny update rate for log_alpha; used in stage 0 slow_update_rate (float): minimal update rate for log_alpha; used in stage 2 fast_update_rate (float): maximum update rate for log_alpha; used in state 2 min_alpha (float): the minimal value of alpha. If <=0, exp(-100) is used. average_window (int): window size for averaging past entropies. optimizer (tf.optimizers.Optimizer): The optimizer for training. If not provided, will use the same optimizer of the parent algorithm. debug_summaries (bool): True if debug summaries should be created. """ super().__init__( debug_summaries=debug_summaries, name="EntropyTargetAlgorithm") self._log_alpha = tf.Variable( name='log_alpha', initial_value=np.log(initial_alpha), dtype=tf.float32, trainable=False) self._stage = tf.Variable( name='stage', initial_value=-2, dtype=tf.int32, trainable=False) self._avg_entropy = ScalarWindowAverager(average_window) self._update_rate = tf.Variable( name='update_rate', initial_value=fast_update_rate, dtype=tf.float32, trainable=False) self._action_spec = action_spec self._min_log_alpha = -100. if min_alpha >= 0.: self._min_log_alpha = np.log(min_alpha) flat_action_spec = tf.nest.flatten(self._action_spec) if target_entropy is None: target_entropy = np.sum( list(map(calc_default_target_entropy, flat_action_spec))) logging.info("target_entropy=%s" % target_entropy) if max_entropy is None: # max_entropy will be estimated in the first `average_window` steps. max_entropy = 0. self._stage.assign(-2 - average_window) else: assert target_entropy <= max_entropy, ( "Target entropy %s should be less or equal than max entropy %s!" % (target_entropy, max_entropy)) self._max_entropy = tf.Variable( name='max_entropy', initial_value=max_entropy, dtype=tf.float32, trainable=False) if skip_free_stage: self._stage.assign(1) if target_entropy > 0: self._fast_stage_thresh = 0.5 * target_entropy else: self._fast_stage_thresh = 2.0 * target_entropy self._target_entropy = target_entropy self._very_slow_update_rate = very_slow_update_rate self._slow_update_rate = slow_update_rate self._fast_update_rate = fast_update_rate def train_step(self, distribution, step_type): """Train step. Args: distribution (nested Distribution): action distribution from the policy. step_type (StepType): the step type for the distributions. Returns: AlgorithmStep. `info` field is LossInfo, other fields are empty. """ entropy, entropy_for_gradient = dist_utils.entropy_with_fallback( distribution, self._action_spec) return AlgorithmStep( outputs=(), state=(), info=EntropyTargetInfo( step_type=step_type, loss=LossInfo( loss=-entropy_for_gradient, extra=EntropyTargetLossInfo(neg_entropy=-entropy)))) def calc_loss(self, training_info: EntropyTargetInfo, valid_mask=None): loss_info = training_info.loss mask = tf.cast(training_info.step_type != StepType.LAST, tf.float32) if valid_mask: mask = mask * tf.cast(valid_mask, tf.float32) entropy = -loss_info.extra.neg_entropy * mask num = tf.reduce_sum(mask) not_empty = num > 0 num = tf.maximum(num, 1) entropy2 = tf.reduce_sum(tf.square(entropy)) / num entropy = tf.reduce_sum(entropy) / num entropy_std = tf.sqrt(tf.maximum(0.0, entropy2 - entropy * entropy)) run_if(not_empty, lambda: self.adjust_alpha(entropy)) def _summarize(): with self.name_scope: tf.summary.scalar("entropy_std", entropy_std) if self._debug_summaries: run_if( tf.logical_and(not_empty, should_record_summaries()), _summarize) alpha = tf.exp(self._log_alpha) return loss_info._replace(loss=loss_info.loss * alpha) def adjust_alpha(self, entropy): """Adjust alpha according to the current entropy. Args: entropy (scalar Tensor). the current entropy. Returns: adjusted entropy regularization """ prev_avg_entropy = self._avg_entropy.get() avg_entropy = self._avg_entropy.average(entropy) def _init_entropy(): self._max_entropy.assign( tf.minimum(0.8 * avg_entropy, avg_entropy / 0.8)) self._stage.assign_add(1) def _init(): below = avg_entropy < self._max_entropy increasing = tf.cast(avg_entropy > prev_avg_entropy, tf.float32) # -1 * increasing + 0.5 * (1 - increasing) update_rate = ( 0.5 - 1.5 * increasing) * self._very_slow_update_rate self._stage.assign_add(tf.cast(below, tf.int32)) self._log_alpha.assign( tf.maximum(self._log_alpha + update_rate, np.float32(self._min_log_alpha))) def _free(): crossing = avg_entropy < self._target_entropy self._stage.assign_add(tf.cast(crossing, tf.int32)) def _adjust(): previous_above = tf.cast(self._stage, tf.bool) above = avg_entropy > self._target_entropy self._stage.assign(tf.cast(above, tf.int32)) crossing = above != previous_above update_rate = self._update_rate update_rate = tf.where(crossing, 0.9 * update_rate, update_rate) update_rate = tf.maximum(update_rate, self._slow_update_rate) update_rate = tf.where(entropy < self._fast_stage_thresh, np.float32(self._fast_update_rate), update_rate) self._update_rate.assign(update_rate) above = tf.cast(above, tf.float32) below = 1 - above increasing = tf.cast(avg_entropy > prev_avg_entropy, tf.float32) decreasing = 1 - increasing log_alpha = self._log_alpha + ( (below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate log_alpha = tf.maximum(log_alpha, np.float32(self._min_log_alpha)) self._log_alpha.assign(log_alpha) run_if(self._stage < -2, _init_entropy) run_if(self._stage == -2, _init) run_if(self._stage == -1, _free) run_if(self._stage >= 0, _adjust) alpha = tf.exp(self._log_alpha) def _summarize(): with self.name_scope: tf.summary.scalar("alpha", alpha) tf.summary.scalar("avg_entropy", avg_entropy) tf.summary.scalar("stage", self._stage) tf.summary.scalar("update_rate", self._update_rate) if self._debug_summaries: run_if(should_record_summaries(), _summarize) return alpha
def __init__(self, action_spec, initial_alpha=0.1, skip_free_stage=False, max_entropy=None, target_entropy=None, very_slow_update_rate=0.001, slow_update_rate=0.01, fast_update_rate=np.log(2), min_alpha=1e-4, average_window=2, debug_summaries=False, name="EntropyTargetAlgorithm"): """ Args: action_spec (nested BoundedTensorSpec): representing the actions. initial_alpha (float): initial value for alpha; make sure that it's large enough for initial meaningful exploration skip_free_stage (bool): If True, directly goes to the adjust stage. max_entropy (float): the upper bound of the entropy. If not provided, ``min(initial_entropy * 0.8, initial_entropy / 0.8)`` is used. initial_entropy is estimated from the first ``average_window`` steps. 0.8 is to ensure that we can get a policy a less random as the initial policy before starting the free stage. target_entropy (float): the lower bound of the entropy. If not provided, a default value proportional to the action dimension is used. This value should be less or equal than ``max_entropy``. very_slow_update_rate (float): a tiny update rate for ``log_alpha``; used in stage 0. slow_update_rate (float): minimal update rate for ``log_alpha``; used in stage 2. fast_update_rate (float): maximum update rate for ``log_alpha``; used in state 2. min_alpha (float): the minimal value of alpha. If <=0, :math:`e^{-100}` is used. average_window (int): window size for averaging past entropies. debug_summaries (bool): True if debug summaries should be created. """ super().__init__(debug_summaries=debug_summaries, name=name) self.register_buffer( '_log_alpha', torch.tensor(np.log(initial_alpha), dtype=torch.float32)) self.register_buffer('_stage', torch.tensor(-2, dtype=torch.int32)) self._avg_entropy = ScalarWindowAverager(average_window) self.register_buffer( "_update_rate", torch.tensor(fast_update_rate, dtype=torch.float32)) self._action_spec = action_spec self._min_log_alpha = -100. if min_alpha >= 0.: self._min_log_alpha = np.log(min_alpha) self._min_log_alpha = torch.tensor(self._min_log_alpha) flat_action_spec = alf.nest.flatten(self._action_spec) if target_entropy is None: target_entropy = np.sum( list(map(calc_default_target_entropy, flat_action_spec))) logging.info("target_entropy=%s" % target_entropy) if not isinstance(target_entropy, Callable): target_entropy = ConstantScheduler(target_entropy) if max_entropy is None: # max_entropy will be estimated in the first `average_window` steps. max_entropy = 0. self._stage.fill_(-2 - average_window) else: assert target_entropy() <= max_entropy, ( "Target entropy %s should be less or equal than max entropy %s!" % (target_entropy(), max_entropy)) self.register_buffer("_max_entropy", torch.tensor(max_entropy, dtype=torch.float32)) if skip_free_stage: self._stage.fill_(1) self._target_entropy = target_entropy self._very_slow_update_rate = very_slow_update_rate self._slow_update_rate = torch.tensor(slow_update_rate) self._fast_update_rate = torch.tensor(fast_update_rate)
class EntropyTargetAlgorithm(Algorithm): """Algorithm for adjusting entropy regularization. It tries to adjust the entropy regularization (i.e. alpha) so that the the entropy is not smaller than ``target_entropy``. The algorithm has three stages: 0. init stage. This is an optional stage. If the initial entropy is already below ``max_entropy``, then this stage is skipped. Otherwise, the alpha will be slowly decreased so that the entropy will land at ``max_entropy`` to trigger the next ``free_stage``. Basically, this stage let the user to choose an arbitrary large init alpha without considering every specific case. 1. free stage. During this stage, the alpha is not changed. It transitions to adjust_stage once entropy drops below ``target_entropy``. 2. adjust stage. During this stage, ``log_alpha`` is adjusted using this formula: .. code-block:: python ((below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate Note that ``log_alpha`` will always be decreased if entropy is increasing even when the entropy is below the target entropy. This is to prevent overshooting ``log_alpha`` to a too big value. Same reason for always increasing ``log_alpha`` even when the entropy is above the target entropy. ``update_rate`` is initialized to ``fast_update_rate`` and is reduced by a factor of 0.9 whenever the entropy crosses ``target_entropy``. ``udpate_rate`` is reset to ``fast_update_rate`` if entropy drops too much below ``target_entropy`` (i.e., ``fast_stage_thresh`` in the code, which is the half of ``target_entropy`` if it is positive, and twice of ``target_entropy`` if it is negative. ``EntropyTargetAlgorithm`` can be used to approximately reproduce the learning of temperature in `Soft Actor-Critic Algorithms and Applications <https://arxiv.org/abs/1812.05905>`_. To do so, you need to use the same ``target_entropy``, set ``skip_free_stage`` to True, and set ``slow_update_rate`` and ``fast_update_rate`` to the 4 times of the learning rate for temperature. """ def __init__(self, action_spec, initial_alpha=0.1, skip_free_stage=False, max_entropy=None, target_entropy=None, very_slow_update_rate=0.001, slow_update_rate=0.01, fast_update_rate=np.log(2), min_alpha=1e-4, average_window=2, debug_summaries=False, name="EntropyTargetAlgorithm"): """ Args: action_spec (nested BoundedTensorSpec): representing the actions. initial_alpha (float): initial value for alpha; make sure that it's large enough for initial meaningful exploration skip_free_stage (bool): If True, directly goes to the adjust stage. max_entropy (float): the upper bound of the entropy. If not provided, ``min(initial_entropy * 0.8, initial_entropy / 0.8)`` is used. initial_entropy is estimated from the first ``average_window`` steps. 0.8 is to ensure that we can get a policy a less random as the initial policy before starting the free stage. target_entropy (float): the lower bound of the entropy. If not provided, a default value proportional to the action dimension is used. This value should be less or equal than ``max_entropy``. very_slow_update_rate (float): a tiny update rate for ``log_alpha``; used in stage 0. slow_update_rate (float): minimal update rate for ``log_alpha``; used in stage 2. fast_update_rate (float): maximum update rate for ``log_alpha``; used in state 2. min_alpha (float): the minimal value of alpha. If <=0, :math:`e^{-100}` is used. average_window (int): window size for averaging past entropies. debug_summaries (bool): True if debug summaries should be created. """ super().__init__(debug_summaries=debug_summaries, name=name) self.register_buffer( '_log_alpha', torch.tensor(np.log(initial_alpha), dtype=torch.float32)) self.register_buffer('_stage', torch.tensor(-2, dtype=torch.int32)) self._avg_entropy = ScalarWindowAverager(average_window) self.register_buffer( "_update_rate", torch.tensor(fast_update_rate, dtype=torch.float32)) self._action_spec = action_spec self._min_log_alpha = -100. if min_alpha >= 0.: self._min_log_alpha = np.log(min_alpha) self._min_log_alpha = torch.tensor(self._min_log_alpha) flat_action_spec = alf.nest.flatten(self._action_spec) if target_entropy is None: target_entropy = np.sum( list(map(calc_default_target_entropy, flat_action_spec))) logging.info("target_entropy=%s" % target_entropy) if not isinstance(target_entropy, Callable): target_entropy = ConstantScheduler(target_entropy) if max_entropy is None: # max_entropy will be estimated in the first `average_window` steps. max_entropy = 0. self._stage.fill_(-2 - average_window) else: assert target_entropy() <= max_entropy, ( "Target entropy %s should be less or equal than max entropy %s!" % (target_entropy(), max_entropy)) self.register_buffer("_max_entropy", torch.tensor(max_entropy, dtype=torch.float32)) if skip_free_stage: self._stage.fill_(1) self._target_entropy = target_entropy self._very_slow_update_rate = very_slow_update_rate self._slow_update_rate = torch.tensor(slow_update_rate) self._fast_update_rate = torch.tensor(fast_update_rate) def rollout_step(self, distribution, step_type, on_policy_training): """Rollout step. Args: distribution (nested Distribution): action distribution from the policy. step_type (StepType): the step type for the distributions. on_policy_training (bool): If False, this step does nothing. Returns: AlgStep: ``info`` field is ``LossInfo``, other fields are empty. All fields are empty If ``on_policy_training=False``. """ if on_policy_training: return self.train_step(distribution, step_type) else: return AlgStep() def train_step(self, distribution, step_type): """Train step. Args: distribution (nested Distribution): action distribution from the policy. step_type (StepType): the step type for the distributions. Returns: AlgStep: ``info`` field is ``LossInfo``, other fields are empty. """ entropy, entropy_for_gradient = entropy_with_fallback(distribution) return AlgStep( output=(), state=(), info=EntropyTargetInfo(loss=LossInfo(loss=-entropy_for_gradient, extra=EntropyTargetLossInfo( neg_entropy=-entropy)))) def calc_loss(self, experience, info: EntropyTargetInfo, valid_mask=None): """Calculate loss. Args: experience (Experience): experience for gradient update info (EntropyTargetInfo): for computing loss. valid_mask (tensor): valid mask to be applied on time steps. Returns: LossInfo: """ loss_info = info.loss mask = (experience.step_type != StepType.LAST).type(torch.float32) if valid_mask: mask = mask * (valid_mask).type(torch.float32) entropy = -loss_info.extra.neg_entropy * mask num = torch.sum(mask) not_empty = num > 0 num = max(num, 1) entropy2 = torch.sum(entropy**2) / num entropy = torch.sum(entropy) / num entropy_std = torch.sqrt( torch.max(torch.tensor(0.0), entropy2 - entropy * entropy)) if not_empty: self.adjust_alpha(entropy) if self._debug_summaries and should_record_summaries(): with alf.summary.scope(self.name): alf.summary.scalar("entropy_std", entropy_std) alpha = torch.exp(self._log_alpha) return loss_info._replace(loss=loss_info.loss * alpha) def adjust_alpha(self, entropy): """Adjust alpha according to the current entropy. Args: entropy (scalar Tensor): the current entropy. Returns: adjusted entropy regularization """ prev_avg_entropy = self._avg_entropy.get() avg_entropy = self._avg_entropy.average(entropy) target_entropy = self._target_entropy() if target_entropy > 0: fast_stage_thresh = 0.5 * target_entropy else: fast_stage_thresh = 2.0 * target_entropy def _init_entropy(): self._max_entropy.fill_( torch.min(0.8 * avg_entropy, avg_entropy / 0.8)) self._stage.add_(1) def _init(): below = avg_entropy < self._max_entropy decreasing = (avg_entropy < prev_avg_entropy).type(torch.float32) # -1 * (1 - decreasing) + 0.5 * decreasing update_rate = (-1 + 1.5 * decreasing) * self._very_slow_update_rate self._stage.add_(below.type(torch.int32)) self._log_alpha.fill_( torch.max(self._log_alpha + update_rate, self._min_log_alpha)) def _free(): crossing = avg_entropy < target_entropy self._stage.add_(crossing.type(torch.int32)) def _adjust(): previous_above = self._stage.type(torch.bool) above = avg_entropy > target_entropy self._stage.fill_(above.type(torch.int32)) crossing = above != previous_above update_rate = self._update_rate update_rate = torch.where(crossing, 0.9 * update_rate, update_rate) update_rate = torch.max(update_rate, self._slow_update_rate) update_rate = torch.where(entropy < fast_stage_thresh, self._fast_update_rate, update_rate) self._update_rate.fill_(update_rate) above = above.type(torch.float32) below = 1 - above decreasing = (avg_entropy < prev_avg_entropy).type(torch.float32) increasing = 1 - decreasing log_alpha = self._log_alpha + ( (below + 0.5 * above) * decreasing - (above + 0.5 * below) * increasing) * update_rate log_alpha = torch.max(log_alpha, self._min_log_alpha) self._log_alpha.fill_(log_alpha) if self._stage < -2: _init_entropy() if self._stage == -2: _init() if self._stage == -1: _free() if self._stage >= 0: _adjust() alpha = torch.exp(self._log_alpha) if self._debug_summaries and should_record_summaries(): with alf.summary.scope(self.name): alf.summary.scalar("alpha", alpha) alf.summary.scalar("avg_entropy", avg_entropy) alf.summary.scalar("stage", self._stage) alf.summary.scalar("update_rate", self._update_rate) if type(self._target_entropy) != ConstantScheduler: alf.summary.scalar("target_entropy", target_entropy) return alpha