Example #1
0
    def __init__(
        self,
        actions,
        actions_logp,
        actions_entropy,
        dones,
        behaviour_action_logp,
        behaviour_logits,
        target_logits,
        discount,
        rewards,
        values,
        bootstrap_value,
        dist_class,
        model,
        valid_mask,
        config,
        vf_loss_coeff=0.5,
        entropy_coeff=0.01,
        clip_rho_threshold=1.0,
        clip_pg_rho_threshold=1.0,
    ):
        """Policy gradient loss with vtrace importance weighting.

        VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
        batch_size. The reason we need to know `B` is for V-trace to properly
        handle episode cut boundaries.

        Args:
            actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
            actions_logp: A float32 tensor of shape [T, B].
            actions_entropy: A float32 tensor of shape [T, B].
            dones: A bool tensor of shape [T, B].
            behaviour_action_logp: Tensor of shape [T, B].
            behaviour_logits: A list with length of ACTION_SPACE of float32
                tensors of shapes
                [T, B, ACTION_SPACE[0]],
                ...,
                [T, B, ACTION_SPACE[-1]]
            target_logits: A list with length of ACTION_SPACE of float32
                tensors of shapes
                [T, B, ACTION_SPACE[0]],
                ...,
                [T, B, ACTION_SPACE[-1]]
            discount: A float32 scalar.
            rewards: A float32 tensor of shape [T, B].
            values: A float32 tensor of shape [T, B].
            bootstrap_value: A float32 tensor of shape [B].
            dist_class: action distribution class for logits.
            valid_mask: A bool tensor of valid RNN input elements (#2992).
            config: Trainer config dict.
        """

        if valid_mask is None:
            valid_mask = torch.ones_like(actions_logp)

        # Compute vtrace on the CPU for better perf
        # (devices handled inside `vtrace.multi_from_logits`).
        device = behaviour_action_logp[0].device
        self.vtrace_returns = vtrace.multi_from_logits(
            behaviour_action_log_probs=behaviour_action_logp,
            behaviour_policy_logits=behaviour_logits,
            target_policy_logits=target_logits,
            actions=torch.unbind(actions, dim=2),
            discounts=(1.0 - dones.float()) * discount,
            rewards=rewards,
            values=values,
            bootstrap_value=bootstrap_value,
            dist_class=dist_class,
            model=model,
            clip_rho_threshold=clip_rho_threshold,
            clip_pg_rho_threshold=clip_pg_rho_threshold,
        )
        # Move v-trace results back to GPU for actual loss computing.
        self.value_targets = self.vtrace_returns.vs.to(device)

        # The policy gradients loss.
        self.pi_loss = -torch.sum(
            actions_logp * self.vtrace_returns.pg_advantages.to(device) *
            valid_mask)

        # The baseline loss.
        delta = (values - self.value_targets) * valid_mask
        self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))

        # The entropy loss.
        self.entropy = torch.sum(actions_entropy * valid_mask)
        self.mean_entropy = self.entropy / torch.sum(valid_mask)

        # The summed weighted loss.
        self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
                           self.entropy * entropy_coeff)
