def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) actions = train_batch[SampleBatch.ACTIONS] # log\pi_\theta(a|s) logprobs = action_dist.logp(actions) # Advantage estimation. if self.config["beta"] != 0.0: cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] state_values = model.value_function() adv = cumulative_rewards - state_values adv_squared_mean = torch.mean(torch.pow(adv, 2.0)) explained_var = explained_variance(cumulative_rewards, state_values) self.explained_variance = torch.mean(explained_var) # Policy loss. # Update averaged advantage norm. rate = self.config["moving_average_sqd_adv_norm_update_rate"] self._moving_average_sqd_adv_norm.add_( rate * (adv_squared_mean - self._moving_average_sqd_adv_norm)) # Exponentially weighted advantages. exp_advs = torch.exp( self.config["beta"] * (adv / (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5))) ).detach() # Value loss. self.v_loss = 0.5 * adv_squared_mean else: # Policy loss (simple BC loss term). exp_advs = 1.0 # Value loss. self.v_loss = 0.0 # logprob loss alone tends to push action distributions to # have very low entropy, resulting in worse performance for # unfamiliar situations. # A scaled logstd loss term encourages stochasticity, thus # alleviate the problem to some extent. logstd_coeff = self.config["bc_logstd_coeff"] if logstd_coeff > 0.0: logstds = torch.mean(action_dist.log_std, dim=1) else: logstds = 0.0 self.p_loss = -torch.mean(exp_advs * (logprobs + logstd_coeff * logstds)) # Combine both losses. self.total_loss = self.p_loss + self.config["vf_coeff"] * self.v_loss return self.total_loss
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) actions = train_batch[SampleBatch.ACTIONS] # log\pi_\theta(a|s) logprobs = action_dist.logp(actions) # Advantage estimation. if policy.config["beta"] != 0.0: cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] state_values = model.value_function() adv = cumulative_rewards - state_values adv_squared_mean = torch.mean(torch.pow(adv, 2.0)) explained_var = explained_variance(cumulative_rewards, state_values) policy.explained_variance = torch.mean(explained_var) # Policy loss. # Update averaged advantage norm. rate = policy.config["moving_average_sqd_adv_norm_update_rate"] policy._moving_average_sqd_adv_norm.add_( rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm)) # Exponentially weighted advantages. exp_advs = torch.exp( policy.config["beta"] * (adv / (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5)))) policy.p_loss = -torch.mean(exp_advs.detach() * logprobs) # Value loss. policy.v_loss = 0.5 * adv_squared_mean else: # Policy loss (simple BC loss term). policy.p_loss = -1.0 * torch.mean(logprobs) # Value loss. policy.v_loss = 0.0 # Combine both losses. policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \ policy.v_loss return policy.total_loss
def build_vtrace_loss(policy, model, dist_class, train_batch): model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. drop_last = policy.config["vtrace_drop_last_ts"] loss = VTraceLoss( actions=_make_time_major(loss_actions, drop_last=drop_last), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=drop_last), actions_entropy=_make_time_major(action_dist.entropy(), drop_last=drop_last), dones=_make_time_major(dones, drop_last=drop_last), behaviour_action_logp=_make_time_major(behaviour_action_logp, drop_last=drop_last), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=drop_last), target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=_make_time_major(values, drop_last=drop_last), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, valid_mask=_make_time_major(mask, drop_last=drop_last), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"], ) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["pi_loss"] = loss.pi_loss model.tower_stats["vf_loss"] = loss.vf_loss model.tower_stats["entropy"] = loss.entropy model.tower_stats["mean_entropy"] = loss.mean_entropy model.tower_stats["total_loss"] = loss.total_loss values_batched = make_time_major( policy, train_batch.get(SampleBatch.SEQ_LENS), values, drop_last=policy.config["vtrace"] and drop_last, ) model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1])) return loss.total_loss
def loss(self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The PPO loss tensor given the input batch. """ logits, state = model(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class( train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) # Only calculate kl loss if necessary (kl-coeff > 0.0). if self.config["kl_coeff"] > 0.0: action_kl = prev_action_dist.kl(curr_action_dist) mean_kl_loss = reduce_mean_valid(action_kl) else: mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if self.config["use_critic"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -self.config["vf_clip_param"], self.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) # Ignore the value function. else: vf_loss = mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + self.config["vf_loss_coeff"] * vf_loss - self.entropy_coeff * curr_entropy) # Add mean_kl_loss (already processed through `reduce_mean_valid`), # if necessary. if self.config["kl_coeff"] > 0.0: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["vf_explained_var"] = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["mean_kl_loss"] = mean_kl_loss return total_loss
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = self.target_models[model] model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kwargs): return make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs ) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=drop_last) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if self.config["vtrace"]: logger.debug( "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" ) old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split( behaviour_logits, list(output_hidden_shape), dim=1 ) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1 ) else: unpacked_behaviour_logits = torch.chunk( behaviour_logits, output_hidden_shape, dim=1 ) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1 ) # Prepare actions for loss. loss_actions = ( actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) ) # Prepare KL for loss. action_kl = _make_time_major( old_policy_action_dist.kl(action_dist), drop_last=drop_last ) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major( unpacked_behaviour_logits, drop_last=drop_last ), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=drop_last ), actions=torch.unbind( _make_time_major(loss_actions, drop_last=drop_last), dim=2 ), discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) * self.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=values_time_major[:-1] if drop_last else values_time_major, bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) actions_logp = _make_time_major( action_dist.logp(actions), drop_last=drop_last ) prev_actions_logp = _make_time_major( prev_action_dist.logp(actions), drop_last=drop_last ) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=drop_last ) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 ) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) self._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) if drop_last: delta = values_time_major[:-1] - value_targets else: delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=drop_last) ) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = ( mean_policy_loss + mean_vf_loss * self.config["vf_loss_coeff"] - mean_entropy * self.entropy_coeff ) # Optional additional KL Loss if self.config["use_kl_loss"]: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_kl_loss"] = mean_kl_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["value_targets"] = value_targets model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if drop_last else values_time_major, [-1] ), ) return total_loss