Beispiel #1
0
def action_distribution_fn(
    policy: Policy,
    model: ModelV2,
    input_dict: ModelInputDict,
    *,
    state_batches: Optional[List[TensorType]] = None,
    seq_lens: Optional[TensorType] = None,
    prev_action_batch: Optional[TensorType] = None,
    prev_reward_batch=None,
    explore: Optional[bool] = None,
    timestep: Optional[int] = None,
    is_training: Optional[bool] = None
) -> Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]:
    """The action distribution function to be used the algorithm.

    An action distribution function is used to customize the choice of action
    distribution class and the resulting action distribution inputs (to
    parameterize the distribution object).
    After parameterizing the distribution, a `sample()` call
    will be made on it to generate actions.

    Args:
        policy: The Policy being queried for actions and calling this
            function.
        model (TorchModelV2): The SAC specific model to use to generate the
            distribution inputs (see sac_tf|torch_model.py). Must support the
            `get_action_model_outputs` method.
        input_dict: The input-dict to be used for the model
            call.
        state_batches (Optional[List[TensorType]]): The list of internal state
            tensor batches.
        seq_lens (Optional[TensorType]): The tensor of sequence lengths used
            in RNNs.
        prev_action_batch (Optional[TensorType]): Optional batch of prev
            actions used by the model.
        prev_reward_batch (Optional[TensorType]): Optional batch of prev
            rewards used by the model.
        explore (Optional[bool]): Whether to activate exploration or not. If
            None, use value of `config.explore`.
        timestep (Optional[int]): An optional timestep.
        is_training (Optional[bool]): An optional is-training flag.

    Returns:
        Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]:
            The dist inputs, dist class, and a list of internal state outputs
            (in the RNN case).
    """
    # Get base-model output (w/o the SAC specific parts of the network).
    model_out, _ = model(input_dict, [], None)
    # Use the base output to get the policy outputs from the SAC model's
    # policy components.
    action_dist_inputs, _ = model.get_action_model_outputs(model_out)
    # Get a distribution class to be used with the just calculated dist-inputs.
    action_dist_class = _get_dist_class(policy, policy.config,
                                        policy.action_space)

    return action_dist_inputs, action_dist_class, []
Beispiel #2
0
def get_distribution_inputs_and_class(
    policy: Policy,
    model: ModelV2,
    obs_batch: TensorType,
    *,
    explore: bool = True,
    **kwargs
) -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]:
    """The action distribution function to be used the algorithm.

    An action distribution function is used to customize the choice of action
    distribution class and the resulting action distribution inputs (to
    parameterize the distribution object).
    After parameterizing the distribution, a `sample()` call
    will be made on it to generate actions.

    Args:
        policy (Policy): The Policy being queried for actions and calling this
            function.
        model (SACTFModel): The SAC specific Model to use to generate the
            distribution inputs (see sac_tf|torch_model.py). Must support the
            `get_action_model_outputs` method.
        obs_batch (TensorType): The observations to be used as inputs to the
            model.
        explore (bool): Whether to activate exploration or not.

    Returns:
        Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The
            dist inputs, dist class, and a list of internal state outputs
            (in the RNN case).
    """
    # Get base-model (forward) output (this should be a noop call).
    forward_out, state_out = model(
        SampleBatch(obs=obs_batch,
                    _is_training=policy._get_is_training_placeholder()),
        [],
        None,
    )
    # Use the base output to get the policy outputs from the SAC model's
    # policy components.
    distribution_inputs, _ = model.get_action_model_outputs(forward_out)
    # Get a distribution class to be used with the just calculated dist-inputs.
    action_dist_class = _get_dist_class(policy, policy.config,
                                        policy.action_space)

    return distribution_inputs, action_dist_class, state_out
