Exemple #1
0
def appo_surrogate_loss(
        policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
        train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for APPO.

    With IS modifications and V-trace for Advantage Estimation.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[ActionDistribution]): The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    model_out, _ = model.from_batch(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space,
                    gym.spaces.multi_discrete.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    # TODO: (sven) deprecate this when trajectory view API gets activated.
    def make_time_major(*args, **kw):
        return _make_time_major(policy, train_batch.get("seq_lens"), *args,
                                **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

    target_model_out, _ = policy.target_model.from_batch(train_batch)
    prev_action_dist = dist_class(behaviour_logits, policy.model)
    values = policy.model.value_function()
    values_time_major = make_time_major(values)

    policy.model_vars = policy.model.variables()
    policy.target_model_vars = policy.target_model.variables()

    if policy.is_recurrent():
        max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1
        mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = tf.reshape(mask, [-1])
        mask = make_time_major(mask, drop_last=policy.config["vtrace"])

        def reduce_mean_valid(t):
            return tf.reduce_mean(tf.boolean_mask(t, mask))

    else:
        reduce_mean_valid = tf.reduce_mean

    if policy.config["vtrace"]:
        logger.debug("Using V-Trace surrogate loss (vtrace=True)")

        # Prepare actions for loss.
        loss_actions = actions if is_multidiscrete else tf.expand_dims(actions,
                                                                       axis=1)

        old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
        old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

        # Prepare KL for Loss
        mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist),
                                  drop_last=True)

        unpacked_behaviour_logits = tf.split(behaviour_logits,
                                             output_hidden_shape,
                                             axis=1)
        unpacked_old_policy_behaviour_logits = tf.split(
            old_policy_behaviour_logits, output_hidden_shape, axis=1)

        # Compute vtrace on the CPU for better perf.
        with tf.device("/cpu:0"):
            vtrace_returns = vtrace.multi_from_logits(
                behaviour_policy_logits=make_time_major(
                    unpacked_behaviour_logits, drop_last=True),
                target_policy_logits=make_time_major(
                    unpacked_old_policy_behaviour_logits, drop_last=True),
                actions=tf.unstack(make_time_major(loss_actions,
                                                   drop_last=True),
                                   axis=2),
                discounts=tf.cast(
                    ~make_time_major(tf.cast(dones, tf.bool), drop_last=True),
                    tf.float32) * policy.config["gamma"],
                rewards=make_time_major(rewards, drop_last=True),
                values=values_time_major[:-1],  # drop-last=True
                bootstrap_value=values_time_major[-1],
                dist_class=Categorical if is_multidiscrete else dist_class,
                model=model,
                clip_rho_threshold=tf.cast(
                    policy.config["vtrace_clip_rho_threshold"], tf.float32),
                clip_pg_rho_threshold=tf.cast(
                    policy.config["vtrace_clip_pg_rho_threshold"], tf.float32),
            )

        actions_logp = make_time_major(action_dist.logp(actions),
                                       drop_last=True)
        prev_actions_logp = make_time_major(prev_action_dist.logp(actions),
                                            drop_last=True)
        old_policy_actions_logp = make_time_major(
            old_policy_action_dist.logp(actions), drop_last=True)

        is_ratio = tf.clip_by_value(
            tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
        logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
        policy._is_ratio = is_ratio

        advantages = vtrace_returns.pg_advantages
        surrogate_loss = tf.minimum(
            advantages * logp_ratio,
            advantages *
            tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
                             1 + policy.config["clip_param"]))

        action_kl = tf.reduce_mean(mean_kl, axis=0) \
            if is_multidiscrete else mean_kl
        mean_kl = reduce_mean_valid(action_kl)
        mean_policy_loss = -reduce_mean_valid(surrogate_loss)

        # The value function loss.
        delta = values_time_major[:-1] - vtrace_returns.vs
        value_targets = vtrace_returns.vs
        mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

        # The entropy loss.
        actions_entropy = make_time_major(action_dist.multi_entropy(),
                                          drop_last=True)
        mean_entropy = reduce_mean_valid(actions_entropy)

    else:
        logger.debug("Using PPO surrogate loss (vtrace=False)")

        # Prepare KL for Loss
        mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))

        logp_ratio = tf.math.exp(
            make_time_major(action_dist.logp(actions)) -
            make_time_major(prev_action_dist.logp(actions)))

        advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
        surrogate_loss = tf.minimum(
            advantages * logp_ratio,
            advantages *
            tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
                             1 + policy.config["clip_param"]))

        action_kl = tf.reduce_mean(mean_kl, axis=0) \
            if is_multidiscrete else mean_kl
        mean_kl = reduce_mean_valid(action_kl)
        mean_policy_loss = -reduce_mean_valid(surrogate_loss)

        # The value function loss.
        value_targets = make_time_major(
            train_batch[Postprocessing.VALUE_TARGETS])
        delta = values_time_major - value_targets
        mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

        # The entropy loss.
        mean_entropy = reduce_mean_valid(
            make_time_major(action_dist.multi_entropy()))

    # The summed weighted loss
    total_loss = mean_policy_loss + \
        mean_vf_loss * policy.config["vf_loss_coeff"] - \
        mean_entropy * policy.config["entropy_coeff"]

    # Optional additional KL Loss
    if policy.config["use_kl_loss"]:
        total_loss += policy.kl_coeff * mean_kl

    policy._total_loss = total_loss
    policy._mean_policy_loss = mean_policy_loss
    policy._mean_kl = mean_kl
    policy._mean_vf_loss = mean_vf_loss
    policy._mean_entropy = mean_entropy
    policy._value_targets = value_targets

    # Store stats in policy for stats_fn.
    return total_loss