Example #2
0
def appo_surrogate_loss(policy: Policy, model: ModelV2,
                        dist_class: Type[TorchDistributionWrapper],
                        train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for APPO.

    With IS modifications and V-trace for Advantage Estimation.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]): The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    target_model = policy.target_models[model]

    model_out, _ = model.from_batch(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space,
                    gym.spaces.multi_discrete.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get("seq_lens"), *args,
                               **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

    target_model_out, _ = target_model.from_batch(train_batch)

    prev_action_dist = dist_class(behaviour_logits, model)
    values = model.value_function()
    values_time_major = _make_time_major(values)

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch["seq_lens"])
        mask = sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = torch.reshape(mask, [-1])
        mask = _make_time_major(mask, drop_last=policy.config["vtrace"])
        num_valid = torch.sum(mask)

        def reduce_mean_valid(t):
            return torch.sum(t[mask]) / num_valid

    else:
        reduce_mean_valid = torch.mean

    if policy.config["vtrace"]:
        logger.debug("Using V-Trace surrogate loss (vtrace=True)")

        old_policy_behaviour_logits = target_model_out.detach()
        old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

        if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
            unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                    list(output_hidden_shape),
                                                    dim=1)
            unpacked_old_policy_behaviour_logits = torch.split(
                old_policy_behaviour_logits, list(output_hidden_shape), dim=1)
        else:
            unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                    output_hidden_shape,
                                                    dim=1)
            unpacked_old_policy_behaviour_logits = torch.chunk(
                old_policy_behaviour_logits, output_hidden_shape, dim=1)

        # Prepare actions for loss.
        loss_actions = actions if is_multidiscrete else torch.unsqueeze(
            actions, dim=1)

        # Prepare KL for loss.
        action_kl = _make_time_major(old_policy_action_dist.kl(action_dist),
                                     drop_last=True)

        # Compute vtrace on the CPU for better perf.
        vtrace_returns = vtrace.multi_from_logits(
            behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits,
                                                     drop_last=True),
            target_policy_logits=_make_time_major(
                unpacked_old_policy_behaviour_logits, drop_last=True),
            actions=torch.unbind(_make_time_major(loss_actions,
                                                  drop_last=True),
                                 dim=2),
            discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) *
            policy.config["gamma"],
            rewards=_make_time_major(rewards, drop_last=True),
            values=values_time_major[:-1],  # drop-last=True
            bootstrap_value=values_time_major[-1],
            dist_class=TorchCategorical if is_multidiscrete else dist_class,
            model=model,
            clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]
        )

        actions_logp = _make_time_major(action_dist.logp(actions),
                                        drop_last=True)
        prev_actions_logp = _make_time_major(prev_action_dist.logp(actions),
                                             drop_last=True)
        old_policy_actions_logp = _make_time_major(
            old_policy_action_dist.logp(actions), drop_last=True)
        is_ratio = torch.clamp(
            torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
        logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
        policy._is_ratio = is_ratio

        advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
        surrogate_loss = torch.min(
            advantages * logp_ratio,
            advantages *
            torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                        1 + policy.config["clip_param"]))

        mean_kl = reduce_mean_valid(action_kl)
        mean_policy_loss = -reduce_mean_valid(surrogate_loss)

        # The value function loss.
        value_targets = vtrace_returns.vs.to(values_time_major.device)
        delta = values_time_major[:-1] - value_targets
        mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

        # The entropy loss.
        mean_entropy = reduce_mean_valid(
            _make_time_major(action_dist.entropy(), drop_last=True))

    else:
        logger.debug("Using PPO surrogate loss (vtrace=False)")

        # Prepare KL for Loss
        action_kl = _make_time_major(prev_action_dist.kl(action_dist))

        actions_logp = _make_time_major(action_dist.logp(actions))
        prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
        logp_ratio = torch.exp(actions_logp - prev_actions_logp)

        advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
        surrogate_loss = torch.min(
            advantages * logp_ratio,
            advantages *
            torch.clamp(logp_ratio, 1 - policy.config["clip_param"],
                        1 + policy.config["clip_param"]))

        mean_kl = reduce_mean_valid(action_kl)
        mean_policy_loss = -reduce_mean_valid(surrogate_loss)

        # The value function loss.
        value_targets = _make_time_major(
            train_batch[Postprocessing.VALUE_TARGETS])
        delta = values_time_major - value_targets
        mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

        # The entropy loss.
        mean_entropy = reduce_mean_valid(
            _make_time_major(action_dist.entropy()))

    # The summed weighted loss
    total_loss = mean_policy_loss + \
        mean_vf_loss * policy.config["vf_loss_coeff"] - \
        mean_entropy * policy.config["entropy_coeff"]

    # Optional additional KL Loss
    if policy.config["use_kl_loss"]:
        total_loss += policy.kl_coeff * mean_kl

    policy._total_loss = total_loss
    policy._mean_policy_loss = mean_policy_loss
    policy._mean_kl = mean_kl
    policy._mean_vf_loss = mean_vf_loss
    policy._mean_entropy = mean_entropy
    policy._value_targets = value_targets
    policy._vf_explained_var = explained_variance(
        torch.reshape(value_targets, [-1]),
        torch.reshape(
            values_time_major[:-1]
            if policy.config["vtrace"] else values_time_major, [-1]),
    )

    return total_loss
