def __init__( self, actions, actions_logp, actions_entropy, dones, behaviour_action_logp, behaviour_logits, target_logits, discount, rewards, values, bootstrap_value, dist_class, model, valid_mask, config, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, ): """Policy gradient loss with vtrace importance weighting. VTraceLoss takes tensors of shape [T, B, ...], where `B` is the batch_size. The reason we need to know `B` is for V-trace to properly handle episode cut boundaries. Args: actions: An int|float32 tensor of shape [T, B, ACTION_SPACE]. actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. behaviour_action_logp: Tensor of shape [T, B]. behaviour_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] target_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). config: Trainer config dict. """ if valid_mask is None: valid_mask = torch.ones_like(actions_logp) # Compute vtrace on the CPU for better perf # (devices handled inside `vtrace.multi_from_logits`). device = behaviour_action_logp[0].device self.vtrace_returns = vtrace.multi_from_logits( behaviour_action_log_probs=behaviour_action_logp, behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, actions=torch.unbind(actions, dim=2), discounts=(1.0 - dones.float()) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, dist_class=dist_class, model=model, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, ) # Move v-trace results back to GPU for actual loss computing. self.value_targets = self.vtrace_returns.vs.to(device) # The policy gradients loss. self.pi_loss = -torch.sum( actions_logp * self.vtrace_returns.pg_advantages.to(device) * valid_mask) # The baseline loss. delta = (values - self.value_targets) * valid_mask self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0)) # The entropy loss. self.entropy = torch.sum(actions_entropy * valid_mask) self.mean_entropy = self.entropy / torch.sum(valid_mask) # The summed weighted loss. self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff)
def appo_surrogate_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> TensorType: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=policy.config["vtrace"]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for loss. action_kl = _make_time_major(old_policy_action_dist.kl(action_dist), drop_last=True) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=torch.unbind(_make_time_major(loss_actions, drop_last=True), dim=2), discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) * policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"] ) actions_logp = _make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) delta = values_time_major[:-1] - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=True)) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets policy._vf_explained_var = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if policy.config["vtrace"] else values_time_major, [-1]), ) return total_loss
def __init__(self, actions, prev_actions_logp, actions_logp, old_policy_actions_logp, action_kl, actions_entropy, dones, behaviour_logits, old_policy_behaviour_logits, target_logits, discount, rewards, values, bootstrap_value, dist_class, model, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, clip_param=0.3, cur_kl_coeff=None, use_kl_loss=False): """APPO Loss, with IS modifications and V-trace for Advantage Estimation VTraceLoss takes tensors of shape [T, B, ...], where `B` is the batch_size. The reason we need to know `B` is for V-trace to properly handle episode cut boundaries. Arguments: actions: An int|float32 tensor of shape [T, B, logit_dim]. prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. old_policy_actions_logp: A float32 tensor of shape [T, B]. action_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. behaviour_logits: A float32 tensor of shape [T, B, logit_dim]. old_policy_behaviour_logits: A float32 tensor of shape [T, B, logit_dim]. target_logits: A float32 tensor of shape [T, B, logit_dim]. discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. dist_class: action distribution class for logits. model: backing ModelV2 instance valid_mask: A bool tensor of valid RNN input elements (#2992). vf_loss_coeff (float): Coefficient of the value function loss. entropy_coeff (float): Coefficient of the entropy regularizer. clip_param (float): Clip parameter. cur_kl_coeff (float): Coefficient for KL loss. use_kl_loss (bool): If true, use KL loss. """ if valid_mask is not None: num_valid = torch.sum(valid_mask) def reduce_mean_valid(t): return torch.sum(t * valid_mask) / num_valid else: def reduce_mean_valid(t): return torch.mean(t) # Compute vtrace on the CPU for better perf. self.vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=old_policy_behaviour_logits, actions=torch.unbind(actions, dim=2), discounts=(1.0 - dones.float()) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, dist_class=dist_class, model=model, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) self.is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = self.is_ratio * torch.exp(actions_logp - prev_actions_logp) advantages = self.vtrace_returns.pg_advantages surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param)) self.mean_kl = reduce_mean_valid(action_kl) self.pi_loss = -reduce_mean_valid(surrogate_loss) # The baseline loss delta = values - self.vtrace_returns.vs self.value_targets = self.vtrace_returns.vs self.vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss self.entropy = reduce_mean_valid(actions_entropy) # The summed weighted loss self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff) # Optional additional KL Loss if use_kl_loss: self.total_loss += cur_kl_coeff * self.mean_kl
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = self.target_models[model] model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kwargs): return make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs ) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=drop_last) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if self.config["vtrace"]: logger.debug( "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" ) old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split( behaviour_logits, list(output_hidden_shape), dim=1 ) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1 ) else: unpacked_behaviour_logits = torch.chunk( behaviour_logits, output_hidden_shape, dim=1 ) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1 ) # Prepare actions for loss. loss_actions = ( actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) ) # Prepare KL for loss. action_kl = _make_time_major( old_policy_action_dist.kl(action_dist), drop_last=drop_last ) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major( unpacked_behaviour_logits, drop_last=drop_last ), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=drop_last ), actions=torch.unbind( _make_time_major(loss_actions, drop_last=drop_last), dim=2 ), discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) * self.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=values_time_major[:-1] if drop_last else values_time_major, bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) actions_logp = _make_time_major( action_dist.logp(actions), drop_last=drop_last ) prev_actions_logp = _make_time_major( prev_action_dist.logp(actions), drop_last=drop_last ) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=drop_last ) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 ) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) self._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) if drop_last: delta = values_time_major[:-1] - value_targets else: delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=drop_last) ) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = ( mean_policy_loss + mean_vf_loss * self.config["vf_loss_coeff"] - mean_entropy * self.entropy_coeff ) # Optional additional KL Loss if self.config["use_kl_loss"]: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_kl_loss"] = mean_kl_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["value_targets"] = value_targets model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if drop_last else values_time_major, [-1] ), ) return total_loss