Exemple #2
0
    def __init__(self,
                 actions,
                 prev_actions_logp,
                 actions_logp,
                 old_policy_actions_logp,
                 action_kl,
                 actions_entropy,
                 dones,
                 behaviour_logits,
                 old_policy_behaviour_logits,
                 target_logits,
                 discount,
                 rewards,
                 values,
                 bootstrap_value,
                 dist_class,
                 model,
                 valid_mask,
                 vf_loss_coeff=0.5,
                 entropy_coeff=0.01,
                 clip_rho_threshold=1.0,
                 clip_pg_rho_threshold=1.0,
                 clip_param=0.3,
                 cur_kl_coeff=None,
                 use_kl_loss=False):
        """APPO Loss, with IS modifications and V-trace for Advantage Estimation

        VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
        batch_size. The reason we need to know `B` is for V-trace to properly
        handle episode cut boundaries.

        Arguments:
            actions: An int|float32 tensor of shape [T, B, logit_dim].
            prev_actions_logp: A float32 tensor of shape [T, B].
            actions_logp: A float32 tensor of shape [T, B].
            old_policy_actions_logp: A float32 tensor of shape [T, B].
            action_kl: A float32 tensor of shape [T, B].
            actions_entropy: A float32 tensor of shape [T, B].
            dones: A bool tensor of shape [T, B].
            behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
            old_policy_behaviour_logits: A float32 tensor of shape
            [T, B, logit_dim].
            target_logits: A float32 tensor of shape [T, B, logit_dim].
            discount: A float32 scalar.
            rewards: A float32 tensor of shape [T, B].
            values: A float32 tensor of shape [T, B].
            bootstrap_value: A float32 tensor of shape [B].
            dist_class: action distribution class for logits.
            model: backing ModelV2 instance
            valid_mask: A bool tensor of valid RNN input elements (#2992).
            vf_loss_coeff (float): Coefficient of the value function loss.
            entropy_coeff (float): Coefficient of the entropy regularizer.
            clip_param (float): Clip parameter.
            cur_kl_coeff (float): Coefficient for KL loss.
            use_kl_loss (bool): If true, use KL loss.
        """

        def reduce_mean_valid(t):
            return tf.reduce_mean(tf.boolean_mask(t, valid_mask))

        # Compute vtrace on the CPU for better perf.
        with tf.device("/cpu:0"):
            self.vtrace_returns = vtrace.multi_from_logits(
                behaviour_policy_logits=behaviour_logits,
                target_policy_logits=old_policy_behaviour_logits,
                actions=tf.unstack(actions, axis=2),
                discounts=tf.to_float(~dones) * discount,
                rewards=rewards,
                values=values,
                bootstrap_value=bootstrap_value,
                dist_class=dist_class,
                model=model,
                clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
                clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
                                              tf.float32))

        self.is_ratio = tf.clip_by_value(
            tf.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
        logp_ratio = self.is_ratio * tf.exp(actions_logp - prev_actions_logp)

        advantages = self.vtrace_returns.pg_advantages
        surrogate_loss = tf.minimum(
            advantages * logp_ratio,
            advantages * tf.clip_by_value(logp_ratio, 1 - clip_param,
                                          1 + clip_param))

        self.mean_kl = reduce_mean_valid(action_kl)
        self.pi_loss = -reduce_mean_valid(surrogate_loss)

        # The baseline loss
        delta = values - self.vtrace_returns.vs
        self.value_targets = self.vtrace_returns.vs
        self.vf_loss = 0.5 * reduce_mean_valid(tf.square(delta))

        # The entropy loss
        self.entropy = reduce_mean_valid(actions_entropy)

        # The summed weighted loss
        self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
                           self.entropy * entropy_coeff)

        # Optional additional KL Loss
        if use_kl_loss:
            self.total_loss += cur_kl_coeff * self.mean_kl