Example #3
0
    def __init__(self,
                 actions,
                 prev_actions_logp,
                 actions_logp,
                 old_policy_actions_logp,
                 action_kl,
                 actions_entropy,
                 dones,
                 behaviour_logits,
                 old_policy_behaviour_logits,
                 target_logits,
                 discount,
                 rewards,
                 values,
                 bootstrap_value,
                 dist_class,
                 model,
                 valid_mask,
                 vf_loss_coeff=0.5,
                 entropy_coeff=0.01,
                 clip_rho_threshold=1.0,
                 clip_pg_rho_threshold=1.0,
                 clip_param=0.3,
                 cur_kl_coeff=None,
                 use_kl_loss=False):
        """APPO Loss, with IS modifications and V-trace for Advantage Estimation

        VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
        batch_size. The reason we need to know `B` is for V-trace to properly
        handle episode cut boundaries.

        Arguments:
            actions: An int|float32 tensor of shape [T, B, logit_dim].
            prev_actions_logp: A float32 tensor of shape [T, B].
            actions_logp: A float32 tensor of shape [T, B].
            old_policy_actions_logp: A float32 tensor of shape [T, B].
            action_kl: A float32 tensor of shape [T, B].
            actions_entropy: A float32 tensor of shape [T, B].
            dones: A bool tensor of shape [T, B].
            behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
            old_policy_behaviour_logits: A float32 tensor of shape
            [T, B, logit_dim].
            target_logits: A float32 tensor of shape [T, B, logit_dim].
            discount: A float32 scalar.
            rewards: A float32 tensor of shape [T, B].
            values: A float32 tensor of shape [T, B].
            bootstrap_value: A float32 tensor of shape [B].
            dist_class: action distribution class for logits.
            model: backing ModelV2 instance
            valid_mask: A bool tensor of valid RNN input elements (#2992).
            vf_loss_coeff (float): Coefficient of the value function loss.
            entropy_coeff (float): Coefficient of the entropy regularizer.
            clip_param (float): Clip parameter.
            cur_kl_coeff (float): Coefficient for KL loss.
            use_kl_loss (bool): If true, use KL loss.
        """

        if valid_mask is not None:
            num_valid = torch.sum(valid_mask)

            def reduce_mean_valid(t):
                return torch.sum(t * valid_mask) / num_valid

        else:

            def reduce_mean_valid(t):
                return torch.mean(t)

        # Compute vtrace on the CPU for better perf.
        self.vtrace_returns = vtrace.multi_from_logits(
            behaviour_policy_logits=behaviour_logits,
            target_policy_logits=old_policy_behaviour_logits,
            actions=torch.unbind(actions, dim=2),
            discounts=(1.0 - dones.float()) * discount,
            rewards=rewards,
            values=values,
            bootstrap_value=bootstrap_value,
            dist_class=dist_class,
            model=model,
            clip_rho_threshold=clip_rho_threshold,
            clip_pg_rho_threshold=clip_pg_rho_threshold)

        self.is_ratio = torch.clamp(
            torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
        logp_ratio = self.is_ratio * torch.exp(actions_logp -
                                               prev_actions_logp)

        advantages = self.vtrace_returns.pg_advantages
        surrogate_loss = torch.min(
            advantages * logp_ratio,
            advantages *
            torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param))

        self.mean_kl = reduce_mean_valid(action_kl)
        self.pi_loss = -reduce_mean_valid(surrogate_loss)

        # The baseline loss
        delta = values - self.vtrace_returns.vs
        self.value_targets = self.vtrace_returns.vs
        self.vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

        # The entropy loss
        self.entropy = reduce_mean_valid(actions_entropy)

        # The summed weighted loss
        self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
                           self.entropy * entropy_coeff)

        # Optional additional KL Loss
        if use_kl_loss:
            self.total_loss += cur_kl_coeff * self.mean_kl
