def __init__(
        self,
        actions,
        actions_logp,
        actions_entropy,
        dones,
        behaviour_action_logp,
        behaviour_logits,
        target_logits,
        discount,
        rewards,
        values,
        bootstrap_value,
        dist_class,
        model,
        valid_mask,
        config,
        vf_loss_coeff=0.5,
        entropy_coeff=0.01,
        clip_rho_threshold=1.0,
        clip_pg_rho_threshold=1.0,
    ):
        """Policy gradient loss with vtrace importance weighting.

        VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
        batch_size. The reason we need to know `B` is for V-trace to properly
        handle episode cut boundaries.

        Args:
            actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
            actions_logp: A float32 tensor of shape [T, B].
            actions_entropy: A float32 tensor of shape [T, B].
            dones: A bool tensor of shape [T, B].
            behaviour_action_logp: Tensor of shape [T, B].
            behaviour_logits: A list with length of ACTION_SPACE of float32
                tensors of shapes
                [T, B, ACTION_SPACE[0]],
                ...,
                [T, B, ACTION_SPACE[-1]]
            target_logits: A list with length of ACTION_SPACE of float32
                tensors of shapes
                [T, B, ACTION_SPACE[0]],
                ...,
                [T, B, ACTION_SPACE[-1]]
            discount: A float32 scalar.
            rewards: A float32 tensor of shape [T, B].
            values: A float32 tensor of shape [T, B].
            bootstrap_value: A float32 tensor of shape [B].
            dist_class: action distribution class for logits.
            valid_mask: A bool tensor of valid RNN input elements (#2992).
            config: Algorithm config dict.
        """

        # Compute vtrace on the CPU for better perf.
        with tf.device("/cpu:0"):
            self.vtrace_returns = vtrace.multi_from_logits(
                behaviour_action_log_probs=behaviour_action_logp,
                behaviour_policy_logits=behaviour_logits,
                target_policy_logits=target_logits,
                actions=tf.unstack(actions, axis=2),
                discounts=tf.cast(~tf.cast(dones, tf.bool), tf.float32) *
                discount,
                rewards=rewards,
                values=values,
                bootstrap_value=bootstrap_value,
                dist_class=dist_class,
                model=model,
                clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
                clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
                                              tf.float32),
            )
            self.value_targets = self.vtrace_returns.vs

        # The policy gradients loss.
        masked_pi_loss = tf.boolean_mask(
            actions_logp * self.vtrace_returns.pg_advantages, valid_mask)
        self.pi_loss = -tf.reduce_sum(masked_pi_loss)
        self.mean_pi_loss = -tf.reduce_mean(masked_pi_loss)

        # The baseline loss.
        delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
        delta_squarred = tf.math.square(delta)
        self.vf_loss = 0.5 * tf.reduce_sum(delta_squarred)
        self.mean_vf_loss = 0.5 * tf.reduce_mean(delta_squarred)

        # The entropy loss.
        masked_entropy = tf.boolean_mask(actions_entropy, valid_mask)
        self.entropy = tf.reduce_sum(masked_entropy)
        self.mean_entropy = tf.reduce_mean(masked_entropy)

        # The summed weighted loss.
        self.total_loss = self.pi_loss - self.entropy * entropy_coeff

        # Optional vf loss (or in a separate term due to separate
        # optimizers/networks).
        self.loss_wo_vf = self.total_loss
        if not config["_separate_vf_optimizer"]:
            self.total_loss += self.vf_loss * vf_loss_coeff
