Пример #1
0
def build_q_losses(policy: Policy, model, dist_class,
                   train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for SimpleQTorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]): The action distribution class.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    target_model = policy.target_models[model]

    # q network evaluation
    q_t = compute_q_values(policy,
                           model,
                           train_batch[SampleBatch.CUR_OBS],
                           explore=False,
                           is_training=True)

    # target q network evalution
    q_tp1 = compute_q_values(
        policy,
        target_model,
        train_batch[SampleBatch.NEXT_OBS],
        explore=False,
        is_training=True,
    )

    # q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                                  policy.action_space.n)
    q_t_selected = torch.sum(q_t * one_hot_selection, 1)

    # compute estimate of best possible value starting from state at t + 1
    dones = train_batch[SampleBatch.DONES].float()
    q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1),
                                             policy.action_space.n)
    q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
    q_tp1_best_masked = (1.0 - dones) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           policy.config["gamma"] * q_tp1_best_masked)

    # Compute the error (Square/Huber).
    td_error = q_t_selected - q_t_selected_target.detach()
    loss = torch.mean(huber_loss(td_error))

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["loss"] = loss
    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    return loss
Пример #2
0
 def get_train_op():
     td_error = q_clicked - target_clicked
     if policy.config["use_huber"]:
         loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
     else:
         loss = torch.pow(td_error, 2.0)
     loss = torch.mean(loss)
     return loss, torch.mean(torch.abs(td_error))