Example #4
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[ActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for APPO.

        With IS modifications and V-trace for Advantage Estimation.

        Args:
            model (ModelV2): The Model to calculate the loss for.
            dist_class (Type[ActionDistribution]): The action distr. class.
            train_batch (SampleBatch): The training data.

        Returns:
            Union[TensorType, List[TensorType]]: A single loss tensor or a list
                of loss tensors.
        """
        target_model = self.target_models[model]

        model_out, _ = model(train_batch)
        action_dist = dist_class(model_out, model)

        if isinstance(self.action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            output_hidden_shape = [self.action_space.n]
        elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            output_hidden_shape = self.action_space.nvec.astype(np.int32)
        else:
            is_multidiscrete = False
            output_hidden_shape = 1

        def _make_time_major(*args, **kwargs):
            return make_time_major(
                self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs
            )

        actions = train_batch[SampleBatch.ACTIONS]
        dones = train_batch[SampleBatch.DONES]
        rewards = train_batch[SampleBatch.REWARDS]
        behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

        target_model_out, _ = target_model(train_batch)

        prev_action_dist = dist_class(behaviour_logits, model)
        values = model.value_function()
        values_time_major = _make_time_major(values)

        drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"]

        if self.is_recurrent():
            max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
            mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
            mask = torch.reshape(mask, [-1])
            mask = _make_time_major(mask, drop_last=drop_last)
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        else:
            reduce_mean_valid = torch.mean

        if self.config["vtrace"]:
            logger.debug(
                "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})"
            )

            old_policy_behaviour_logits = target_model_out.detach()
            old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

            if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
                unpacked_behaviour_logits = torch.split(
                    behaviour_logits, list(output_hidden_shape), dim=1
                )
                unpacked_old_policy_behaviour_logits = torch.split(
                    old_policy_behaviour_logits, list(output_hidden_shape), dim=1
                )
            else:
                unpacked_behaviour_logits = torch.chunk(
                    behaviour_logits, output_hidden_shape, dim=1
                )
                unpacked_old_policy_behaviour_logits = torch.chunk(
                    old_policy_behaviour_logits, output_hidden_shape, dim=1
                )

            # Prepare actions for loss.
            loss_actions = (
                actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
            )

            # Prepare KL for loss.
            action_kl = _make_time_major(
                old_policy_action_dist.kl(action_dist), drop_last=drop_last
            )

            # Compute vtrace on the CPU for better perf.
            vtrace_returns = vtrace.multi_from_logits(
                behaviour_policy_logits=_make_time_major(
                    unpacked_behaviour_logits, drop_last=drop_last
                ),
                target_policy_logits=_make_time_major(
                    unpacked_old_policy_behaviour_logits, drop_last=drop_last
                ),
                actions=torch.unbind(
                    _make_time_major(loss_actions, drop_last=drop_last), dim=2
                ),
                discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float())
                * self.config["gamma"],
                rewards=_make_time_major(rewards, drop_last=drop_last),
                values=values_time_major[:-1] if drop_last else values_time_major,
                bootstrap_value=values_time_major[-1],
                dist_class=TorchCategorical if is_multidiscrete else dist_class,
                model=model,
                clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
                clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
            )

            actions_logp = _make_time_major(
                action_dist.logp(actions), drop_last=drop_last
            )
            prev_actions_logp = _make_time_major(
                prev_action_dist.logp(actions), drop_last=drop_last
            )
            old_policy_actions_logp = _make_time_major(
                old_policy_action_dist.logp(actions), drop_last=drop_last
            )
            is_ratio = torch.clamp(
                torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
            )
            logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
            self._is_ratio = is_ratio

            advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
            surrogate_loss = torch.min(
                advantages * logp_ratio,
                advantages
                * torch.clamp(
                    logp_ratio,
                    1 - self.config["clip_param"],
                    1 + self.config["clip_param"],
                ),
            )

            mean_kl_loss = reduce_mean_valid(action_kl)
            mean_policy_loss = -reduce_mean_valid(surrogate_loss)

            # The value function loss.
            value_targets = vtrace_returns.vs.to(values_time_major.device)
            if drop_last:
                delta = values_time_major[:-1] - value_targets
            else:
                delta = values_time_major - value_targets
            mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

            # The entropy loss.
            mean_entropy = reduce_mean_valid(
                _make_time_major(action_dist.entropy(), drop_last=drop_last)
            )

        else:
            logger.debug("Using PPO surrogate loss (vtrace=False)")

            # Prepare KL for Loss
            action_kl = _make_time_major(prev_action_dist.kl(action_dist))

            actions_logp = _make_time_major(action_dist.logp(actions))
            prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
            logp_ratio = torch.exp(actions_logp - prev_actions_logp)

            advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
            surrogate_loss = torch.min(
                advantages * logp_ratio,
                advantages
                * torch.clamp(
                    logp_ratio,
                    1 - self.config["clip_param"],
                    1 + self.config["clip_param"],
                ),
            )

            mean_kl_loss = reduce_mean_valid(action_kl)
            mean_policy_loss = -reduce_mean_valid(surrogate_loss)

            # The value function loss.
            value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
            delta = values_time_major - value_targets
            mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

            # The entropy loss.
            mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))

        # The summed weighted loss
        total_loss = (
            mean_policy_loss
            + mean_vf_loss * self.config["vf_loss_coeff"]
            - mean_entropy * self.entropy_coeff
        )

        # Optional additional KL Loss
        if self.config["use_kl_loss"]:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_kl_loss"] = mean_kl_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["value_targets"] = value_targets
        model.tower_stats["vf_explained_var"] = explained_variance(
            torch.reshape(value_targets, [-1]),
            torch.reshape(
                values_time_major[:-1] if drop_last else values_time_major, [-1]
            ),
        )

        return total_loss