Exemple #2
0
        def loss(
            self,
            model: Union[ModelV2, "tf.keras.Model"],
            dist_class: Type[TFActionDistribution],
            train_batch: SampleBatch,
        ) -> Union[TensorType, List[TensorType]]:
            model_out, _ = model(train_batch)
            action_dist = dist_class(model_out, model)

            if isinstance(self.action_space, gym.spaces.Discrete):
                is_multidiscrete = False
                output_hidden_shape = [self.action_space.n]
            elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
                is_multidiscrete = True
                output_hidden_shape = self.action_space.nvec.astype(np.int32)
            else:
                is_multidiscrete = False
                output_hidden_shape = 1

            # TODO: (sven) deprecate this when trajectory view API gets activated.
            def make_time_major(*args, **kw):
                return _make_time_major(
                    self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw
                )

            actions = train_batch[SampleBatch.ACTIONS]
            dones = train_batch[SampleBatch.DONES]
            rewards = train_batch[SampleBatch.REWARDS]
            behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

            target_model_out, _ = self.target_model(train_batch)
            prev_action_dist = dist_class(behaviour_logits, self.model)
            values = self.model.value_function()
            values_time_major = make_time_major(values)

            if self.is_recurrent():
                max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
                mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
                mask = tf.reshape(mask, [-1])
                mask = make_time_major(mask, drop_last=self.config["vtrace"])

                def reduce_mean_valid(t):
                    return tf.reduce_mean(tf.boolean_mask(t, mask))

            else:
                reduce_mean_valid = tf.reduce_mean

            if self.config["vtrace"]:
                drop_last = self.config["vtrace_drop_last_ts"]
                logger.debug(
                    "Using V-Trace surrogate loss (vtrace=True; "
                    f"drop_last={drop_last})"
                )

                # Prepare actions for loss.
                loss_actions = (
                    actions if is_multidiscrete else tf.expand_dims(actions, axis=1)
                )

                old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
                old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

                # Prepare KL for Loss
                mean_kl = make_time_major(
                    old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last
                )

                unpacked_behaviour_logits = tf.split(
                    behaviour_logits, output_hidden_shape, axis=1
                )
                unpacked_old_policy_behaviour_logits = tf.split(
                    old_policy_behaviour_logits, output_hidden_shape, axis=1
                )

                # Compute vtrace on the CPU for better perf.
                with tf.device("/cpu:0"):
                    vtrace_returns = vtrace.multi_from_logits(
                        behaviour_policy_logits=make_time_major(
                            unpacked_behaviour_logits, drop_last=drop_last
                        ),
                        target_policy_logits=make_time_major(
                            unpacked_old_policy_behaviour_logits, drop_last=drop_last
                        ),
                        actions=tf.unstack(
                            make_time_major(loss_actions, drop_last=drop_last), axis=2
                        ),
                        discounts=tf.cast(
                            ~make_time_major(
                                tf.cast(dones, tf.bool), drop_last=drop_last
                            ),
                            tf.float32,
                        )
                        * self.config["gamma"],
                        rewards=make_time_major(rewards, drop_last=drop_last),
                        values=values_time_major[:-1]
                        if drop_last
                        else values_time_major,
                        bootstrap_value=values_time_major[-1],
                        dist_class=Categorical if is_multidiscrete else dist_class,
                        model=model,
                        clip_rho_threshold=tf.cast(
                            self.config["vtrace_clip_rho_threshold"], tf.float32
                        ),
                        clip_pg_rho_threshold=tf.cast(
                            self.config["vtrace_clip_pg_rho_threshold"], tf.float32
                        ),
                    )

                actions_logp = make_time_major(
                    action_dist.logp(actions), drop_last=drop_last
                )
                prev_actions_logp = make_time_major(
                    prev_action_dist.logp(actions), drop_last=drop_last
                )
                old_policy_actions_logp = make_time_major(
                    old_policy_action_dist.logp(actions), drop_last=drop_last
                )

                is_ratio = tf.clip_by_value(
                    tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
                )
                logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
                self._is_ratio = is_ratio

                advantages = vtrace_returns.pg_advantages
                surrogate_loss = tf.minimum(
                    advantages * logp_ratio,
                    advantages
                    * tf.clip_by_value(
                        logp_ratio,
                        1 - self.config["clip_param"],
                        1 + self.config["clip_param"],
                    ),
                )

                action_kl = (
                    tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
                )
                mean_kl_loss = reduce_mean_valid(action_kl)
                mean_policy_loss = -reduce_mean_valid(surrogate_loss)

                # The value function loss.
                if drop_last:
                    delta = values_time_major[:-1] - vtrace_returns.vs
                else:
                    delta = values_time_major - vtrace_returns.vs
                value_targets = vtrace_returns.vs
                mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

                # The entropy loss.
                actions_entropy = make_time_major(
                    action_dist.multi_entropy(), drop_last=True
                )
                mean_entropy = reduce_mean_valid(actions_entropy)

            else:
                logger.debug("Using PPO surrogate loss (vtrace=False)")

                # Prepare KL for Loss
                mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))

                logp_ratio = tf.math.exp(
                    make_time_major(action_dist.logp(actions))
                    - make_time_major(prev_action_dist.logp(actions))
                )

                advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
                surrogate_loss = tf.minimum(
                    advantages * logp_ratio,
                    advantages
                    * tf.clip_by_value(
                        logp_ratio,
                        1 - self.config["clip_param"],
                        1 + self.config["clip_param"],
                    ),
                )

                action_kl = (
                    tf.reduce_mean(mean_kl, axis=0) if is_multidiscrete else mean_kl
                )
                mean_kl_loss = reduce_mean_valid(action_kl)
                mean_policy_loss = -reduce_mean_valid(surrogate_loss)

                # The value function loss.
                value_targets = make_time_major(
                    train_batch[Postprocessing.VALUE_TARGETS]
                )
                delta = values_time_major - value_targets
                mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))

                # The entropy loss.
                mean_entropy = reduce_mean_valid(
                    make_time_major(action_dist.multi_entropy())
                )

            # The summed weighted loss.
            total_loss = mean_policy_loss - mean_entropy * self.entropy_coeff
            # Optional KL loss.
            if self.config["use_kl_loss"]:
                total_loss += self.kl_coeff * mean_kl_loss
            # Optional vf loss (or in a separate term due to separate
            # optimizers/networks).
            loss_wo_vf = total_loss
            if not self.config["_separate_vf_optimizer"]:
                total_loss += mean_vf_loss * self.config["vf_loss_coeff"]

            # Store stats in policy for stats_fn.
            self._total_loss = total_loss
            self._loss_wo_vf = loss_wo_vf
            self._mean_policy_loss = mean_policy_loss
            # Backward compatibility: Deprecate policy._mean_kl.
            self._mean_kl_loss = self._mean_kl = mean_kl_loss
            self._mean_vf_loss = mean_vf_loss
            self._mean_entropy = mean_entropy
            self._value_targets = value_targets

            # Return one total loss or two losses: vf vs rest (policy + kl).
            if self.config["_separate_vf_optimizer"]:
                return loss_wo_vf, mean_vf_loss
            else:
                return total_loss