Beispiel #3
0
def actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy: The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch: The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    # Look up the target model (tower) using the model tower.
    target_model = policy.target_models[model]

    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    model_out_t, _ = model(
        SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True),
        [], None)

    model_out_tp1, _ = model(
        SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True),
        [], None)

    target_model_out_tp1, _ = target_model(
        SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True),
        [], None)

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
        log_pis_t = F.log_softmax(action_dist_inputs_t, dim=-1)
        policy_t = torch.exp(log_pis_t)
        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1)
        log_pis_tp1 = F.log_softmax(action_dist_inputs_tp1, -1)
        policy_tp1 = torch.exp(log_pis_tp1)
        # Q-values.
        q_t, _ = model.get_q_values(model_out_t)
        # Target Q-values.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1)
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(model_out_t)
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1)
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
        action_dist_t = action_dist_class(action_dist_inputs_t, model)
        policy_t = (action_dist_t.sample() if not deterministic else
                    action_dist_t.deterministic_sample())
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1)
        action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
        policy_tp1 = (action_dist_tp1.sample() if not deterministic else
                      action_dist_tp1.deterministic_sample())
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t, _ = model.get_q_values(model_out_t,
                                    train_batch[SampleBatch.ACTIONS])
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(
                model_out_t, train_batch[SampleBatch.ACTIONS])

        # Q-values for current policy in given current state.
        q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t)
        if policy.config["twin_q"]:
            twin_q_t_det_policy, _ = model.get_twin_q_values(
                model_out_t, policy_t)
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1)
        if policy.config["twin_q"]:
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1, policy_tp1)
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error)))

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = torch.mean(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach(),
                ),
                dim=-1,
            ))
    else:
        alpha_loss = -torch.mean(model.log_alpha *
                                 (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t
    model.tower_stats["policy_t"] = policy_t
    model.tower_stats["log_pis_t"] = log_pis_t
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss

    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss])
Beispiel #4
0
def sac_actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TFActionDistribution],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    _is_training = policy._get_is_training_placeholder()
    # Get the base model output from the train batch.
    model_out_t, _ = model(
        SampleBatch(obs=train_batch[SampleBatch.CUR_OBS],
                    _is_training=_is_training),
        [],
        None,
    )

    # Get the base model output from the next observations in the train batch.
    model_out_tp1, _ = model(
        SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS],
                    _is_training=_is_training),
        [],
        None,
    )

    # Get the target model's base outputs from the next observations in the
    # train batch.
    target_model_out_tp1, _ = policy.target_model(
        SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS],
                    _is_training=_is_training),
        [],
        None,
    )

    # Discrete actions case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
        log_pis_t = tf.nn.log_softmax(action_dist_inputs_t, -1)
        policy_t = tf.math.exp(log_pis_t)

        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1)
        log_pis_tp1 = tf.nn.log_softmax(action_dist_inputs_tp1, -1)
        policy_tp1 = tf.math.exp(log_pis_tp1)

        # Q-values.
        q_t, _ = model.get_q_values(model_out_t)
        # Target Q-values.
        q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1)
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(model_out_t)
            twin_q_tp1, _ = policy.target_model.get_twin_q_values(
                target_model_out_tp1)
            q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)
        q_tp1 -= model.alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = tf.one_hot(train_batch[SampleBatch.ACTIONS],
                             depth=q_t.shape.as_list()[-1])
        q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1)
        q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES],
                                           tf.float32)) * q_tp1_best
    # Continuous actions case.
    else:
        # Sample simgle actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
        action_dist_t = action_dist_class(action_dist_inputs_t, policy.model)
        policy_t = (action_dist_t.sample() if not deterministic else
                    action_dist_t.deterministic_sample())
        log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1)

        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1)
        action_dist_tp1 = action_dist_class(action_dist_inputs_tp1,
                                            policy.model)
        policy_tp1 = (action_dist_tp1.sample() if not deterministic else
                      action_dist_tp1.deterministic_sample())
        log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t, _ = model.get_q_values(
            model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32))
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(
                model_out_t,
                tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32))

        # Q-values for current policy in given current state.
        q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t)
        if policy.config["twin_q"]:
            twin_q_t_det_policy, _ = model.get_twin_q_values(
                model_out_t, policy_t)
            q_t_det_policy = tf.reduce_min(
                (q_t_det_policy, twin_q_t_det_policy), axis=0)

        # target q network evaluation
        q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1,
                                                    policy_tp1)
        if policy.config["twin_q"]:
            twin_q_tp1, _ = policy.target_model.get_twin_q_values(
                target_model_out_tp1, policy_tp1)
            # Take min over both twin-NNs.
            q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0)

        q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1)
        if policy.config["twin_q"]:
            twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1)
        q_tp1 -= model.alpha * log_pis_tp1

        q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1)
        q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES],
                                           tf.float32)) * q_tp1_best

    # Compute RHS of bellman equation for the Q-loss (critic(s)).
    q_t_selected_target = tf.stop_gradient(
        tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) +
        policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked)

    # Compute the TD-error (potentially clipped).
    base_td_error = tf.math.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    # Calculate one or two critic losses (2 in the twin_q case).
    prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)
    critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))]
    if policy.config["twin_q"]:
        critic_loss.append(
            tf.reduce_mean(prio_weights * huber_loss(twin_td_error)))

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        alpha_loss = tf.reduce_mean(
            tf.reduce_sum(
                tf.multiply(
                    tf.stop_gradient(policy_t),
                    -model.log_alpha *
                    tf.stop_gradient(log_pis_t + model.target_entropy),
                ),
                axis=-1,
            ))
        actor_loss = tf.reduce_mean(
            tf.reduce_sum(
                tf.multiply(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    model.alpha * log_pis_t - tf.stop_gradient(q_t),
                ),
                axis=-1,
            ))
    else:
        alpha_loss = -tf.reduce_mean(
            model.log_alpha *
            tf.stop_gradient(log_pis_t + model.target_entropy))
        actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy)

    # Save for stats function.
    policy.policy_t = policy_t
    policy.q_t = q_t
    policy.td_error = td_error
    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.alpha_value = model.alpha
    policy.target_entropy = model.target_entropy

    # In a custom apply op we handle the losses separately, but return them
    # combined in one loss here.
    return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
