def build_q_losses(policy: Policy, model, dist_class, train_batch: SampleBatch) -> TensorType: """Constructs the loss for SimpleQTorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distribution class. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ target_model = policy.target_models[model] # q network evaluation q_t = compute_q_values(policy, model, train_batch[SampleBatch.CUR_OBS], explore=False, is_training=True) # target q network evalution q_tp1 = compute_q_values( policy, target_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) # q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n) q_t_selected = torch.sum(q_t * one_hot_selection, 1) # compute estimate of best possible value starting from state at t + 1 dones = train_batch[SampleBatch.DONES].float() q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1), policy.action_space.n) q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1) q_tp1_best_masked = (1.0 - dones) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + policy.config["gamma"] * q_tp1_best_masked) # Compute the error (Square/Huber). td_error = q_t_selected - q_t_selected_target.detach() loss = torch.mean(huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["loss"] = loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error return loss
def get_train_op(): td_error = q_clicked - target_clicked if policy.config["use_huber"]: loss = huber_loss(td_error, delta=policy.config["huber_threshold"]) else: loss = torch.pow(td_error, 2.0) loss = torch.mean(loss) return loss, torch.mean(torch.abs(td_error))
def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Compute loss for SimpleQ. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The SimpleQ loss tensor given the input batch. """ target_model = self.target_models[model] # q network evaluation q_t = self._compute_q_values(model, train_batch[SampleBatch.CUR_OBS], is_training=True) # target q network evalution q_tp1 = self._compute_q_values( target_model, train_batch[SampleBatch.NEXT_OBS], is_training=True, ) # q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), self.action_space.n) q_t_selected = torch.sum(q_t * one_hot_selection, 1) # compute estimate of best possible value starting from state at t + 1 dones = train_batch[SampleBatch.DONES].float() q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1), self.action_space.n) q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1) q_tp1_best_masked = (1.0 - dones) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + self.config["gamma"] * q_tp1_best_masked) # Compute the error (Square/Huber). td_error = q_t_selected - q_t_selected_target.detach() loss = torch.mean(huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["loss"] = loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error return loss
def __init__( self, q_t_selected: TensorType, q_logits_t_selected: TensorType, q_tp1_best: TensorType, q_probs_tp1_best: TensorType, importance_weights: TensorType, rewards: TensorType, done_mask: TensorType, gamma=0.99, n_step=1, num_atoms=1, v_min=-10.0, v_max=10.0, ): if num_atoms > 1: # Distributional Q-learning which corresponds to an entropy loss z = torch.range(0.0, num_atoms - 1, dtype=torch.float32).to(rewards.device) z = v_min + z * (v_max - v_min) / float(num_atoms - 1) # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms) r_tau = torch.unsqueeze( rewards, -1) + gamma**n_step * torch.unsqueeze( 1.0 - done_mask, -1) * torch.unsqueeze(z, 0) r_tau = torch.clamp(r_tau, v_min, v_max) b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1)) lb = torch.floor(b) ub = torch.ceil(b) # Indispensable judgement which is missed in most implementations # when b happens to be an integer, lb == ub, so pr_j(s', a*) will # be discarded because (ub-b) == (b-lb) == 0. floor_equal_ceil = ((ub - lb) < 0.5).float() # (batch_size, num_atoms, num_atoms) l_project = F.one_hot(lb.long(), num_atoms) # (batch_size, num_atoms, num_atoms) u_project = F.one_hot(ub.long(), num_atoms) ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil) mu_delta = q_probs_tp1_best * (b - lb) ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1), dim=1) mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1), dim=1) m = ml_delta + mu_delta # Rainbow paper claims that using this cross entropy loss for # priority is robust and insensitive to `prioritized_replay_alpha` self.td_error = softmax_cross_entropy_with_logits( logits=q_logits_t_selected, labels=m.detach()) self.loss = torch.mean(self.td_error * importance_weights) self.stats = { # TODO: better Q stats for dist dqn } else: q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked # compute the error (potentially clipped) self.td_error = q_t_selected - q_t_selected_target.detach() self.loss = torch.mean(importance_weights.float() * huber_loss(self.td_error)) self.stats = { "mean_q": torch.mean(q_t_selected), "min_q": torch.min(q_t_selected), "max_q": torch.max(q_t_selected), }
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy: The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch: The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] model_out_t, _ = model( SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True), [], None) model_out_tp1, _ = model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None) target_model_out_tp1, _ = target_model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) log_pis_t = F.log_softmax(action_dist_inputs_t, dim=-1) policy_t = torch.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) log_pis_tp1 = F.log_softmax(action_dist_inputs_tp1, -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t) # Target Q-values. q_tp1, _ = target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t) twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, model) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = torch.mean( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach(), ), dim=-1, )) else: alpha_loss = -torch.mean(model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["policy_t"] = policy_t model.tower_stats["log_pis_t"] = log_pis_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, train_batch: SampleBatch) -> TensorType: target_model = policy.target_models[model] twin_q = policy.config["twin_q"] gamma = policy.config["gamma"] n_step = policy.config["n_step"] use_huber = policy.config["use_huber"] huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True) input_dict_next = SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = target_model(input_dict_next, [], None) # Policy network evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) policy_t = model.get_policy_output(model_out_t) # policy_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) policy_tp1 = target_model.get_policy_output(target_model_out_tp1) # Action outputs. if policy.config["smooth_target_policy"]: target_noise_clip = policy.config["target_noise_clip"] clipped_normal_sample = torch.clamp( torch.normal(mean=torch.zeros(policy_tp1.size()), std=policy.config["target_noise"]).to( policy_tp1.device), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = torch.min( torch.max( policy_tp1 + clipped_normal_sample, torch.tensor( policy.action_space.low, dtype=torch.float32, device=policy_tp1.device, ), ), torch.tensor(policy.action_space.high, dtype=torch.float32, device=policy_tp1.device), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) actor_loss = -torch.mean(q_t_det_policy) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # q_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # Target q-net(s) evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1_smoothed) q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = (train_batch[SampleBatch.REWARDS] + gamma**n_step * q_tp1_best_masked).detach() # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold) else: errors = 0.5 * (torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0)) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * torch.pow(td_error, 2.0) critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors) # Add l2-regularization if required. if l2_reg is not None: for name, var in model.policy_variables(as_dict=True).items(): if "bias" not in name: actor_loss += l2_reg * l2_loss(var) for name, var in model.q_variables(as_dict=True).items(): if "bias" not in name: critic_loss += l2_reg * l2_loss(var) # Model self-supervised losses. if policy.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] [actor_loss, critic_loss] = model.custom_loss([actor_loss, critic_loss], input_dict) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return two loss terms (corresponding to the two optimizers, we create). return actor_loss, critic_loss
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get(SampleBatch.SEQ_LENS) model_out_t, state_in_t = model( SampleBatch( obs=train_batch[SampleBatch.CUR_OBS], prev_actions=train_batch[SampleBatch.PREV_ACTIONS], prev_rewards=train_batch[SampleBatch.PREV_REWARDS], _is_training=True, ), state_batches, seq_lens, ) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], prev_actions=train_batch[SampleBatch.ACTIONS], prev_rewards=train_batch[SampleBatch.REWARDS], _is_training=True, ), state_batches, seq_lens, ) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = target_model( SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], prev_actions=train_batch[SampleBatch.ACTIONS], prev_rewards=train_batch[SampleBatch.REWARDS], _is_training=True, ), state_batches, seq_lens, ) target_states_in_tp1 = target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_t, states_in_t["policy"], seq_lens) log_pis_t = F.log_softmax( action_dist_inputs_t, dim=-1, ) policy_t = torch.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1, states_in_tp1["policy"], seq_lens) log_pis_tp1 = F.log_softmax( action_dist_inputs_tp1, -1, ) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens) # Target Q-values. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens) twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_t, states_in_t["policy"], seq_lens) action_dist_t = action_dist_class( action_dist_inputs_t, model, ) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_tp1, states_in_tp1["policy"], seq_lens) action_dist_tp1 = action_dist_class( action_dist_inputs_t, model, ) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS], ) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1, ) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach(), ), dim=-1, )) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t * seq_mask[..., None] model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # Store per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]), dim=-1) # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ target_model = policy.target_models[model] config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values(policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get( SampleBatch.SEQ_LENS), explore=False, is_training=True) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values(policy, target_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get( SampleBatch.SEQ_LENS), explore=False, is_training=True) actions = train_batch[SampleBatch.ACTIONS].long() dones = train_batch[SampleBatch.DONES].float() rewards = train_batch[SampleBatch.REWARDS] weights = train_batch[PRIO_WEIGHTS] B = state_batches[0].shape[0] T = q.shape[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(actions, policy.action_space.n) q_selected = torch.sum( torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=q.device)) * one_hot_selection, 1) if config["double_q"]: best_actions = torch.argmax(q, dim=1) else: best_actions = torch.argmax(q_target, dim=1) best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) q_target_best = torch.sum( torch.where(q_target > FLOAT_MIN, q_target, torch.tensor(0.0, device=q_target.device)) * best_actions_one_hot, dim=1) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * torch.cat([ q_target_best[1:], torch.tensor([0.0], device=q_target_best.device) ]) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"]**config["n_step"] * h_inv, config["h_function_epsilon"]) else: target = rewards + \ config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = q_selected.reshape([B, T])[:, :-1] td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] total_loss = reduce_mean_valid(weights * huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_q"] = reduce_mean_valid(q_selected) model.tower_stats["min_q"] = torch.min(q_selected) model.tower_stats["max_q"] = torch.max(q_selected) model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error) # Store per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). model.tower_stats["td_error"] = torch.mean(td_error, dim=-1) return total_loss
def build_slateq_losses( policy: Policy, model: ModelV2, _, train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: The user-choice- and Q-value loss tensors. """ # B=batch size # S=slate size # C=num candidates # E=embedding size # A=number of all possible slates # Q-value computations. # --------------------- # action.shape: [B, S] actions = train_batch[SampleBatch.ACTIONS] observation = convert_to_torch_tensor( train_batch[SampleBatch.OBS], device=actions.device ) # user.shape: [B, E] user_obs = observation["user"] batch_size, embedding_size = user_obs.shape # doc.shape: [B, C, E] doc_obs = list(observation["doc"].values()) A, S = policy.slates.shape # click_indicator.shape: [B, S] click_indicator = torch.stack( [k["click"] for k in observation["response"]], 1 ).float() # item_reward.shape: [B, S] item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1) # q_values.shape: [B, C] q_values = model.get_q_values(user_obs, doc_obs) # slate_q_values.shape: [B, S] slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1) # Only get the Q from the clicked document. # replay_click_q.shape: [B] replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1) # Target computations. # -------------------- next_obs = convert_to_torch_tensor( train_batch[SampleBatch.NEXT_OBS], device=actions.device ) # user.shape: [B, E] user_next_obs = next_obs["user"] # doc.shape: [B, C, E] doc_next_obs = list(next_obs["doc"].values()) # Only compute the watch time reward of the clicked item. reward = torch.sum(item_reward * click_indicator, dim=1) # TODO: Find out, whether it's correct here to use obs, not next_obs! # Dopamine uses obs, then next_obs only for the score. # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs) next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs) scores, score_no_click = score_documents(user_next_obs, doc_next_obs) # next_q_values_slate.shape: [B, A, S] indices = policy.slates_indices.to(next_q_values.device) next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape( [-1, A, S] ) # scores_slate.shape [B, A, S] scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S]) # score_no_click_slate.shape: [B, A] score_no_click_slate = torch.reshape( torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1] ) # next_q_target_slate.shape: [B, A] next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / ( torch.sum(scores_slate, dim=2) + score_no_click_slate ) next_q_target_max, _ = torch.max(next_q_target_slate, dim=1) target = reward + policy.config["gamma"] * next_q_target_max * ( 1.0 - train_batch["dones"].float() ) target = target.detach() clicked = torch.sum(click_indicator, dim=1) mask_clicked_slates = clicked > 0 clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device) clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates) # Clicked_indices is a vector and torch.gather selects the batch dimension. q_clicked = torch.gather(replay_click_q, 0, clicked_indices) target_clicked = torch.gather(target, 0, clicked_indices) td_error = torch.where( clicked.bool(), replay_click_q - target, torch.zeros_like(train_batch[SampleBatch.REWARDS]), ) if policy.config["use_huber"]: loss = huber_loss(td_error, delta=policy.config["huber_threshold"]) else: loss = torch.pow(td_error, 2.0) loss = torch.mean(loss) td_error = torch.abs(td_error) mean_td_error = torch.mean(td_error) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_values"] = torch.mean(q_values) model.tower_stats["q_clicked"] = torch.mean(q_clicked) model.tower_stats["scores"] = torch.mean(scores) model.tower_stats["score_no_click"] = torch.mean(score_no_click) model.tower_stats["slate_q_values"] = torch.mean(slate_q_values) model.tower_stats["replay_click_q"] = torch.mean(replay_click_q) model.tower_stats["bellman_reward"] = torch.mean(reward) model.tower_stats["next_q_values"] = torch.mean(next_q_values) model.tower_stats["target"] = torch.mean(target) model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate) model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max) model.tower_stats["target_clicked"] = torch.mean(target_clicked) model.tower_stats["q_loss"] = loss model.tower_stats["td_error"] = td_error model.tower_stats["mean_td_error"] = mean_td_error model.tower_stats["mean_actions"] = torch.mean(actions.float()) # selected_doc.shape: [batch_size, slate_size, embedding_size] selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] torch.stack(doc_obs, 1), 1, # index.shape: [batch_size, slate_size, embedding_size] actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) scores = model.choice_model(user_obs, selected_doc) # click_indicator.shape: [batch_size, slate_size] # no_clicks.shape: [batch_size, 1] no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True) # targets.shape: [batch_size, slate_size+1] targets = torch.cat([click_indicator, no_clicks], dim=1) choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1)) # print(model.choice_model.a.item(), model.choice_model.b.item()) model.tower_stats["choice_loss"] = choice_loss return choice_loss, loss
def build_slateq_losses( policy: Policy, model: ModelV2, _: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: Tuple consisting of 1) the choice loss- and 2) the Q-value loss tensors. """ start = time.time() obs = restore_original_dimensions(train_batch[SampleBatch.OBS], policy.observation_space, tensorlib=torch) # user.shape: [batch_size, embedding_size] user = obs["user"] # doc.shape: [batch_size, num_docs, embedding_size] doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1) # action.shape: [batch_size, slate_size] actions = train_batch[SampleBatch.ACTIONS] next_obs = restore_original_dimensions(train_batch[SampleBatch.NEXT_OBS], policy.observation_space, tensorlib=torch) # Step 1: Build user choice model loss _, _, embedding_size = doc.shape # selected_doc.shape: [batch_size, slate_size, embedding_size] selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] input=doc, dim=1, # index.shape: [batch_size, slate_size, embedding_size] index=actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) scores = model.choice_model(user, selected_doc) choice_loss_fn = nn.CrossEntropyLoss() # clicks.shape: [batch_size, slate_size] clicks = torch.stack( [resp["click"][:, 1] for resp in next_obs["response"]], dim=1) no_clicks = 1 - torch.sum(clicks, 1, keepdim=True) # clicks.shape: [batch_size, slate_size+1] targets = torch.cat([clicks, no_clicks], dim=1) choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1)) # print(model.choice_model.a.item(), model.choice_model.b.item()) # Step 2: Build qvalue loss # Fields in available in train_batch: ['t', 'eps_id', 'agent_index', # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions', # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights', # 'batch_indexes'] learning_strategy = policy.config["slateq_strategy"] # Myopic agent: Don't care about value of next state. # Acts only based off immediate reward. if learning_strategy == "MYOP": next_q_values = torch.tensor(0.0, requires_grad=False) # Q-learning: Default setting for SlateQ -> Use DQN-style loss function. elif learning_strategy == "QL": # next_doc.shape: [batch_size, num_docs, embedding_size] next_doc = torch.cat( [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) next_user = next_obs["user"] dones = train_batch[SampleBatch.DONES] with torch.no_grad(): if policy.config["double_q"]: next_target_per_slate_q_values = policy.target_models[ model].get_per_slate_q_values(next_user, next_doc) _, next_q_values, _ = model.choose_slate( next_user, next_doc, next_target_per_slate_q_values) else: _, next_q_values, _ = policy.target_models[model].choose_slate( next_user, next_doc) next_q_values = next_q_values.detach() next_q_values[dones.bool()] = 0.0 # SARS'A': Use on-policy sarsa loss. elif learning_strategy == "SARSA": # next_doc.shape: [batch_size, num_docs, embedding_size] next_doc = torch.cat( [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) next_actions = train_batch["next_actions"] _, _, embedding_size = next_doc.shape # selected_doc.shape: [batch_size, slate_size, embedding_size] next_selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] input=next_doc, dim=1, # index.shape: [batch_size, slate_size, embedding_size] index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) next_user = next_obs["user"] dones = train_batch[SampleBatch.DONES] with torch.no_grad(): # q_values.shape: [batch_size, slate_size+1] q_values = model.q_model(next_user, next_selected_doc) # raw_scores.shape: [batch_size, slate_size+1] raw_scores = model.choice_model(next_user, next_selected_doc) max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) scores = torch.exp(raw_scores - max_raw_scores) # next_q_values.shape: [batch_size] next_q_values = torch.sum(q_values * scores, dim=1) / torch.sum( scores, dim=1) next_q_values[dones.bool()] = 0.0 else: raise ValueError(learning_strategy) # target_q_values.shape: [batch_size] target_q_values = (train_batch[SampleBatch.REWARDS] + policy.config["gamma"] * next_q_values) # q_values.shape: [batch_size, slate_size+1]. q_values = model.q_model(user, selected_doc) # raw_scores.shape: [batch_size, slate_size+1]. raw_scores = model.choice_model(user, selected_doc) max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) scores = torch.exp(raw_scores - max_raw_scores) q_values = torch.sum(q_values * scores, dim=1) / torch.sum( scores, dim=1) # shape=[batch_size] td_error = torch.abs(q_values - target_q_values) q_value_loss = torch.mean(huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_loss"] = q_value_loss model.tower_stats["q_values"] = q_values model.tower_stats["next_q_values"] = next_q_values model.tower_stats["next_q_minus_q"] = next_q_values - q_values model.tower_stats["td_error"] = td_error model.tower_stats["target_q_values"] = target_q_values model.tower_stats["scores"] = scores model.tower_stats["raw_scores"] = raw_scores model.tower_stats["choice_loss"] = choice_loss model.tower_stats["choice_beta"] = model.choice_model.beta model.tower_stats[ "choice_score_no_click"] = model.choice_model.score_no_click logger.debug(f"loss calculation took {time.time()-start}s") return choice_loss, q_value_loss