Пример #3
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Compute loss for SimpleQ.

        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.

        Returns:
            The SimpleQ loss tensor given the input batch.
        """
        target_model = self.target_models[model]

        # q network evaluation
        q_t = self._compute_q_values(model,
                                     train_batch[SampleBatch.CUR_OBS],
                                     is_training=True)

        # target q network evalution
        q_tp1 = self._compute_q_values(
            target_model,
            train_batch[SampleBatch.NEXT_OBS],
            is_training=True,
        )

        # q scores for actions which we know were selected in the given state.
        one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                                      self.action_space.n)
        q_t_selected = torch.sum(q_t * one_hot_selection, 1)

        # compute estimate of best possible value starting from state at t + 1
        dones = train_batch[SampleBatch.DONES].float()
        q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1),
                                                 self.action_space.n)
        q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1)
        q_tp1_best_masked = (1.0 - dones) * q_tp1_best

        # compute RHS of bellman equation
        q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                               self.config["gamma"] * q_tp1_best_masked)

        # Compute the error (Square/Huber).
        td_error = q_t_selected - q_t_selected_target.detach()
        loss = torch.mean(huber_loss(td_error))

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["loss"] = loss
        # TD-error tensor in final stats
        # will be concatenated and retrieved for each individual batch item.
        model.tower_stats["td_error"] = td_error

        return loss
Пример #4
0
    def __init__(
        self,
        q_t_selected: TensorType,
        q_logits_t_selected: TensorType,
        q_tp1_best: TensorType,
        q_probs_tp1_best: TensorType,
        importance_weights: TensorType,
        rewards: TensorType,
        done_mask: TensorType,
        gamma=0.99,
        n_step=1,
        num_atoms=1,
        v_min=-10.0,
        v_max=10.0,
    ):

        if num_atoms > 1:
            # Distributional Q-learning which corresponds to an entropy loss
            z = torch.range(0.0, num_atoms - 1,
                            dtype=torch.float32).to(rewards.device)
            z = v_min + z * (v_max - v_min) / float(num_atoms - 1)

            # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
            r_tau = torch.unsqueeze(
                rewards, -1) + gamma**n_step * torch.unsqueeze(
                    1.0 - done_mask, -1) * torch.unsqueeze(z, 0)
            r_tau = torch.clamp(r_tau, v_min, v_max)
            b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
            lb = torch.floor(b)
            ub = torch.ceil(b)

            # Indispensable judgement which is missed in most implementations
            # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
            # be discarded because (ub-b) == (b-lb) == 0.
            floor_equal_ceil = ((ub - lb) < 0.5).float()

            # (batch_size, num_atoms, num_atoms)
            l_project = F.one_hot(lb.long(), num_atoms)
            # (batch_size, num_atoms, num_atoms)
            u_project = F.one_hot(ub.long(), num_atoms)
            ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil)
            mu_delta = q_probs_tp1_best * (b - lb)
            ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1),
                                 dim=1)
            mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1),
                                 dim=1)
            m = ml_delta + mu_delta

            # Rainbow paper claims that using this cross entropy loss for
            # priority is robust and insensitive to `prioritized_replay_alpha`
            self.td_error = softmax_cross_entropy_with_logits(
                logits=q_logits_t_selected, labels=m.detach())
            self.loss = torch.mean(self.td_error * importance_weights)
            self.stats = {
                # TODO: better Q stats for dist dqn
            }
        else:
            q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best

            # compute RHS of bellman equation
            q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked

            # compute the error (potentially clipped)
            self.td_error = q_t_selected - q_t_selected_target.detach()
            self.loss = torch.mean(importance_weights.float() *
                                   huber_loss(self.td_error))
            self.stats = {
                "mean_q": torch.mean(q_t_selected),
                "min_q": torch.min(q_t_selected),
                "max_q": torch.max(q_t_selected),
            }
Пример #5
0
def actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy: The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch: The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    # Look up the target model (tower) using the model tower.
    target_model = policy.target_models[model]

    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    model_out_t, _ = model(
        SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True),
        [], None)

    model_out_tp1, _ = model(
        SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True),
        [], None)

    target_model_out_tp1, _ = target_model(
        SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True),
        [], None)

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
        log_pis_t = F.log_softmax(action_dist_inputs_t, dim=-1)
        policy_t = torch.exp(log_pis_t)
        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1)
        log_pis_tp1 = F.log_softmax(action_dist_inputs_tp1, -1)
        policy_tp1 = torch.exp(log_pis_tp1)
        # Q-values.
        q_t, _ = model.get_q_values(model_out_t)
        # Target Q-values.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1)
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(model_out_t)
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1)
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
        action_dist_t = action_dist_class(action_dist_inputs_t, model)
        policy_t = (action_dist_t.sample() if not deterministic else
                    action_dist_t.deterministic_sample())
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1)
        action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
        policy_tp1 = (action_dist_tp1.sample() if not deterministic else
                      action_dist_tp1.deterministic_sample())
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t, _ = model.get_q_values(model_out_t,
                                    train_batch[SampleBatch.ACTIONS])
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(
                model_out_t, train_batch[SampleBatch.ACTIONS])

        # Q-values for current policy in given current state.
        q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t)
        if policy.config["twin_q"]:
            twin_q_t_det_policy, _ = model.get_twin_q_values(
                model_out_t, policy_t)
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1)
        if policy.config["twin_q"]:
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1, policy_tp1)
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error)))

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = torch.mean(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach(),
                ),
                dim=-1,
            ))
    else:
        alpha_loss = -torch.mean(model.log_alpha *
                                 (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t
    model.tower_stats["policy_t"] = policy_t
    model.tower_stats["log_pis_t"] = log_pis_t
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss

    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss])
Пример #6
0
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _,
                           train_batch: SampleBatch) -> TensorType:

    target_model = policy.target_models[model]

    twin_q = policy.config["twin_q"]
    gamma = policy.config["gamma"]
    n_step = policy.config["n_step"]
    use_huber = policy.config["use_huber"]
    huber_threshold = policy.config["huber_threshold"]
    l2_reg = policy.config["l2_reg"]

    input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS],
                             _is_training=True)
    input_dict_next = SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS],
                                  _is_training=True)

    model_out_t, _ = model(input_dict, [], None)
    model_out_tp1, _ = model(input_dict_next, [], None)
    target_model_out_tp1, _ = target_model(input_dict_next, [], None)

    # Policy network evaluation.
    # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
    policy_t = model.get_policy_output(model_out_t)
    # policy_batchnorm_update_ops = list(
    #    set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

    policy_tp1 = target_model.get_policy_output(target_model_out_tp1)

    # Action outputs.
    if policy.config["smooth_target_policy"]:
        target_noise_clip = policy.config["target_noise_clip"]
        clipped_normal_sample = torch.clamp(
            torch.normal(mean=torch.zeros(policy_tp1.size()),
                         std=policy.config["target_noise"]).to(
                             policy_tp1.device),
            -target_noise_clip,
            target_noise_clip,
        )

        policy_tp1_smoothed = torch.min(
            torch.max(
                policy_tp1 + clipped_normal_sample,
                torch.tensor(
                    policy.action_space.low,
                    dtype=torch.float32,
                    device=policy_tp1.device,
                ),
            ),
            torch.tensor(policy.action_space.high,
                         dtype=torch.float32,
                         device=policy_tp1.device),
        )
    else:
        # No smoothing, just use deterministic actions.
        policy_tp1_smoothed = policy_tp1

    # Q-net(s) evaluation.
    # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS))
    # Q-values for given actions & observations in given current
    q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])

    # Q-values for current policy (no noise) in given current state
    q_t_det_policy = model.get_q_values(model_out_t, policy_t)

    actor_loss = -torch.mean(q_t_det_policy)

    if twin_q:
        twin_q_t = model.get_twin_q_values(model_out_t,
                                           train_batch[SampleBatch.ACTIONS])
    # q_batchnorm_update_ops = list(
    #     set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)

    # Target q-net(s) evaluation.
    q_tp1 = target_model.get_q_values(target_model_out_tp1,
                                      policy_tp1_smoothed)

    if twin_q:
        twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1,
                                                    policy_tp1_smoothed)

    q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1)
    if twin_q:
        twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
        q_tp1 = torch.min(q_tp1, twin_q_tp1)

    q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
    q_tp1_best_masked = (1.0 -
                         train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # Compute RHS of bellman equation.
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           gamma**n_step * q_tp1_best_masked).detach()

    # Compute the error (potentially clipped).
    if twin_q:
        td_error = q_t_selected - q_t_selected_target
        twin_td_error = twin_q_t_selected - q_t_selected_target
        if use_huber:
            errors = huber_loss(td_error, huber_threshold) + huber_loss(
                twin_td_error, huber_threshold)
        else:
            errors = 0.5 * (torch.pow(td_error, 2.0) +
                            torch.pow(twin_td_error, 2.0))
    else:
        td_error = q_t_selected - q_t_selected_target
        if use_huber:
            errors = huber_loss(td_error, huber_threshold)
        else:
            errors = 0.5 * torch.pow(td_error, 2.0)

    critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors)

    # Add l2-regularization if required.
    if l2_reg is not None:
        for name, var in model.policy_variables(as_dict=True).items():
            if "bias" not in name:
                actor_loss += l2_reg * l2_loss(var)
        for name, var in model.q_variables(as_dict=True).items():
            if "bias" not in name:
                critic_loss += l2_reg * l2_loss(var)

    # Model self-supervised losses.
    if policy.config["use_state_preprocessor"]:
        # Expand input_dict in case custom_loss' need them.
        input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS]
        input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS]
        input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES]
        input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS]
        [actor_loss,
         critic_loss] = model.custom_loss([actor_loss, critic_loss],
                                          input_dict)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    # Return two loss terms (corresponding to the two optimizers, we create).
    return actor_loss, critic_loss
Пример #7
0
def actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    target_model = policy.target_models[model]

    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches
    seq_lens = train_batch.get(SampleBatch.SEQ_LENS)

    model_out_t, state_in_t = model(
        SampleBatch(
            obs=train_batch[SampleBatch.CUR_OBS],
            prev_actions=train_batch[SampleBatch.PREV_ACTIONS],
            prev_rewards=train_batch[SampleBatch.PREV_REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"])

    model_out_tp1, state_in_tp1 = model(
        SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS],
            prev_actions=train_batch[SampleBatch.ACTIONS],
            prev_rewards=train_batch[SampleBatch.REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"])

    target_model_out_tp1, target_state_in_tp1 = target_model(
        SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS],
            prev_actions=train_batch[SampleBatch.ACTIONS],
            prev_rewards=train_batch[SampleBatch.REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    target_states_in_tp1 = target_model.select_state(state_in_tp1,
                                                     ["policy", "q", "twin_q"])

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_t, states_in_t["policy"], seq_lens)
        log_pis_t = F.log_softmax(
            action_dist_inputs_t,
            dim=-1,
        )
        policy_t = torch.exp(log_pis_t)

        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1, states_in_tp1["policy"], seq_lens)
        log_pis_tp1 = F.log_softmax(
            action_dist_inputs_tp1,
            -1,
        )
        policy_tp1 = torch.exp(log_pis_tp1)

        # Q-values.
        q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)
        # Target Q-values.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1,
                                             target_states_in_tp1["q"],
                                             seq_lens)
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(model_out_t,
                                                  states_in_t["twin_q"],
                                                  seq_lens)
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_t, states_in_t["policy"], seq_lens)
        action_dist_t = action_dist_class(
            action_dist_inputs_t,
            model,
        )
        policy_t = (action_dist_t.sample() if not deterministic else
                    action_dist_t.deterministic_sample())
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_tp1, states_in_tp1["policy"], seq_lens)
        action_dist_tp1 = action_dist_class(
            action_dist_inputs_t,
            model,
        )
        policy_tp1 = (action_dist_tp1.sample() if not deterministic else
                      action_dist_tp1.deterministic_sample())
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens,
                                    train_batch[SampleBatch.ACTIONS])
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(
                model_out_t,
                states_in_t["twin_q"],
                seq_lens,
                train_batch[SampleBatch.ACTIONS],
            )

        # Q-values for current policy in given current state.
        q_t_det_policy, _ = model.get_q_values(model_out_t, states_in_t["q"],
                                               seq_lens, policy_t)
        if policy.config["twin_q"]:
            twin_q_t_det_policy, _ = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens, policy_t)
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1 = target_model.get_q_values(target_model_out_tp1,
                                          target_states_in_tp1["q"], seq_lens,
                                          policy_tp1)
        if policy.config["twin_q"]:
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1,
                target_states_in_tp1["twin_q"],
                seq_lens,
                policy_tp1,
            )
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # BURNIN #
    B = state_batches[0].shape[0]
    T = q_t_selected.shape[0] // B
    seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)
    # Mask away also the burn-in sequence at the beginning.
    burn_in = policy.config["burn_in"]
    if burn_in > 0 and burn_in < T:
        seq_mask[:, :burn_in] = False

    seq_mask = seq_mask.reshape(-1)
    num_valid = torch.sum(seq_mask)

    def reduce_mean_valid(t):
        return torch.sum(t[seq_mask]) / num_valid

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                          huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                              huber_loss(twin_td_error)))
    td_error = td_error * seq_mask

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = reduce_mean_valid(
            torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = reduce_mean_valid(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach(),
                ),
                dim=-1,
            ))
    else:
        alpha_loss = -reduce_mean_valid(
            model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t -
                                       q_t_det_policy)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t * seq_mask[..., None]
    model.tower_stats["policy_t"] = policy_t * seq_mask[..., None]
    model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None]
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss
    # Store per time chunk (b/c we need only one mean
    # prioritized replay weight per stored sequence).
    model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]),
                                               dim=-1)

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss])
Пример #8
0
def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    target_model = policy.target_models[model]
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(policy,
                                  model,
                                  train_batch,
                                  state_batches=state_batches,
                                  seq_lens=train_batch.get(
                                      SampleBatch.SEQ_LENS),
                                  explore=False,
                                  is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(policy,
                                         target_model,
                                         train_batch,
                                         state_batches=state_batches,
                                         seq_lens=train_batch.get(
                                             SampleBatch.SEQ_LENS),
                                         explore=False,
                                         is_training=True)

    actions = train_batch[SampleBatch.ACTIONS].long()
    dones = train_batch[SampleBatch.DONES].float()
    rewards = train_batch[SampleBatch.REWARDS]
    weights = train_batch[PRIO_WEIGHTS]

    B = state_batches[0].shape[0]
    T = q.shape[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(actions, policy.action_space.n)
    q_selected = torch.sum(
        torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=q.device)) *
        one_hot_selection, 1)

    if config["double_q"]:
        best_actions = torch.argmax(q, dim=1)
    else:
        best_actions = torch.argmax(q_target, dim=1)

    best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
    q_target_best = torch.sum(
        torch.where(q_target > FLOAT_MIN, q_target,
                    torch.tensor(0.0, device=q_target.device)) *
        best_actions_one_hot,
        dim=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * torch.cat([
            q_target_best[1:],
            torch.tensor([0.0], device=q_target_best.device)
        ])

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        if burn_in > 0 and burn_in < T:
            seq_mask[:, :burn_in] = False

        num_valid = torch.sum(seq_mask)

        def reduce_mean_valid(t):
            return torch.sum(t[seq_mask]) / num_valid

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = q_selected.reshape([B, T])[:, :-1]
        td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
        td_error = td_error * seq_mask
        weights = weights.reshape([B, T])[:, :-1]
        total_loss = reduce_mean_valid(weights * huber_loss(td_error))

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_q"] = reduce_mean_valid(q_selected)
        model.tower_stats["min_q"] = torch.min(q_selected)
        model.tower_stats["max_q"] = torch.max(q_selected)
        model.tower_stats["mean_td_error"] = reduce_mean_valid(td_error)
        # Store per time chunk (b/c we need only one mean
        # prioritized replay weight per stored sequence).
        model.tower_stats["td_error"] = torch.mean(td_error, dim=-1)

    return total_loss
Пример #9
0
def build_slateq_losses(
    policy: Policy,
    model: ModelV2,
    _,
    train_batch: SampleBatch,
) -> TensorType:
    """Constructs the choice- and Q-value losses for the SlateQTorchPolicy.

    Args:
        policy: The Policy to calculate the loss for.
        model: The Model to calculate the loss for.
        train_batch: The training data.

    Returns:
        The user-choice- and Q-value loss tensors.
    """

    # B=batch size
    # S=slate size
    # C=num candidates
    # E=embedding size
    # A=number of all possible slates

    # Q-value computations.
    # ---------------------
    # action.shape: [B, S]
    actions = train_batch[SampleBatch.ACTIONS]

    observation = convert_to_torch_tensor(
        train_batch[SampleBatch.OBS], device=actions.device
    )
    # user.shape: [B, E]
    user_obs = observation["user"]
    batch_size, embedding_size = user_obs.shape
    # doc.shape: [B, C, E]
    doc_obs = list(observation["doc"].values())

    A, S = policy.slates.shape

    # click_indicator.shape: [B, S]
    click_indicator = torch.stack(
        [k["click"] for k in observation["response"]], 1
    ).float()
    # item_reward.shape: [B, S]
    item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1)
    # q_values.shape: [B, C]
    q_values = model.get_q_values(user_obs, doc_obs)
    # slate_q_values.shape: [B, S]
    slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1)
    # Only get the Q from the clicked document.
    # replay_click_q.shape: [B]
    replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1)

    # Target computations.
    # --------------------
    next_obs = convert_to_torch_tensor(
        train_batch[SampleBatch.NEXT_OBS], device=actions.device
    )

    # user.shape: [B, E]
    user_next_obs = next_obs["user"]
    # doc.shape: [B, C, E]
    doc_next_obs = list(next_obs["doc"].values())
    # Only compute the watch time reward of the clicked item.
    reward = torch.sum(item_reward * click_indicator, dim=1)

    # TODO: Find out, whether it's correct here to use obs, not next_obs!
    # Dopamine uses obs, then next_obs only for the score.
    # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs)
    next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs)
    scores, score_no_click = score_documents(user_next_obs, doc_next_obs)

    # next_q_values_slate.shape: [B, A, S]
    indices = policy.slates_indices.to(next_q_values.device)
    next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape(
        [-1, A, S]
    )
    # scores_slate.shape [B, A, S]
    scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S])
    # score_no_click_slate.shape: [B, A]
    score_no_click_slate = torch.reshape(
        torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1]
    )

    # next_q_target_slate.shape: [B, A]
    next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / (
        torch.sum(scores_slate, dim=2) + score_no_click_slate
    )
    next_q_target_max, _ = torch.max(next_q_target_slate, dim=1)

    target = reward + policy.config["gamma"] * next_q_target_max * (
        1.0 - train_batch["dones"].float()
    )
    target = target.detach()

    clicked = torch.sum(click_indicator, dim=1)
    mask_clicked_slates = clicked > 0
    clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device)
    clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates)
    # Clicked_indices is a vector and torch.gather selects the batch dimension.
    q_clicked = torch.gather(replay_click_q, 0, clicked_indices)
    target_clicked = torch.gather(target, 0, clicked_indices)

    td_error = torch.where(
        clicked.bool(),
        replay_click_q - target,
        torch.zeros_like(train_batch[SampleBatch.REWARDS]),
    )
    if policy.config["use_huber"]:
        loss = huber_loss(td_error, delta=policy.config["huber_threshold"])
    else:
        loss = torch.pow(td_error, 2.0)
    loss = torch.mean(loss)
    td_error = torch.abs(td_error)
    mean_td_error = torch.mean(td_error)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_values"] = torch.mean(q_values)
    model.tower_stats["q_clicked"] = torch.mean(q_clicked)
    model.tower_stats["scores"] = torch.mean(scores)
    model.tower_stats["score_no_click"] = torch.mean(score_no_click)
    model.tower_stats["slate_q_values"] = torch.mean(slate_q_values)
    model.tower_stats["replay_click_q"] = torch.mean(replay_click_q)
    model.tower_stats["bellman_reward"] = torch.mean(reward)
    model.tower_stats["next_q_values"] = torch.mean(next_q_values)
    model.tower_stats["target"] = torch.mean(target)
    model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate)
    model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max)
    model.tower_stats["target_clicked"] = torch.mean(target_clicked)
    model.tower_stats["q_loss"] = loss
    model.tower_stats["td_error"] = td_error
    model.tower_stats["mean_td_error"] = mean_td_error
    model.tower_stats["mean_actions"] = torch.mean(actions.float())

    # selected_doc.shape: [batch_size, slate_size, embedding_size]
    selected_doc = torch.gather(
        # input.shape: [batch_size, num_docs, embedding_size]
        torch.stack(doc_obs, 1),
        1,
        # index.shape: [batch_size, slate_size, embedding_size]
        actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
    )

    scores = model.choice_model(user_obs, selected_doc)

    # click_indicator.shape: [batch_size, slate_size]
    # no_clicks.shape: [batch_size, 1]
    no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True)
    # targets.shape: [batch_size, slate_size+1]
    targets = torch.cat([click_indicator, no_clicks], dim=1)
    choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1))
    # print(model.choice_model.a.item(), model.choice_model.b.item())

    model.tower_stats["choice_loss"] = choice_loss

    return choice_loss, loss
Пример #10
0
def build_slateq_losses(
    policy: Policy,
    model: ModelV2,
    _: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> TensorType:
    """Constructs the choice- and Q-value losses for the SlateQTorchPolicy.

    Args:
        policy: The Policy to calculate the loss for.
        model: The Model to calculate the loss for.
        train_batch: The training data.

    Returns:
        Tuple consisting of 1) the choice loss- and 2) the Q-value loss tensors.
    """
    start = time.time()
    obs = restore_original_dimensions(train_batch[SampleBatch.OBS],
                                      policy.observation_space,
                                      tensorlib=torch)
    # user.shape: [batch_size, embedding_size]
    user = obs["user"]
    # doc.shape: [batch_size, num_docs, embedding_size]
    doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1)
    # action.shape: [batch_size, slate_size]
    actions = train_batch[SampleBatch.ACTIONS]

    next_obs = restore_original_dimensions(train_batch[SampleBatch.NEXT_OBS],
                                           policy.observation_space,
                                           tensorlib=torch)

    # Step 1: Build user choice model loss
    _, _, embedding_size = doc.shape
    # selected_doc.shape: [batch_size, slate_size, embedding_size]
    selected_doc = torch.gather(
        # input.shape: [batch_size, num_docs, embedding_size]
        input=doc,
        dim=1,
        # index.shape: [batch_size, slate_size, embedding_size]
        index=actions.unsqueeze(2).expand(-1, -1, embedding_size).long(),
    )

    scores = model.choice_model(user, selected_doc)
    choice_loss_fn = nn.CrossEntropyLoss()

    # clicks.shape: [batch_size, slate_size]
    clicks = torch.stack(
        [resp["click"][:, 1] for resp in next_obs["response"]], dim=1)
    no_clicks = 1 - torch.sum(clicks, 1, keepdim=True)
    # clicks.shape: [batch_size, slate_size+1]
    targets = torch.cat([clicks, no_clicks], dim=1)
    choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1))
    # print(model.choice_model.a.item(), model.choice_model.b.item())

    # Step 2: Build qvalue loss
    # Fields in available in train_batch: ['t', 'eps_id', 'agent_index',
    # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions',
    # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights',
    # 'batch_indexes']
    learning_strategy = policy.config["slateq_strategy"]

    # Myopic agent: Don't care about value of next state.
    # Acts only based off immediate reward.
    if learning_strategy == "MYOP":
        next_q_values = torch.tensor(0.0, requires_grad=False)
    # Q-learning: Default setting for SlateQ -> Use DQN-style loss function.
    elif learning_strategy == "QL":
        # next_doc.shape: [batch_size, num_docs, embedding_size]
        next_doc = torch.cat(
            [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
        next_user = next_obs["user"]
        dones = train_batch[SampleBatch.DONES]
        with torch.no_grad():
            if policy.config["double_q"]:
                next_target_per_slate_q_values = policy.target_models[
                    model].get_per_slate_q_values(next_user, next_doc)
                _, next_q_values, _ = model.choose_slate(
                    next_user, next_doc, next_target_per_slate_q_values)
            else:
                _, next_q_values, _ = policy.target_models[model].choose_slate(
                    next_user, next_doc)
        next_q_values = next_q_values.detach()
        next_q_values[dones.bool()] = 0.0
    # SARS'A': Use on-policy sarsa loss.
    elif learning_strategy == "SARSA":
        # next_doc.shape: [batch_size, num_docs, embedding_size]
        next_doc = torch.cat(
            [val.unsqueeze(1) for val in next_obs["doc"].values()], 1)
        next_actions = train_batch["next_actions"]
        _, _, embedding_size = next_doc.shape
        # selected_doc.shape: [batch_size, slate_size, embedding_size]
        next_selected_doc = torch.gather(
            # input.shape: [batch_size, num_docs, embedding_size]
            input=next_doc,
            dim=1,
            # index.shape: [batch_size, slate_size, embedding_size]
            index=next_actions.unsqueeze(2).expand(-1, -1,
                                                   embedding_size).long(),
        )
        next_user = next_obs["user"]
        dones = train_batch[SampleBatch.DONES]
        with torch.no_grad():
            # q_values.shape: [batch_size, slate_size+1]
            q_values = model.q_model(next_user, next_selected_doc)
            # raw_scores.shape: [batch_size, slate_size+1]
            raw_scores = model.choice_model(next_user, next_selected_doc)
            max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
            scores = torch.exp(raw_scores - max_raw_scores)
            # next_q_values.shape: [batch_size]
            next_q_values = torch.sum(q_values * scores, dim=1) / torch.sum(
                scores, dim=1)
            next_q_values[dones.bool()] = 0.0
    else:
        raise ValueError(learning_strategy)
    # target_q_values.shape: [batch_size]
    target_q_values = (train_batch[SampleBatch.REWARDS] +
                       policy.config["gamma"] * next_q_values)

    # q_values.shape: [batch_size, slate_size+1].
    q_values = model.q_model(user, selected_doc)
    # raw_scores.shape: [batch_size, slate_size+1].
    raw_scores = model.choice_model(user, selected_doc)
    max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True)
    scores = torch.exp(raw_scores - max_raw_scores)
    q_values = torch.sum(q_values * scores, dim=1) / torch.sum(
        scores, dim=1)  # shape=[batch_size]
    td_error = torch.abs(q_values - target_q_values)
    q_value_loss = torch.mean(huber_loss(td_error))

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_loss"] = q_value_loss
    model.tower_stats["q_values"] = q_values
    model.tower_stats["next_q_values"] = next_q_values
    model.tower_stats["next_q_minus_q"] = next_q_values - q_values
    model.tower_stats["td_error"] = td_error
    model.tower_stats["target_q_values"] = target_q_values
    model.tower_stats["scores"] = scores
    model.tower_stats["raw_scores"] = raw_scores
    model.tower_stats["choice_loss"] = choice_loss
    model.tower_stats["choice_beta"] = model.choice_model.beta
    model.tower_stats[
        "choice_score_no_click"] = model.choice_model.score_no_click

    logger.debug(f"loss calculation took {time.time()-start}s")
    return choice_loss, q_value_loss