Beispiel #5
0
def cql_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TFActionDistribution],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    logger.info(f"Current iteration = {policy.cur_iter}")
    policy.cur_iter += 1

    # For best performance, turn deterministic off
    deterministic = policy.config["_deterministic_loss"]
    assert not deterministic
    twin_q = policy.config["twin_q"]
    discount = policy.config["gamma"]

    # CQL Parameters
    bc_iters = policy.config["bc_iters"]
    cql_temp = policy.config["temperature"]
    num_actions = policy.config["num_actions"]
    min_q_weight = policy.config["min_q_weight"]
    use_lagrange = policy.config["lagrangian"]
    target_action_gap = policy.config["lagrangian_thresh"]

    obs = train_batch[SampleBatch.CUR_OBS]
    actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)
    rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32)
    next_obs = train_batch[SampleBatch.NEXT_OBS]
    terminals = train_batch[SampleBatch.DONES]

    model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)

    model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [],
                             None)

    target_model_out_tp1, _ = policy.target_model(
        SampleBatch(obs=next_obs, _is_training=True), [], None)

    action_dist_class = _get_dist_class(policy, policy.config,
                                        policy.action_space)
    action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
    action_dist_t = action_dist_class(action_dist_inputs_t, model)
    policy_t, log_pis_t = action_dist_t.sample_logp()
    log_pis_t = tf.expand_dims(log_pis_t, -1)

    # Unlike original SAC, Alpha and Actor Loss are computed first.
    # Alpha Loss
    alpha_loss = -tf.reduce_mean(
        model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy))

    # Policy Loss (Either Behavior Clone Loss or SAC Loss)
    alpha = tf.math.exp(model.log_alpha)
    if policy.cur_iter >= bc_iters:
        min_q, _ = model.get_q_values(model_out_t, policy_t)
        if twin_q:
            twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
            min_q = tf.math.minimum(min_q, twin_q_)
        actor_loss = tf.reduce_mean(
            tf.stop_gradient(alpha) * log_pis_t - min_q)
    else:
        bc_logp = action_dist_t.logp(actions)
        actor_loss = tf.reduce_mean(
            tf.stop_gradient(alpha) * log_pis_t - bc_logp)
        # actor_loss = -tf.reduce_mean(bc_logp)

    # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
    # SAC Loss:
    # Q-values for the batched actions.
    action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
    action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
    policy_tp1, _ = action_dist_tp1.sample_logp()

    q_t, _ = model.get_q_values(model_out_t, actions)
    q_t_selected = tf.squeeze(q_t, axis=-1)
    if twin_q:
        twin_q_t, _ = model.get_twin_q_values(model_out_t, actions)
        twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1)

    # Target q network evaluation.
    q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1,
                                                policy_tp1)
    if twin_q:
        twin_q_tp1, _ = policy.target_model.get_twin_q_values(
            target_model_out_tp1, policy_tp1)
        # Take min over both twin-NNs.
        q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1)

    q_tp1_best = tf.squeeze(input=q_tp1, axis=-1)
    q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best

    # compute RHS of bellman equation
    q_t_target = tf.stop_gradient(rewards +
                                  (discount**policy.config["n_step"]) *
                                  q_tp1_best_masked)

    # Compute the TD-error (potentially clipped), for priority replay buffer
    base_td_error = tf.math.abs(q_t_selected - q_t_target)
    if twin_q:
        twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target)
    if twin_q:
        critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target)

    # CQL Loss (We are using Entropy version of CQL (the best version))
    rand_actions, _ = policy._random_action_generator.get_exploration_action(
        action_distribution=action_dist_class(
            tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model),
        timestep=0,
        explore=True,
    )
    curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_t, num_actions)
    next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_tp1, num_actions)

    q1_rand = q_values_repeat(model, model_out_t, rand_actions)
    q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
    q1_next_actions = q_values_repeat(model, model_out_t, next_actions)

    if twin_q:
        q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
        q2_curr_actions = q_values_repeat(model,
                                          model_out_t,
                                          curr_actions,
                                          twin=True)
        q2_next_actions = q_values_repeat(model,
                                          model_out_t,
                                          next_actions,
                                          twin=True)

    random_density = np.log(0.5**int(curr_actions.shape[-1]))
    cat_q1 = tf.concat(
        [
            q1_rand - random_density,
            q1_next_actions - tf.stop_gradient(next_logp),
            q1_curr_actions - tf.stop_gradient(curr_logp),
        ],
        1,
    )
    if twin_q:
        cat_q2 = tf.concat(
            [
                q2_rand - random_density,
                q2_next_actions - tf.stop_gradient(next_logp),
                q2_curr_actions - tf.stop_gradient(curr_logp),
            ],
            1,
        )

    min_qf1_loss_ = (
        tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) *
        min_q_weight * cql_temp)
    min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight)
    if twin_q:
        min_qf2_loss_ = (
            tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) *
            min_q_weight * cql_temp)
        min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) *
                                        min_q_weight)

    if use_lagrange:
        alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0,
                                       1000000.0)[0]
        min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
        if twin_q:
            min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
            alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
        else:
            alpha_prime_loss = -min_qf1_loss

    cql_loss = [min_qf1_loss]
    if twin_q:
        cql_loss.append(min_qf2_loss)

    critic_loss = [critic_loss_1 + min_qf1_loss]
    if twin_q:
        critic_loss.append(critic_loss_2 + min_qf2_loss)

    # Save for stats function.
    policy.q_t = q_t_selected
    policy.policy_t = policy_t
    policy.log_pis_t = log_pis_t
    policy.td_error = td_error
    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.log_alpha_value = model.log_alpha
    policy.alpha_value = alpha
    policy.target_entropy = model.target_entropy
    # CQL Stats
    policy.cql_loss = cql_loss
    if use_lagrange:
        policy.log_alpha_prime_value = model.log_alpha_prime[0]
        policy.alpha_prime_value = alpha_prime
        policy.alpha_prime_loss = alpha_prime_loss

    # Return all loss terms corresponding to our optimizers.
    if use_lagrange:
        return actor_loss + tf.math.add_n(
            critic_loss) + alpha_loss + alpha_prime_loss
    return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
