Ejemplo n.º 1
0
def loss_fn(policy: Policy, model: ModelV2,
            dist_class: TorchDistributionWrapper, sample_batch: SampleBatch):
    max_seq_len = sample_batch['seq_lens'].max().item()
    mask = sequence_mask(sample_batch['seq_lens'],
                         max_seq_len,
                         time_major=model.is_time_major()).view((-1, 1))
    mean_reg = sample_batch['seq_lens'].sum() * model.nbr_agents
    actions = sample_batch['actions'].view(
        (sample_batch['actions'].shape[0], model.nbr_agents,
         -1))[:, :, :1].to(torch.long)
    actions = add_time_dimension(actions,
                                 max_seq_len=max_seq_len,
                                 framework='torch',
                                 time_major=True).reshape_as(actions)

    logits_pi, _ = model(sample_batch, [
        sample_batch['state_in_0'],
    ], sample_batch['seq_lens'])
    logits_pi = logits_pi.view((logits_pi.shape[0], model.nbr_agents, -1))
    logits_pi_action = logits_pi[:, :, :model.nbr_actions]
    log_pi_action = nn.functional.log_softmax(logits_pi_action, dim=-1)
    pi_action = torch.exp(log_pi_action)
    log_pi_action_selected = torch.gather(log_pi_action, -1,
                                          actions).squeeze(-1)

    q_values = model.q_values(sample_batch, target=False)
    q_values = add_time_dimension(q_values,
                                  max_seq_len=max_seq_len,
                                  framework="torch",
                                  time_major=True).reshape_as(q_values)
    q_values_selected = torch.gather(q_values, -1, actions).squeeze(-1)
    q_values_target = sample_batch[Postprocessing.VALUE_TARGETS]
    q_values_target = add_time_dimension(
        q_values_target,
        max_seq_len=max_seq_len,
        framework="torch",
        time_major=True).reshape_as(q_values_target)
    td_error = q_values_selected - q_values_target

    with torch.no_grad():
        coma_avg = q_values_selected - (pi_action * q_values).sum(-1)
    entropy = -(log_pi_action * pi_action).sum(-1)

    critic_loss = torch.pow(mask * td_error, 2.0)
    actor_loss = mask * coma_avg * log_pi_action_selected
    entropy = mask * entropy

    policy.actor_loss = -actor_loss.sum() / mean_reg
    policy.critic_loss = critic_loss.sum() / mean_reg
    policy.entropy = entropy.sum() / mean_reg

    pi_loss = policy.actor_loss - policy.config[
        'entropy_coeff'] * policy.entropy

    return pi_loss, policy.critic_loss