Exemple #3
0
    def __init__(self,
                 actions,
                 actions_logp,
                 actions_entropy,
                 dones,
                 behaviour_action_logp,
                 behaviour_logits,
                 target_logits,
                 discount,
                 rewards,
                 values,
                 bootstrap_value,
                 dist_class,
                 model,
                 valid_mask,
                 config,
                 vf_loss_coeff=0.5,
                 entropy_coeff=0.01,
                 clip_rho_threshold=1.0,
                 clip_pg_rho_threshold=1.0):
        """Policy gradient loss with vtrace importance weighting.

        VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
        batch_size. The reason we need to know `B` is for V-trace to properly
        handle episode cut boundaries.

        Args:
            actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
            actions_logp: A float32 tensor of shape [T, B].
            actions_entropy: A float32 tensor of shape [T, B].
            dones: A bool tensor of shape [T, B].
            behaviour_action_logp: Tensor of shape [T, B].
            behaviour_logits: A list with length of ACTION_SPACE of float32
                tensors of shapes
                [T, B, ACTION_SPACE[0]],
                ...,
                [T, B, ACTION_SPACE[-1]]
            target_logits: A list with length of ACTION_SPACE of float32
                tensors of shapes
                [T, B, ACTION_SPACE[0]],
                ...,
                [T, B, ACTION_SPACE[-1]]
            discount: A float32 scalar.
            rewards: A float32 tensor of shape [T, B].
            values: A float32 tensor of shape [T, B].
            bootstrap_value: A float32 tensor of shape [B].
            dist_class: action distribution class for logits.
            valid_mask: A bool tensor of valid RNN input elements (#2992).
            config: Trainer config dict.
        """

        # Compute vtrace on the CPU for better perf.
        with tf.device("/cpu:0"):
            self.vtrace_returns = vtrace.multi_from_logits(
                behaviour_action_log_probs=behaviour_action_logp,
                behaviour_policy_logits=behaviour_logits,
                target_policy_logits=target_logits,
                actions=tf.unstack(actions, axis=2),
                discounts=tf.to_float(~dones) * discount,
                rewards=rewards,
                values=values,
                bootstrap_value=bootstrap_value,
                dist_class=dist_class,
                model=model,
                clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
                clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
                                              tf.float32))
            self.value_targets = self.vtrace_returns.vs

        # The policy gradients loss
        self.pi_loss = -tf.reduce_sum(
            tf.boolean_mask(actions_logp * self.vtrace_returns.pg_advantages,
                            valid_mask))

        # The baseline loss
        delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
        self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))

        # The entropy loss
        self.entropy = tf.reduce_sum(
            tf.boolean_mask(actions_entropy, valid_mask))

        # The summed weighted loss
        self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
                           self.entropy * entropy_coeff)