Beispiel #6
0
def actor_critic_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    target_model = policy.target_models[model]

    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches
    seq_lens = train_batch.get(SampleBatch.SEQ_LENS)

    model_out_t, state_in_t = model(
        SampleBatch(
            obs=train_batch[SampleBatch.CUR_OBS],
            prev_actions=train_batch[SampleBatch.PREV_ACTIONS],
            prev_rewards=train_batch[SampleBatch.PREV_REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"])

    model_out_tp1, state_in_tp1 = model(
        SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS],
            prev_actions=train_batch[SampleBatch.ACTIONS],
            prev_rewards=train_batch[SampleBatch.REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"])

    target_model_out_tp1, target_state_in_tp1 = target_model(
        SampleBatch(
            obs=train_batch[SampleBatch.NEXT_OBS],
            prev_actions=train_batch[SampleBatch.ACTIONS],
            prev_rewards=train_batch[SampleBatch.REWARDS],
            _is_training=True,
        ),
        state_batches,
        seq_lens,
    )
    target_states_in_tp1 = target_model.select_state(state_in_tp1,
                                                     ["policy", "q", "twin_q"])

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_t, states_in_t["policy"], seq_lens)
        log_pis_t = F.log_softmax(
            action_dist_inputs_t,
            dim=-1,
        )
        policy_t = torch.exp(log_pis_t)

        action_dist_inputs_tp1, _ = model.get_action_model_outputs(
            model_out_tp1, states_in_tp1["policy"], seq_lens)
        log_pis_tp1 = F.log_softmax(
            action_dist_inputs_tp1,
            -1,
        )
        policy_tp1 = torch.exp(log_pis_tp1)

        # Q-values.
        q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)
        # Target Q-values.
        q_tp1, _ = target_model.get_q_values(target_model_out_tp1,
                                             target_states_in_tp1["q"],
                                             seq_lens)
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(model_out_t,
                                                  states_in_t["twin_q"],
                                                  seq_lens)
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_t, states_in_t["policy"], seq_lens)
        action_dist_t = action_dist_class(
            action_dist_inputs_t,
            model,
        )
        policy_t = (action_dist_t.sample() if not deterministic else
                    action_dist_t.deterministic_sample())
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_inputs_t, _ = model.get_action_model_outputs(
            model_out_tp1, states_in_tp1["policy"], seq_lens)
        action_dist_tp1 = action_dist_class(
            action_dist_inputs_t,
            model,
        )
        policy_tp1 = (action_dist_tp1.sample() if not deterministic else
                      action_dist_tp1.deterministic_sample())
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens,
                                    train_batch[SampleBatch.ACTIONS])
        if policy.config["twin_q"]:
            twin_q_t, _ = model.get_twin_q_values(
                model_out_t,
                states_in_t["twin_q"],
                seq_lens,
                train_batch[SampleBatch.ACTIONS],
            )

        # Q-values for current policy in given current state.
        q_t_det_policy, _ = model.get_q_values(model_out_t, states_in_t["q"],
                                               seq_lens, policy_t)
        if policy.config["twin_q"]:
            twin_q_t_det_policy, _ = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens, policy_t)
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1 = target_model.get_q_values(target_model_out_tp1,
                                          target_states_in_tp1["q"], seq_lens,
                                          policy_tp1)
        if policy.config["twin_q"]:
            twin_q_tp1, _ = target_model.get_twin_q_values(
                target_model_out_tp1,
                target_states_in_tp1["twin_q"],
                seq_lens,
                policy_tp1,
            )
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = (
            1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # BURNIN #
    B = state_batches[0].shape[0]
    T = q_t_selected.shape[0] // B
    seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)
    # Mask away also the burn-in sequence at the beginning.
    burn_in = policy.config["burn_in"]
    if burn_in > 0 and burn_in < T:
        seq_mask[:, :burn_in] = False

    seq_mask = seq_mask.reshape(-1)
    num_valid = torch.sum(seq_mask)

    def reduce_mean_valid(t):
        return torch.sum(t[seq_mask]) / num_valid

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                          huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                              huber_loss(twin_td_error)))
    td_error = td_error * seq_mask

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = reduce_mean_valid(
            torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = reduce_mean_valid(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach(),
                ),
                dim=-1,
            ))
    else:
        alpha_loss = -reduce_mean_valid(
            model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t -
                                       q_t_det_policy)

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["q_t"] = q_t * seq_mask[..., None]
    model.tower_stats["policy_t"] = policy_t * seq_mask[..., None]
    model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None]
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss
    # Store per time chunk (b/c we need only one mean
    # prioritized replay weight per stored sequence).
    model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]),
                                               dim=-1)

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss])
Beispiel #7
0
def cql_loss(
    policy: Policy,
    model: ModelV2,
    dist_class: Type[TorchDistributionWrapper],
    train_batch: SampleBatch,
) -> Union[TensorType, List[TensorType]]:
    logger.info(f"Current iteration = {policy.cur_iter}")
    policy.cur_iter += 1

    # Look up the target model (tower) using the model tower.
    target_model = policy.target_models[model]

    # For best performance, turn deterministic off
    deterministic = policy.config["_deterministic_loss"]
    assert not deterministic
    twin_q = policy.config["twin_q"]
    discount = policy.config["gamma"]
    action_low = model.action_space.low[0]
    action_high = model.action_space.high[0]

    # CQL Parameters
    bc_iters = policy.config["bc_iters"]
    cql_temp = policy.config["temperature"]
    num_actions = policy.config["num_actions"]
    min_q_weight = policy.config["min_q_weight"]
    use_lagrange = policy.config["lagrangian"]
    target_action_gap = policy.config["lagrangian_thresh"]

    obs = train_batch[SampleBatch.CUR_OBS]
    actions = train_batch[SampleBatch.ACTIONS]
    rewards = train_batch[SampleBatch.REWARDS].float()
    next_obs = train_batch[SampleBatch.NEXT_OBS]
    terminals = train_batch[SampleBatch.DONES]

    model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None)

    model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [],
                             None)

    target_model_out_tp1, _ = target_model(
        SampleBatch(obs=next_obs, _is_training=True), [], None)

    action_dist_class = _get_dist_class(policy, policy.config,
                                        policy.action_space)
    action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t)
    action_dist_t = action_dist_class(action_dist_inputs_t, model)
    policy_t, log_pis_t = action_dist_t.sample_logp()
    log_pis_t = torch.unsqueeze(log_pis_t, -1)

    # Unlike original SAC, Alpha and Actor Loss are computed first.
    # Alpha Loss
    alpha_loss = -(model.log_alpha *
                   (log_pis_t + model.target_entropy).detach()).mean()

    batch_size = tree.flatten(obs)[0].shape[0]
    if batch_size == policy.config["train_batch_size"]:
        policy.alpha_optim.zero_grad()
        alpha_loss.backward()
        policy.alpha_optim.step()

    # Policy Loss (Either Behavior Clone Loss or SAC Loss)
    alpha = torch.exp(model.log_alpha)
    if policy.cur_iter >= bc_iters:
        min_q, _ = model.get_q_values(model_out_t, policy_t)
        if twin_q:
            twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t)
            min_q = torch.min(min_q, twin_q_)
        actor_loss = (alpha.detach() * log_pis_t - min_q).mean()
    else:
        bc_logp = action_dist_t.logp(actions)
        actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean()
        # actor_loss = -bc_logp.mean()

    if batch_size == policy.config["train_batch_size"]:
        policy.actor_optim.zero_grad()
        actor_loss.backward(retain_graph=True)
        policy.actor_optim.step()

    # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss)
    # SAC Loss:
    # Q-values for the batched actions.
    action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1)
    action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model)
    policy_tp1, _ = action_dist_tp1.sample_logp()

    q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS])
    q_t_selected = torch.squeeze(q_t, dim=-1)
    if twin_q:
        twin_q_t, _ = model.get_twin_q_values(model_out_t,
                                              train_batch[SampleBatch.ACTIONS])
        twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)

    # Target q network evaluation.
    q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1)
    if twin_q:
        twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1,
                                                       policy_tp1)
        # Take min over both twin-NNs.
        q_tp1 = torch.min(q_tp1, twin_q_tp1)

    q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
    q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_target = (
        rewards +
        (discount**policy.config["n_step"]) * q_tp1_best_masked).detach()

    # Compute the TD-error (potentially clipped), for priority replay buffer
    base_td_error = torch.abs(q_t_selected - q_t_target)
    if twin_q:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target)
    if twin_q:
        critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target)

    # CQL Loss (We are using Entropy version of CQL (the best version))
    rand_actions = convert_to_torch_tensor(
        torch.FloatTensor(actions.shape[0] * num_actions,
                          actions.shape[-1]).uniform_(action_low, action_high),
        policy.device,
    )
    curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_t, num_actions)
    next_actions, next_logp = policy_actions_repeat(model, action_dist_class,
                                                    model_out_tp1, num_actions)

    q1_rand = q_values_repeat(model, model_out_t, rand_actions)
    q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions)
    q1_next_actions = q_values_repeat(model, model_out_t, next_actions)

    if twin_q:
        q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True)
        q2_curr_actions = q_values_repeat(model,
                                          model_out_t,
                                          curr_actions,
                                          twin=True)
        q2_next_actions = q_values_repeat(model,
                                          model_out_t,
                                          next_actions,
                                          twin=True)

    random_density = np.log(0.5**curr_actions.shape[-1])
    cat_q1 = torch.cat(
        [
            q1_rand - random_density,
            q1_next_actions - next_logp.detach(),
            q1_curr_actions - curr_logp.detach(),
        ],
        1,
    )
    if twin_q:
        cat_q2 = torch.cat(
            [
                q2_rand - random_density,
                q2_next_actions - next_logp.detach(),
                q2_curr_actions - curr_logp.detach(),
            ],
            1,
        )

    min_qf1_loss_ = (torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() *
                     min_q_weight * cql_temp)
    min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight)
    if twin_q:
        min_qf2_loss_ = (torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() *
                         min_q_weight * cql_temp)
        min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight)

    if use_lagrange:
        alpha_prime = torch.clamp(model.log_alpha_prime.exp(),
                                  min=0.0,
                                  max=1000000.0)[0]
        min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap)
        if twin_q:
            min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap)
            alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss)
        else:
            alpha_prime_loss = -min_qf1_loss

    cql_loss = [min_qf1_loss]
    if twin_q:
        cql_loss.append(min_qf2_loss)

    critic_loss = [critic_loss_1 + min_qf1_loss]
    if twin_q:
        critic_loss.append(critic_loss_2 + min_qf2_loss)

    if batch_size == policy.config["train_batch_size"]:
        policy.critic_optims[0].zero_grad()
        critic_loss[0].backward(retain_graph=True)
        policy.critic_optims[0].step()

        if twin_q:
            policy.critic_optims[1].zero_grad()
            critic_loss[1].backward(retain_graph=False)
            policy.critic_optims[1].step()

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    # SAC stats.
    model.tower_stats["q_t"] = q_t_selected
    model.tower_stats["policy_t"] = policy_t
    model.tower_stats["log_pis_t"] = log_pis_t
    model.tower_stats["actor_loss"] = actor_loss
    model.tower_stats["critic_loss"] = critic_loss
    model.tower_stats["alpha_loss"] = alpha_loss
    model.tower_stats["log_alpha_value"] = model.log_alpha
    model.tower_stats["alpha_value"] = alpha
    model.tower_stats["target_entropy"] = model.target_entropy
    # CQL stats.
    model.tower_stats["cql_loss"] = cql_loss

    # TD-error tensor in final stats
    # will be concatenated and retrieved for each individual batch item.
    model.tower_stats["td_error"] = td_error

    if use_lagrange:
        model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0]
        model.tower_stats["alpha_prime_value"] = alpha_prime
        model.tower_stats["alpha_prime_loss"] = alpha_prime_loss

        if batch_size == policy.config["train_batch_size"]:
            policy.alpha_prime_optim.zero_grad()
            alpha_prime_loss.backward()
            policy.alpha_prime_optim.step()

    # Return all loss terms corresponding to our optimizers.
    return tuple([actor_loss] + critic_loss + [alpha_loss] +
                 ([alpha_prime_loss] if use_lagrange else []))