def forward(self, inputs: TensorType) -> TensorType: L = list(inputs.size())[1] # length of segment H = self._num_heads # number of attention heads D = self._head_dim # attention head dimension qkv = self._qkv_layer(inputs) queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) queries = queries[:, -L:] # only query based on the segment queries = torch.reshape(queries, [-1, L, H, D]) keys = torch.reshape(keys, [-1, L, H, D]) values = torch.reshape(values, [-1, L, H, D]) score = torch.einsum("bihd,bjhd->bijh", queries, keys) score = score / D**0.5 # causal mask of the same length as the sequence mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype) mask = mask[None, :, :, None] mask = mask.float() masked_score = score * mask + 1e30 * (mask - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) shape = list(out.size())[:2] + [H * D] # temp = torch.cat(temp2, [H * D], dim=0) out = torch.reshape(out, shape) return self._linear_layer(out)
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch, ) -> TensorType: logits, _ = model(train_batch) values = model.value_function() if policy.is_recurrent(): B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if policy.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask, ), 2.0, )) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.entropy_coeff) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["entropy"] = entropy model.tower_stats["pi_err"] = pi_err model.tower_stats["value_err"] = value_err return total_loss
def forward(self, inputs: TensorType, memory: TensorType = None) -> TensorType: T = list(inputs.size())[1] # length of segment (time) H = self._num_heads # number of attention heads d = self._head_dim # attention head dimension # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). Tau = list(memory.shape)[1] inputs = torch.cat((memory.detach(), inputs), dim=1) # Apply the Layer-Norm. if self._input_layernorm is not None: inputs = self._input_layernorm(inputs) qkv = self._qkv_layer(inputs) queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) # Cut out Tau memory timesteps from query. queries = queries[:, -T:] queries = torch.reshape(queries, [-1, T, H, d]) keys = torch.reshape(keys, [-1, Tau + T, H, d]) values = torch.reshape(values, [-1, Tau + T, H, d]) R = self._pos_proj(self._rel_pos_embedding(Tau + T)) R = torch.reshape(R, [Tau + T, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) # h=head # d=head-dim (over which we will reduce-sum) score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R) score = score + self.rel_shift(pos_score) score = score / d**0.5 # causal mask of the same length as the sequence mask = sequence_mask(torch.arange(Tau + 1, Tau + T + 1), dtype=score.dtype).to(score.device) mask = mask[None, :, :, None] masked_score = score * mask + 1e30 * (mask.float() - 1.0) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) shape = list(out.shape)[:2] + [H * d] out = torch.reshape(out, shape) return self._linear_layer(out)
def build_vtrace_loss(policy, model, dist_class, train_batch): model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. drop_last = policy.config["vtrace_drop_last_ts"] loss = VTraceLoss( actions=_make_time_major(loss_actions, drop_last=drop_last), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=drop_last), actions_entropy=_make_time_major(action_dist.entropy(), drop_last=drop_last), dones=_make_time_major(dones, drop_last=drop_last), behaviour_action_logp=_make_time_major(behaviour_action_logp, drop_last=drop_last), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=drop_last), target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=_make_time_major(values, drop_last=drop_last), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, valid_mask=_make_time_major(mask, drop_last=drop_last), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"], ) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["pi_loss"] = loss.pi_loss model.tower_stats["vf_loss"] = loss.vf_loss model.tower_stats["entropy"] = loss.entropy model.tower_stats["mean_entropy"] = loss.mean_entropy model.tower_stats["total_loss"] = loss.total_loss values_batched = make_time_major( policy, train_batch.get(SampleBatch.SEQ_LENS), values, drop_last=policy.config["vtrace"] and drop_last, ) model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1])) return loss.total_loss
def loss(self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The PPO loss tensor given the input batch. """ logits, state = model(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class( train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) # Only calculate kl loss if necessary (kl-coeff > 0.0). if self.config["kl_coeff"] > 0.0: action_kl = prev_action_dist.kl(curr_action_dist) mean_kl_loss = reduce_mean_valid(action_kl) else: mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if self.config["use_critic"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -self.config["vf_clip_param"], self.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) # Ignore the value function. else: vf_loss = mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + self.config["vf_loss_coeff"] * vf_loss - self.entropy_coeff * curr_entropy) # Add mean_kl_loss (already processed through `reduce_mean_valid`), # if necessary. if self.config["kl_coeff"] > 0.0: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["vf_explained_var"] = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["mean_kl_loss"] = mean_kl_loss return total_loss
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get(SampleBatch.SEQ_LENS) model_out_t, state_in_t = model( SampleBatch( obs=train_batch[SampleBatch.CUR_OBS], prev_actions=train_batch[SampleBatch.PREV_ACTIONS], prev_rewards=train_batch[SampleBatch.PREV_REWARDS], _is_training=True, ), state_batches, seq_lens, ) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], prev_actions=train_batch[SampleBatch.ACTIONS], prev_rewards=train_batch[SampleBatch.REWARDS], _is_training=True, ), state_batches, seq_lens, ) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = target_model( SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], prev_actions=train_batch[SampleBatch.ACTIONS], prev_rewards=train_batch[SampleBatch.REWARDS], _is_training=True, ), state_batches, seq_lens, ) target_states_in_tp1 = target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_t, states_in_t["policy"], seq_lens) log_pis_t = F.log_softmax( action_dist_inputs_t, dim=-1, ) policy_t = torch.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1, states_in_tp1["policy"], seq_lens) log_pis_tp1 = F.log_softmax( action_dist_inputs_tp1, -1, ) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens) # Target Q-values. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens) twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_t, states_in_t["policy"], seq_lens) action_dist_t = action_dist_class( action_dist_inputs_t, model, ) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_tp1, states_in_tp1["policy"], seq_lens) action_dist_tp1 = action_dist_class( action_dist_inputs_t, model, ) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS], ) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1, ) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach(), ), dim=-1, )) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t * seq_mask[..., None] model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # Store per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]), dim=-1) # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss function. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The A3C loss tensor given the input batch. """ logits, _ = model(train_batch) values = model.value_function() if self.is_recurrent(): B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select( log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if self.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask, ), 2.0, )) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * self.config["vf_loss_coeff"] - entropy * self.entropy_coeff) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["entropy"] = entropy model.tower_stats["pi_err"] = pi_err model.tower_stats["value_err"] = value_err return total_loss
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = self.target_models[model] model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kwargs): return make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs ) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=drop_last) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if self.config["vtrace"]: logger.debug( "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" ) old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split( behaviour_logits, list(output_hidden_shape), dim=1 ) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1 ) else: unpacked_behaviour_logits = torch.chunk( behaviour_logits, output_hidden_shape, dim=1 ) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1 ) # Prepare actions for loss. loss_actions = ( actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) ) # Prepare KL for loss. action_kl = _make_time_major( old_policy_action_dist.kl(action_dist), drop_last=drop_last ) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major( unpacked_behaviour_logits, drop_last=drop_last ), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=drop_last ), actions=torch.unbind( _make_time_major(loss_actions, drop_last=drop_last), dim=2 ), discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) * self.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=values_time_major[:-1] if drop_last else values_time_major, bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) actions_logp = _make_time_major( action_dist.logp(actions), drop_last=drop_last ) prev_actions_logp = _make_time_major( prev_action_dist.logp(actions), drop_last=drop_last ) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=drop_last ) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 ) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) self._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) if drop_last: delta = values_time_major[:-1] - value_targets else: delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=drop_last) ) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = ( mean_policy_loss + mean_vf_loss * self.config["vf_loss_coeff"] - mean_entropy * self.entropy_coeff ) # Optional additional KL Loss if self.config["use_kl_loss"]: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_kl_loss"] = mean_kl_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["value_targets"] = value_targets model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if drop_last else values_time_major, [-1] ), ) return total_loss
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ target_model = policy.target_models[model] config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values(policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get( SampleBatch.SEQ_LENS), explore=False, is_training=True) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values(policy, target_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get( SampleBatch.SEQ_LENS), explore=False, is_training=True) actions = train_batch[SampleBatch.ACTIONS].long() dones = train_batch[SampleBatch.DONES].float() rewards = train_batch[SampleBatch.REWARDS] weights = train_batch[PRIO_WEIGHTS] B = state_batches[0].shape[0] T = q.shape[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(actions, policy.action_space.n) q_selected = torch.sum( torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=q.device)) * one_hot_selection, 1) if config["double_q"]: best_actions = torch.argmax(q, dim=1) else: best_actions = torch.argmax(q_target, dim=1) best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) q_target_best = torch.sum( torch.where(q_target > FLOAT_MIN, q_target, torch.tensor(0.0, device=q_target.device)) * best_actions_one_hot, dim=1) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * torch.cat([ q_target_best[1:], torch.tensor([0.0], device=q_target_best.device) ]) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"]**config["n_step"] * h_inv, config["h_function_epsilon"]) else: target = rewards + \ config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = q_selected.reshape([B, T])[:, :-1] td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] total_loss = reduce_mean_valid(weights * huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_q"] = reduce_mean_valid(q_selected) model.tower_stats["min_q"] = torch.min(q_selected) model.tower_stats["max_q"] = torch.max(q_selected) model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error) # Store per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). model.tower_stats["td_error"] = torch.mean(td_error, dim=-1) return total_loss