Exemple #4
0
        def loss(
            self,
            model: Union[ModelV2, "tf.keras.Model"],
            dist_class: Type[TFActionDistribution],
            train_batch: SampleBatch,
        ) -> Union[TensorType, List[TensorType]]:
            model_out, _ = model(train_batch)
            action_dist = dist_class(model_out, model)

            if isinstance(self.action_space, gym.spaces.Discrete):
                is_multidiscrete = False
                output_hidden_shape = [self.action_space.n]
            elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
                is_multidiscrete = True
                output_hidden_shape = self.action_space.nvec.astype(np.int32)
            else:
                is_multidiscrete = False
                output_hidden_shape = 1

            # TODO: (sven) deprecate this when trajectory view API gets activated.
            def make_time_major(*args, **kw):
                return _make_time_major(
                    self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
                )

            actions = train_batch[SampleBatch.ACTIONS]
            dones = train_batch[SampleBatch.DONES]
            rewards = train_batch[SampleBatch.REWARDS]
            behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

            target_model_out, _ = self.target_model(train_batch)
            prev_action_dist = dist_class(behaviour_logits, self.model)
            values = self.model.value_function()
            values_time_major = make_time_major(values)

            self.model_vars = self.model.variables()
            self.target_model_vars = self.target_model.variables()

            if self.is_recurrent():
                max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
                mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
                mask = tf.reshape(mask, [-1])
                mask = make_time_major(mask, drop_last=self.config["vtrace"])

                def reduce_mean_valid(t):
                    return tf.reduce_mean(tf.boolean_mask(t, mask))

            else:
                reduce_mean_valid = tf.reduce_mean

            if self.config["vtrace"]:
                drop_last = self.config["vtrace_drop_last_ts"]
                logger.debug(
                    "Using V-Trace surrogate loss (vtrace=True; "
                    f"drop_last={drop_last})"
                )

                # Prepare actions for loss.
                loss_actions = (
                    actions if is_multidiscrete else tf.expand_dims(actions, axis=1)
                )

                old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
                old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

                # Prepare KL for Loss
                mean_kl = make_time_major(
                    old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last
                )

                unpacked_behaviour_logits = tf.split(
                    behaviour_logits, output_hidden_shape, axis=1
                )
                unpacked_old_policy_behaviour_logits = tf.split(
                    old_policy_behaviour_logits, output_hidden_shape, axis=1
                )

                # Compute vtrace on the CPU for better perf.
                with tf.device("/cpu:0"):
                    vtrace_returns = vtrace.multi_from_logits(
                        behaviour_policy_logits=make_time_major(
                            unpacked_behaviour_logits, drop_last=drop_last
                        ),
                        target_policy_logits=make_time_major(
                            unpacked_old_policy_behaviour_logits, drop_last=drop_last
                        ),
                        actions=tf.unstack(
                            make_time_major(loss_actions, drop_last=drop_last), axis=2
                        ),
                        discounts=tf.cast(
                            ~make_time_major(
                                tf.cast(dones, tf.bool), drop_last=drop_last
                            ),
                            tf.float32,
                        )
                        * self.config["gamma"],
                        rewards=make_time_major(rewards, drop_last=drop_last),
                        values=values_time_major[:-1]
                        if drop_last
                        else values_time_major,
                        bootstrap_value=values_time_major[-1],
                        dist_class=Categorical if is_multidiscrete else dist_class,
                        model=model,
                        clip_rho_threshold=tf.cast(
                            self.config["vtrace_clip_rho_threshold"], tf.float32
                        ),
                        clip_pg_rho_threshold=tf.cast(
                            self.config["vtrace_clip_pg_rho_threshold"], tf.float32
                        ),
                    )

                actions_logp = make_time_major(
                    action_dist.logp(actions), drop_last=drop_last
                )
                prev_actions_logp = make_time_major(
                    prev_action_dist.logp(actions), drop_last=drop_last
                )
                old_policy_actions_logp = make_time_major(
                    old_policy_action_dist.logp(actions), drop_last=drop_last
                )

                is_ratio = tf.clip_by_value(
                    tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
                )
                logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
                self._is_ratio = is_ratio

                advantages = vtrace_returns.pg_advantages
                surrogate_loss = tf.minimum(
                    advantages * logp_ratio,
                    advantages
                    * tf.clip_by_value(
                        logp_ratio,
                        1 - self.config["clip_param"],
                        1 + self.config["clip_param"],
                    ),
                )

                action_kl = (
                    tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
                )
                mean_kl_loss = reduce_mean_valid(action_kl)
                mean_policy_loss = -reduce_mean_valid(surrogate_loss)

                # The value function loss.
                if drop_last:
                    delta = values_time_major[:-1] - vtrace_returns.vs
                else:
                    delta = values_time_major - vtrace_returns.vs
                value_targets = vtrace_returns.vs
                mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

                # The entropy loss.
                actions_entropy = make_time_major(
                    action_dist.multi_entropy(), drop_last=True
                )
                mean_entropy = reduce_mean_valid(actions_entropy)

            else:
                logger.debug("Using PPO surrogate loss (vtrace=False)")

                # Prepare KL for Loss
                mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))

                logp_ratio = tf.math.exp(
                    make_time_major(action_dist.logp(actions))
                    - make_time_major(prev_action_dist.logp(actions))
                )

                advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
                surrogate_loss = tf.minimum(
                    advantages * logp_ratio,
                    advantages
                    * tf.clip_by_value(
                        logp_ratio,
                        1 - self.config["clip_param"],
                        1 + self.config["clip_param"],
                    ),
                )

                action_kl = (
                    tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
                )
                mean_kl_loss = reduce_mean_valid(action_kl)
                mean_policy_loss = -reduce_mean_valid(surrogate_loss)

                # The value function loss.
                value_targets = make_time_major(
                    train_batch[Postprocessing.VALUE_TARGETS]
                )
                delta = values_time_major - value_targets
                mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

                # The entropy loss.
                mean_entropy = reduce_mean_valid(
                    make_time_major(action_dist.multi_entropy())
                )

            # The summed weighted loss.
            total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
            # Optional KL loss.
            if self.config["use_kl_loss"]:
                total_loss += self.kl_coeff * mean_kl_loss
            # Optional vf loss (or in a separate term due to separate
            # optimizers/networks).
            loss_wo_vf = total_loss
            if not self.config["_separate_vf_optimizer"]:
                total_loss += mean_vf_loss * self.config["vf_loss_coeff"]

            # Store stats in policy for stats_fn.
            self._total_loss = total_loss
            self._loss_wo_vf = loss_wo_vf
            self._mean_policy_loss = mean_policy_loss
            # Backward compatibility: Deprecate policy._mean_kl.
            self._mean_kl_loss = self._mean_kl = mean_kl_loss
            self._mean_vf_loss = mean_vf_loss
            self._mean_entropy = mean_entropy
            self._value_targets = value_targets

            # Return one total loss or two losses: vf vs rest (policy + kl).
            if self.config["_separate_vf_optimizer"]:
                return loss_wo_vf, mean_vf_loss
            else:
                return total_loss