def actor_critic_loss(policy, model, dist_class, train_batch): logits, _ = model.from_batch(train_batch) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) policy.pi_err = -torch.sum( torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) policy.value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask), 2.0)) policy.entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) overall_err = ( policy.pi_err + policy.value_err * policy.config["vf_loss_coeff"] - policy.entropy * policy.config["entropy_coeff"]) return overall_err
def forward(self, inputs: TensorType) -> TensorType: L = list(inputs.size())[1] # length of segment H = self._num_heads # number of attention heads D = self._head_dim # attention head dimension qkv = self._qkv_layer(inputs) queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) queries = queries[:, -L:] # only query based on the segment queries = torch.reshape(queries, [-1, L, H, D]) keys = torch.reshape(keys, [-1, L, H, D]) values = torch.reshape(values, [-1, L, H, D]) score = torch.einsum("bihd,bjhd->bijh", queries, keys) score = score / D**0.5 # causal mask of the same length as the sequence mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype) mask = mask[None, :, :, None] mask = mask.float() masked_score = score * mask + 1e30 * (mask - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) shape = list(out.size())[:2] + [H * D] # temp = torch.cat(temp2, [H * D], dim=0) out = torch.reshape(out, shape) return self._linear_layer(out)
def actor_critic_loss(policy, model, dist_class, train_batch): # If policy is recurrent, mask out padded sequences # and calculate batch size if policy.is_recurrent(): seq_lens = train_batch['seq_lens'] max_seq_len = torch.max(seq_lens) mask_orig = sequence_mask(seq_lens, max_seq_len) mask = torch.reshape(mask_orig, [-1]) batch_size = seq_lens.shape[0] * max_seq_len else: mask = torch.ones_like(train_batch[SampleBatch.REWARDS]) batch_size = mask.shape[0] logits, _ = model.from_batch(train_batch) values = model.value_function() dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]) policy.entropy = -torch.sum(dist.entropy() * mask) / batch_size policy.pi_err = -torch.sum(train_batch[Postprocessing.ADVANTAGES] * log_probs.reshape(-1) * mask) / batch_size policy.value_err = torch.sum( torch.pow( (values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS]) * mask, 2.0)) / batch_size overall_err = sum([ policy.pi_err, policy.config['vf_loss_coeff'] * policy.value_err, policy.config['entropy_coeff'] * policy.entropy, ]) return overall_err
def ppo_surrogate_loss(policy, model, dist_class, train_batch): logits, state = model.from_batch(train_batch) action_dist = dist_class(logits, model) mask = None if state: max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) policy.loss_obj = PPOLoss( dist_class, model, train_batch[Postprocessing.VALUE_TARGETS], train_batch[Postprocessing.ADVANTAGES], train_batch[SampleBatch.ACTIONS], train_batch[SampleBatch.ACTION_DIST_INPUTS], train_batch[SampleBatch.ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, model.value_function(), policy.kl_coeff, mask, entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], use_gae=policy.config["use_gae"], ) return policy.loss_obj.loss
def actor_critic_loss(policy, model, dist_class, train_batch): assert policy.is_recurrent(), "policy must be recurrent" seq_lens = train_batch['seq_lens'] batch_size = seq_lens.shape[0] max_seq_len = torch.max(seq_lens) mask_orig = sequence_mask(seq_lens, max_seq_len) mask = torch.reshape(mask_orig, [-1]) horizon = policy.config['fun_horizon'] manager_horizon_mask = mask_orig.clone() manager_horizon_mask[:, -horizon:] = False manager_horizon_mask = manager_horizon_mask.reshape(-1) # Hacky way of passing data from sample batch to train batch model.random_select = train_batch['random_select'].reshape( (batch_size, -1)) model.random_goal = train_batch['random_goal'].reshape( (batch_size, max_seq_len, -1)) logits, _ = model.from_batch(train_batch) manager_values, worker_values = model.value_function() manager_latent_state, manager_goal = model.manager_features() manager_latent_state_future = torch.roll(manager_latent_state, -horizon, 1) manager_latent_state_diff = (manager_latent_state_future - manager_latent_state).detach() policy.manager_loss = 10.0 * -torch.sum( train_batch['manager_advantages'] * F.cosine_similarity( manager_latent_state_diff, manager_goal, dim=-1).reshape(-1) * manager_horizon_mask) / (batch_size * max_seq_len) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]) policy.entropy = 3e-4 * -torch.sum(dist.entropy() * mask) / (batch_size * max_seq_len) policy.pi_err = 0.1 * -torch.sum( train_batch['worker_advantages'] * log_probs.reshape(-1) * mask) / ( batch_size * max_seq_len) policy.manager_value_err = torch.sum( torch.pow( (manager_values.reshape(-1) - train_batch['manager_value_targets']) * mask, 2.0)) / (batch_size * max_seq_len) policy.worker_value_err = 0.01 * torch.sum( torch.pow( (worker_values.reshape(-1) - train_batch['worker_value_targets']) * mask, 2.0)) / (batch_size * max_seq_len) overall_err = sum([ policy.pi_err, policy.manager_value_err, policy.worker_value_err, policy.entropy, policy.manager_loss, ]) return overall_err
def loss_fn(policy: Policy, model: ModelV2, dist_class: TorchDistributionWrapper, sample_batch: SampleBatch): max_seq_len = sample_batch['seq_lens'].max().item() mask = sequence_mask(sample_batch['seq_lens'], max_seq_len, time_major=model.is_time_major()).view((-1, 1)) mean_reg = sample_batch['seq_lens'].sum() * model.nbr_agents actions = sample_batch['actions'].view( (sample_batch['actions'].shape[0], model.nbr_agents, -1))[:, :, :1].to(torch.long) actions = add_time_dimension(actions, max_seq_len=max_seq_len, framework='torch', time_major=True).reshape_as(actions) logits_pi, _ = model(sample_batch, [ sample_batch['state_in_0'], ], sample_batch['seq_lens']) logits_pi = logits_pi.view((logits_pi.shape[0], model.nbr_agents, -1)) logits_pi_action = logits_pi[:, :, :model.nbr_actions] log_pi_action = nn.functional.log_softmax(logits_pi_action, dim=-1) pi_action = torch.exp(log_pi_action) log_pi_action_selected = torch.gather(log_pi_action, -1, actions).squeeze(-1) q_values = model.q_values(sample_batch, target=False) q_values = add_time_dimension(q_values, max_seq_len=max_seq_len, framework="torch", time_major=True).reshape_as(q_values) q_values_selected = torch.gather(q_values, -1, actions).squeeze(-1) q_values_target = sample_batch[Postprocessing.VALUE_TARGETS] q_values_target = add_time_dimension( q_values_target, max_seq_len=max_seq_len, framework="torch", time_major=True).reshape_as(q_values_target) td_error = q_values_selected - q_values_target with torch.no_grad(): coma_avg = q_values_selected - (pi_action * q_values).sum(-1) entropy = -(log_pi_action * pi_action).sum(-1) critic_loss = torch.pow(mask * td_error, 2.0) actor_loss = mask * coma_avg * log_pi_action_selected entropy = mask * entropy policy.actor_loss = -actor_loss.sum() / mean_reg policy.critic_loss = critic_loss.sum() / mean_reg policy.entropy = entropy.sum() / mean_reg pi_loss = policy.actor_loss - policy.config[ 'entropy_coeff'] * policy.entropy return pi_loss, policy.critic_loss
def forward(self, inputs: TensorType, memory: TensorType = None) -> TensorType: T = list(inputs.size())[1] # length of segment (time) H = self._num_heads # number of attention heads d = self._head_dim # attention head dimension # Add previous memory chunk (as const, w/o gradient) to input. # Tau (number of (prev) time slices in each memory chunk). Tau = list(memory.shape)[1] if memory is not None else 0 if memory is not None: memory.requires_grad_(False) inputs = torch.cat((memory, inputs), dim=1) # Apply the Layer-Norm. if self._input_layernorm is not None: inputs = self._input_layernorm(inputs) qkv = self._qkv_layer(inputs) queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) # Cut out Tau memory timesteps from query. queries = queries[:, -T:] queries = torch.reshape(queries, [-1, T, H, d]) keys = torch.reshape(keys, [-1, T + Tau, H, d]) values = torch.reshape(values, [-1, T + Tau, H, d]) R = self._pos_proj(self._rel_pos_encoder) R = torch.reshape(R, [T + Tau, H, d]) # b=batch # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) # h=head # d=head-dim (over which we will reduce-sum) score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R) score = score + self.rel_shift(pos_score) score = score / d**0.5 # causal mask of the same length as the sequence mask = sequence_mask( torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype) mask = mask[None, :, :, None] masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) shape = list(out.shape)[:2] + [H * d] out = torch.reshape(out, shape) return self._linear_layer(out)
def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: logits, _ = model.from_batch(train_batch) values = model.value_function() if policy.is_recurrent(): B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if policy.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask), 2.0)) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["entropy"] = entropy model.tower_stats["pi_err"] = pi_err model.tower_stats["value_err"] = value_err return total_loss
def ppo_surrogate_loss(policy, model, dist_class, train_batch): logits, state = model.from_batch(train_batch) action_dist = dist_class(logits, model) mask = None if state: max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) if policy.config["use_aux_loss"]: aux_input = { "new_obs": train_batch[SampleBatch.NEXT_OBS], "actions": train_batch[SampleBatch.ACTIONS], "dones": train_batch["dones"], "rewards": train_batch["rewards"] } aux_loss = policy.model.aux_loss(aux_input) train_batch["aux_loss"] = aux_loss policy.loss_obj = PPOLoss( dist_class, model, train_batch[Postprocessing.VALUE_TARGETS], train_batch[Postprocessing.ADVANTAGES], train_batch[SampleBatch.ACTIONS], train_batch[SampleBatch.ACTION_DIST_INPUTS], train_batch[SampleBatch.ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, model.value_function(), policy.kl_coeff, mask, train_batch.get("aux_loss", None), entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], aux_loss_coeff=policy.config["aux_loss_coeff"], use_gae=policy.config["use_gae"], ) return policy.loss_obj.loss
def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: logits, _ = model.from_batch(train_batch) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if policy.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask), 2.0)) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) policy.entropy = entropy policy.pi_err = pi_err policy.value_err = value_err return total_loss
def build_vtrace_loss(policy, model, dist_class, train_batch): model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. loss = VTraceLoss( actions=_make_time_major(loss_actions, drop_last=True), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=True), actions_entropy=_make_time_major(action_dist.entropy(), drop_last=True), dones=_make_time_major(dones, drop_last=True), behaviour_action_logp=_make_time_major(behaviour_action_logp, drop_last=True), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), target_logits=_make_time_major(unpacked_outputs, drop_last=True), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=_make_time_major(values, drop_last=True), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, valid_mask=_make_time_major(mask, drop_last=True), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["pi_loss"] = loss.pi_loss model.tower_stats["vf_loss"] = loss.vf_loss model.tower_stats["entropy"] = loss.entropy model.tower_stats["mean_entropy"] = loss.mean_entropy model.tower_stats["total_loss"] = loss.total_loss values_batched = make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), values, drop_last=policy.config["vtrace"]) model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1])) return loss.total_loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model.from_batch(train_batch, is_training=True) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch["seq_lens"]) max_seq_len = logits.shape[0] // B mask = sequence_mask(train_batch["seq_lens"], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._vf_explained_var = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function()) policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss
def build_CAT_vtrace_loss(policy, model, dist_class, train_batch): action_space_parts = model.action_space_parts def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) # Repeat the output_hidden_shape depending on the number of actions that have been generated # output_hidden_shape = np.tile(output_hidden_shape, action_repeats) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] invalid_action_mask = train_batch['invalid_action_mask'] autoregressive_actions = policy.config['autoregressive_actions'] if 'seq_lens' in train_batch: max_seq_len = policy.config['rollout_fragment_length'] mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) actions_per_step = policy.config["actions_per_step"] states = [] i = 0 while "state_in_{}".format(i) in train_batch: states.append(train_batch["state_in_{}".format(i)]) i += 1 seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] model.observation_features_module(train_batch, states, seq_lens) action_features, _ = model.action_features_module(train_batch, states, seq_lens) previous_action = None embedded_action = None logp_list = [] entropy_list = [] logits_list = [] multi_actions = torch.chunk(actions, actions_per_step, dim=1) multi_invalid_action_mask = torch.chunk(invalid_action_mask, actions_per_step, dim=1) for a in range(actions_per_step): if autoregressive_actions: if a == 0: batch_size = action_features.shape[0] previous_action = torch.zeros([batch_size, len(action_space_parts)]).to(action_features.device) else: previous_action = multi_actions[a-1] embedded_action = model.embed_action_module(previous_action) logits = model.action_module(action_features, embedded_action) logits += torch.maximum(torch.tensor(torch.finfo().min), torch.log(multi_invalid_action_mask[a])) cat = TorchMultiCategorical(logits, model, action_space_parts) logits_list.append(logits) logp_list.append(cat.logp(multi_actions[a])) entropy_list.append(cat.entropy()) logp = torch.stack(logp_list, dim=1).sum(dim=1) entropy = torch.stack(entropy_list, dim=1).sum(dim=1) target_logits = torch.hstack(logits_list) unpack_shape = np.tile(action_space_parts, actions_per_step) unpacked_behaviour_logits = torch.split(behaviour_logits, list(unpack_shape), dim=1) unpacked_outputs = torch.split(target_logits, list(unpack_shape), dim=1) values = model.value_function() # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. policy.loss = VTraceLoss( actions=_make_time_major(actions, drop_last=True), actions_logp=_make_time_major(logp, drop_last=True), actions_entropy=_make_time_major(entropy, drop_last=True), dones=_make_time_major(dones, drop_last=True), behaviour_action_logp=_make_time_major( behaviour_action_logp, drop_last=True), behaviour_logits=_make_time_major( unpacked_behaviour_logits, drop_last=True), target_logits=_make_time_major(unpacked_outputs, drop_last=True), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=_make_time_major(values, drop_last=True), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical, model=model, valid_mask=_make_time_major(mask, drop_last=True), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) return policy.loss.total_loss
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values(policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get("seq_lens"), explore=False, is_training=True) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values(policy, policy.target_q_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get("seq_lens"), explore=False, is_training=True) actions = train_batch[SampleBatch.ACTIONS].long() dones = train_batch[SampleBatch.DONES].float() rewards = train_batch[SampleBatch.REWARDS] weights = train_batch[PRIO_WEIGHTS] B = state_batches[0].shape[0] T = q.shape[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(actions, policy.action_space.n) q_selected = torch.sum( torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device)) * one_hot_selection, 1) if config["double_q"]: best_actions = torch.argmax(q, dim=1) else: best_actions = torch.argmax(q_target, dim=1) best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) q_target_best = torch.sum( torch.where(q_target > FLOAT_MIN, q_target, torch.tensor(0.0, device=policy.device)) * best_actions_one_hot, dim=1) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * torch.cat( [q_target_best[1:], torch.tensor([0.0], device=policy.device)]) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"]**config["n_step"] * h_inv, config["h_function_epsilon"]) else: target = rewards + \ config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = q_selected.reshape([B, T])[:, :-1] td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) policy._td_error = td_error.reshape([-1]) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": torch.min(q_selected), "max_q": torch.max(q_selected), "mean_td_error": reduce_mean_valid(td_error), } return policy._total_loss
def build_vtrace_loss(policy, model, dist_class, train_batch): model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. loss = VTraceLoss( actions=_make_time_major(loss_actions, drop_last=True), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=True), actions_entropy=_make_time_major(action_dist.entropy(), drop_last=True), dones=_make_time_major(dones, drop_last=True), behaviour_action_logp=_make_time_major(behaviour_action_logp, drop_last=True), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), target_logits=_make_time_major(unpacked_outputs, drop_last=True), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=_make_time_major(values, drop_last=True), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, valid_mask=_make_time_major(mask, drop_last=True), config=policy.config, vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.entropy_coeff, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) # Store loss object only for multi-GPU tower 0. if model is policy.model_gpu_towers[0]: policy.loss = loss return loss.total_loss
def build_appo_surrogate_loss(policy, model, dist_class, train_batch): model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = policy.target_model.from_batch(train_batch) old_policy_behaviour_logits = target_model_out.detach() unpacked_behaviour_logits = torch.split(behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1) old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() policy.model_vars = policy.model.variables() policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) - 1 mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) else: mask = torch.ones_like(rewards) if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") # Prepare actions for loss loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for Loss mean_kl = _make_time_major( old_policy_action_dist.multi_kl(action_dist), drop_last=True) policy.loss = VTraceSurrogateLoss( actions=_make_time_major(loss_actions, drop_last=True), prev_actions_logp=_make_time_major(prev_action_dist.logp(actions), drop_last=True), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=True), old_policy_actions_logp=_make_time_major( old_policy_action_dist.logp(actions), drop_last=True), action_kl=torch.mean(mean_kl, dim=0) if is_multidiscrete else mean_kl, actions_entropy=_make_time_major(action_dist.multi_entropy(), drop_last=True), dones=_make_time_major(dones, drop_last=True), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), old_policy_behaviour_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), target_logits=_make_time_major(unpacked_outputs, drop_last=True), discount=policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=_make_time_major(values, drop_last=True), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=policy.model, valid_mask=_make_time_major(mask, drop_last=True), vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.config["entropy_coeff"], clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy. config["vtrace_clip_pg_rho_threshold"], clip_param=policy.config["clip_param"], cur_kl_coeff=policy.kl_coeff, use_kl_loss=policy.config["use_kl_loss"]) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss mean_kl = _make_time_major(prev_action_dist.multi_kl(action_dist)) policy.loss = PPOSurrogateLoss( prev_actions_logp=_make_time_major(prev_action_dist.logp(actions)), actions_logp=_make_time_major(action_dist.logp(actions)), action_kl=torch.mean(mean_kl, dim=0) if is_multidiscrete else mean_kl, actions_entropy=_make_time_major(action_dist.multi_entropy()), values=_make_time_major(values), valid_mask=_make_time_major(mask), advantages=_make_time_major( train_batch[Postprocessing.ADVANTAGES]), value_targets=_make_time_major( train_batch[Postprocessing.VALUE_TARGETS]), vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.config["entropy_coeff"], clip_param=policy.config["clip_param"], cur_kl_coeff=policy.kl_coeff, use_kl_loss=policy.config["use_kl_loss"]) return policy.loss.total_loss
def appo_surrogate_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> TensorType: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=policy.config["vtrace"]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for loss. action_kl = _make_time_major(old_policy_action_dist.kl(action_dist), drop_last=True) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=torch.unbind(_make_time_major(loss_actions, drop_last=True), dim=2), discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) * policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"] ) actions_logp = _make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) delta = values_time_major[:-1] - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=True)) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets policy._vf_explained_var = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if policy.config["vtrace"] else values_time_major, [-1]), ) return total_loss
def drq_ppo_surrogate_loss(policy, model, dist_class, train_batch): """ loss function for PPO with input augmentations """ ################################################################################################ # logits, state = model.from_batch(train_batch) # action_dist = dist_class(logits, model) # mask = None # if state: # max_seq_len = torch.max(train_batch["seq_lens"]) # mask = sequence_mask(train_batch["seq_lens"], max_seq_len) # mask = torch.reshape(mask, [-1]) # policy.loss_obj = PPOLoss( # dist_class, # model, # train_batch[Postprocessing.VALUE_TARGETS], # train_batch[Postprocessing.ADVANTAGES], # train_batch[SampleBatch.ACTIONS], # train_batch[SampleBatch.ACTION_DIST_INPUTS], # train_batch[SampleBatch.ACTION_LOGP], # train_batch[SampleBatch.VF_PREDS], # action_dist, # model.value_function(), # policy.kl_coeff, # mask, # entropy_coeff=policy.entropy_coeff, # clip_param=policy.config["clip_param"], # vf_clip_param=policy.config["vf_clip_param"], # vf_loss_coeff=policy.config["vf_loss_coeff"], # use_gae=policy.config["use_gae"], # ) # return policy.loss_obj.loss # NOTE: averaged augmented loss for ppo aug_num = policy.config["aug_num"] aug_loss = 0 orig_cur_obs = train_batch[SampleBatch.CUR_OBS].clone() for _ in range(aug_num): # do augmentation, overwrite last augmented obs aug_cur_obs = model.trans(orig_cur_obs.permute(0, 3, 1, 2).float()).permute( 0, 2, 3, 1) train_batch[SampleBatch.CUR_OBS] = aug_cur_obs # forward with augmented obs logits, state = model.from_batch(train_batch) action_dist = dist_class(logits, model) mask = None if state: max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) policy.loss_obj = PPOLoss( dist_class, model, train_batch[Postprocessing.VALUE_TARGETS], train_batch[Postprocessing.ADVANTAGES], train_batch[SampleBatch.ACTIONS], train_batch[SampleBatch.ACTION_DIST_INPUTS], train_batch[SampleBatch.ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, model.value_function(), policy.kl_coeff, mask, entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], use_gae=policy.config["use_gae"], ) # accumulate loss with augmented obs aug_loss += policy.loss_obj.loss return aug_loss / aug_num
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get("seq_lens") model_out_t, state_in_t = model( { "obs": train_batch[SampleBatch.CUR_OBS], "prev_actions": train_batch[SampleBatch.PREV_ACTIONS], "prev_rewards": train_batch[SampleBatch.PREV_REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) target_states_in_tp1 = \ policy.target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax(model.get_policy_output( model_out_t, states_in_t["policy"], seq_lens)[0], dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0] # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens)[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens)[0] twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)[0] q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t, states_in_t["policy"], seq_lens)[0], policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t)[0] if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0] q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1)[0] if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1)[0] # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch["seq_lens"], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Save for stats function. policy.q_t = q_t * seq_mask[..., None] policy.policy_t = policy_t * seq_mask[..., None] policy.log_pis_t = log_pis_t * seq_mask[..., None] # Store td-error in model, such that for multi-GPU, we do not override # them during the parallel loss phase. TD-error tensor in final stats # can then be concatenated and retrieved for each individual batch item. model.td_error = td_error * seq_mask policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def mu_zero_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: logits, state = model.from_batch(train_batch, is_training=True) curr_action_dist = dist_class(logits, model) mcts_policy = dist_class(train_batch["mcts_policy"], model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp = curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) logp_ratio = torch.exp(logp - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid( -surrogate_loss * policy.config["surrogate_coeff"] + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) pred_reward = model.reward_function(policy_logits=logits) # rewards need to be float for proper bAcK pRopAgAtiOn reward_loss = torch.nn.functional.mse_loss( pred_reward, train_batch[SampleBatch.REWARDS].float()) mcts_loss = torch.mean(-mcts_policy.dist.probs * torch.log(curr_action_dist.dist.probs)) total_loss += reward_loss + mcts_loss # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl policy._mean_reward_loss = reward_loss policy._mcts_loss = mcts_loss return total_loss