def __init__( self, actions, actions_logp, actions_entropy, dones, behaviour_action_logp, behaviour_logits, target_logits, discount, rewards, values, bootstrap_value, dist_class, model, valid_mask, config, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, ): """Policy gradient loss with vtrace importance weighting. VTraceLoss takes tensors of shape [T, B, ...], where `B` is the batch_size. The reason we need to know `B` is for V-trace to properly handle episode cut boundaries. Args: actions: An int|float32 tensor of shape [T, B, ACTION_SPACE]. actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. behaviour_action_logp: Tensor of shape [T, B]. behaviour_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] target_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). config: Algorithm config dict. """ # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): self.vtrace_returns = vtrace.multi_from_logits( behaviour_action_log_probs=behaviour_action_logp, behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, actions=tf.unstack(actions, axis=2), discounts=tf.cast(~tf.cast(dones, tf.bool), tf.float32) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, dist_class=dist_class, model=model, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, tf.float32), ) self.value_targets = self.vtrace_returns.vs # The policy gradients loss. masked_pi_loss = tf.boolean_mask( actions_logp * self.vtrace_returns.pg_advantages, valid_mask) self.pi_loss = -tf.reduce_sum(masked_pi_loss) self.mean_pi_loss = -tf.reduce_mean(masked_pi_loss) # The baseline loss. delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) delta_squarred = tf.math.square(delta) self.vf_loss = 0.5 * tf.reduce_sum(delta_squarred) self.mean_vf_loss = 0.5 * tf.reduce_mean(delta_squarred) # The entropy loss. masked_entropy = tf.boolean_mask(actions_entropy, valid_mask) self.entropy = tf.reduce_sum(masked_entropy) self.mean_entropy = tf.reduce_mean(masked_entropy) # The summed weighted loss. self.total_loss = self.pi_loss - self.entropy * entropy_coeff # Optional vf loss (or in a separate term due to separate # optimizers/networks). self.loss_wo_vf = self.total_loss if not config["_separate_vf_optimizer"]: self.total_loss += self.vf_loss * vf_loss_coeff
def loss( self, model: Union[ModelV2, "tf.keras.Model"], dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 # TODO: (sven) deprecate this when trajectory view API gets activated. def make_time_major(*args, **kw): return _make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw ) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = self.target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, self.model) values = self.model.value_function() values_time_major = make_time_major(values) if self.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) mask = make_time_major(mask, drop_last=self.config["vtrace"]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) else: reduce_mean_valid = tf.reduce_mean if self.config["vtrace"]: drop_last = self.config["vtrace_drop_last_ts"] logger.debug( "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" ) # Prepare actions for loss. loss_actions = ( actions if is_multidiscrete else tf.expand_dims(actions, axis=1) ) old_policy_behaviour_logits = tf.stop_gradient(target_model_out) old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) # Prepare KL for Loss mean_kl = make_time_major( old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last ) unpacked_behaviour_logits = tf.split( behaviour_logits, output_hidden_shape, axis=1 ) unpacked_old_policy_behaviour_logits = tf.split( old_policy_behaviour_logits, output_hidden_shape, axis=1 ) # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=make_time_major( unpacked_behaviour_logits, drop_last=drop_last ), target_policy_logits=make_time_major( unpacked_old_policy_behaviour_logits, drop_last=drop_last ), actions=tf.unstack( make_time_major(loss_actions, drop_last=drop_last), axis=2 ), discounts=tf.cast( ~make_time_major( tf.cast(dones, tf.bool), drop_last=drop_last ), tf.float32, ) * self.config["gamma"], rewards=make_time_major(rewards, drop_last=drop_last), values=values_time_major[:-1] if drop_last else values_time_major, bootstrap_value=values_time_major[-1], dist_class=Categorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=tf.cast( self.config["vtrace_clip_rho_threshold"], tf.float32 ), clip_pg_rho_threshold=tf.cast( self.config["vtrace_clip_pg_rho_threshold"], tf.float32 ), ) actions_logp = make_time_major( action_dist.logp(actions), drop_last=drop_last ) prev_actions_logp = make_time_major( prev_action_dist.logp(actions), drop_last=drop_last ) old_policy_actions_logp = make_time_major( old_policy_action_dist.logp(actions), drop_last=drop_last ) is_ratio = tf.clip_by_value( tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 ) logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp) self._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) action_kl = ( tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. if drop_last: delta = values_time_major[:-1] - vtrace_returns.vs else: delta = values_time_major - vtrace_returns.vs value_targets = vtrace_returns.vs mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. actions_entropy = make_time_major( action_dist.multi_entropy(), drop_last=True ) mean_entropy = reduce_mean_valid(actions_entropy) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist)) logp_ratio = tf.math.exp( make_time_major(action_dist.logp(actions)) - make_time_major(prev_action_dist.logp(actions)) ) advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) action_kl = ( tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = make_time_major( train_batch[Postprocessing.VALUE_TARGETS] ) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. mean_entropy = reduce_mean_valid( make_time_major(action_dist.multi_entropy()) ) # The summed weighted loss. total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff # Optional KL loss. if self.config["use_kl_loss"]: total_loss += self.kl_coeff * mean_kl_loss # Optional vf loss (or in a separate term due to separate # optimizers/networks). loss_wo_vf = total_loss if not self.config["_separate_vf_optimizer"]: total_loss += mean_vf_loss * self.config["vf_loss_coeff"] # Store stats in policy for stats_fn. self._total_loss = total_loss self._loss_wo_vf = loss_wo_vf self._mean_policy_loss = mean_policy_loss # Backward compatibility: Deprecate policy._mean_kl. self._mean_kl_loss = self._mean_kl = mean_kl_loss self._mean_vf_loss = mean_vf_loss self._mean_entropy = mean_entropy self._value_targets = value_targets # Return one total loss or two losses: vf vs rest (policy + kl). if self.config["_separate_vf_optimizer"]: return loss_wo_vf, mean_vf_loss else: return total_loss