Ejemplo n.º 2
0
def ppo_surrogate_loss(
        policy: Policy, model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for Proximal Policy Objective.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    logits, state = model.from_batch(train_batch, is_training=True)
    curr_action_dist = dist_class(logits, model)

    # RNN case: Mask away 0-padded chunks at end of time axis.
    if state:
        B = len(train_batch["seq_lens"])
        max_seq_len = logits.shape[0] // B
        mask = sequence_mask(train_batch["seq_lens"],
                             max_seq_len,
                             time_major=model.is_time_major())
        mask = torch.reshape(mask, [-1])
        num_valid = torch.sum(mask)

        def reduce_mean_valid(t):
            return torch.sum(t[mask]) / num_valid

    # non-RNN case: No masking.
    else:
        mask = None
        reduce_mean_valid = torch.mean

    prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS],
                                  model)

    logp_ratio = torch.exp(
        curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
        train_batch[SampleBatch.ACTION_LOGP])
    action_kl = prev_action_dist.kl(curr_action_dist)
    mean_kl = reduce_mean_valid(action_kl)

    curr_entropy = curr_action_dist.entropy()
    mean_entropy = reduce_mean_valid(curr_entropy)

    surrogate_loss = torch.min(
        train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
        train_batch[Postprocessing.ADVANTAGES] *
        torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                    1 + policy.config["clip_param"]))
    mean_policy_loss = reduce_mean_valid(-surrogate_loss)

    if policy.config["use_gae"]:
        prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
        value_fn_out = model.value_function()
        vf_loss1 = torch.pow(
            value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
        vf_clipped = prev_value_fn_out + torch.clamp(
            value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"],
            policy.config["vf_clip_param"])
        vf_loss2 = torch.pow(
            vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
        vf_loss = torch.max(vf_loss1, vf_loss2)
        mean_vf_loss = reduce_mean_valid(vf_loss)
        total_loss = reduce_mean_valid(-surrogate_loss +
                                       policy.kl_coeff * action_kl +
                                       policy.config["vf_loss_coeff"] *
                                       vf_loss -
                                       policy.entropy_coeff * curr_entropy)
    else:
        mean_vf_loss = 0.0
        total_loss = reduce_mean_valid(-surrogate_loss +
                                       policy.kl_coeff * action_kl -
                                       policy.entropy_coeff * curr_entropy)

    # Store stats in policy for stats_fn.
    policy._total_loss = total_loss
    policy._mean_policy_loss = mean_policy_loss
    policy._mean_vf_loss = mean_vf_loss
    policy._vf_explained_var = explained_variance(
        train_batch[Postprocessing.VALUE_TARGETS],
        policy.model.value_function())
    policy._mean_entropy = mean_entropy
    policy._mean_kl = mean_kl

    return total_loss
Ejemplo n.º 3
0
    def loss(self, model: ModelV2, dist_class: Type[ActionDistribution],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for Proximal Policy Objective.

        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.

        Returns:
            The PPO loss tensor given the input batch.
        """

        logits, state = model(train_batch)
        curr_action_dist = dist_class(logits, model)

        # RNN case: Mask away 0-padded chunks at end of time axis.
        if state:
            B = len(train_batch[SampleBatch.SEQ_LENS])
            max_seq_len = logits.shape[0] // B
            mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                 max_seq_len,
                                 time_major=model.is_time_major())
            mask = torch.reshape(mask, [-1])
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        # non-RNN case: No masking.
        else:
            mask = None
            reduce_mean_valid = torch.mean

        prev_action_dist = dist_class(
            train_batch[SampleBatch.ACTION_DIST_INPUTS], model)

        logp_ratio = torch.exp(
            curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
            train_batch[SampleBatch.ACTION_LOGP])

        # Only calculate kl loss if necessary (kl-coeff > 0.0).
        if self.config["kl_coeff"] > 0.0:
            action_kl = prev_action_dist.kl(curr_action_dist)
            mean_kl_loss = reduce_mean_valid(action_kl)
        else:
            mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)

        curr_entropy = curr_action_dist.entropy()
        mean_entropy = reduce_mean_valid(curr_entropy)

        surrogate_loss = torch.min(
            train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
            train_batch[Postprocessing.ADVANTAGES] *
            torch.clamp(logp_ratio, 1 - self.config["clip_param"],
                        1 + self.config["clip_param"]))
        mean_policy_loss = reduce_mean_valid(-surrogate_loss)

        # Compute a value function loss.
        if self.config["use_critic"]:
            prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
            value_fn_out = model.value_function()
            vf_loss1 = torch.pow(
                value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
            vf_clipped = prev_value_fn_out + torch.clamp(
                value_fn_out - prev_value_fn_out,
                -self.config["vf_clip_param"], self.config["vf_clip_param"])
            vf_loss2 = torch.pow(
                vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
            vf_loss = torch.max(vf_loss1, vf_loss2)
            mean_vf_loss = reduce_mean_valid(vf_loss)
        # Ignore the value function.
        else:
            vf_loss = mean_vf_loss = 0.0

        total_loss = reduce_mean_valid(-surrogate_loss +
                                       self.config["vf_loss_coeff"] * vf_loss -
                                       self.entropy_coeff * curr_entropy)

        # Add mean_kl_loss (already processed through `reduce_mean_valid`),
        # if necessary.
        if self.config["kl_coeff"] > 0.0:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["vf_explained_var"] = explained_variance(
            train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["mean_kl_loss"] = mean_kl_loss

        return total_loss