Exemple #1
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        model_out, _ = model(train_batch)
        action_dist = dist_class(model_out, model)
        actions = train_batch[SampleBatch.ACTIONS]
        # log\pi_\theta(a|s)
        logprobs = action_dist.logp(actions)

        # Advantage estimation.
        if self.config["beta"] != 0.0:
            cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
            state_values = model.value_function()
            adv = cumulative_rewards - state_values
            adv_squared_mean = torch.mean(torch.pow(adv, 2.0))

            explained_var = explained_variance(cumulative_rewards,
                                               state_values)
            self.explained_variance = torch.mean(explained_var)

            # Policy loss.
            # Update averaged advantage norm.
            rate = self.config["moving_average_sqd_adv_norm_update_rate"]
            self._moving_average_sqd_adv_norm.add_(
                rate * (adv_squared_mean - self._moving_average_sqd_adv_norm))
            # Exponentially weighted advantages.
            exp_advs = torch.exp(
                self.config["beta"] *
                (adv /
                 (1e-8 + torch.pow(self._moving_average_sqd_adv_norm, 0.5)))
            ).detach()
            # Value loss.
            self.v_loss = 0.5 * adv_squared_mean
        else:
            # Policy loss (simple BC loss term).
            exp_advs = 1.0
            # Value loss.
            self.v_loss = 0.0

        # logprob loss alone tends to push action distributions to
        # have very low entropy, resulting in worse performance for
        # unfamiliar situations.
        # A scaled logstd loss term encourages stochasticity, thus
        # alleviate the problem to some extent.
        logstd_coeff = self.config["bc_logstd_coeff"]
        if logstd_coeff > 0.0:
            logstds = torch.mean(action_dist.log_std, dim=1)
        else:
            logstds = 0.0

        self.p_loss = -torch.mean(exp_advs *
                                  (logprobs + logstd_coeff * logstds))

        # Combine both losses.
        self.total_loss = self.p_loss + self.config["vf_coeff"] * self.v_loss

        return self.total_loss
Exemple #2
0
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
                train_batch: SampleBatch) -> TensorType:
    model_out, _ = model(train_batch)
    action_dist = dist_class(model_out, model)
    actions = train_batch[SampleBatch.ACTIONS]
    # log\pi_\theta(a|s)
    logprobs = action_dist.logp(actions)

    # Advantage estimation.
    if policy.config["beta"] != 0.0:
        cumulative_rewards = train_batch[Postprocessing.ADVANTAGES]
        state_values = model.value_function()
        adv = cumulative_rewards - state_values
        adv_squared_mean = torch.mean(torch.pow(adv, 2.0))

        explained_var = explained_variance(cumulative_rewards, state_values)
        policy.explained_variance = torch.mean(explained_var)

        # Policy loss.
        # Update averaged advantage norm.
        rate = policy.config["moving_average_sqd_adv_norm_update_rate"]
        policy._moving_average_sqd_adv_norm.add_(
            rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm))
        # Exponentially weighted advantages.
        exp_advs = torch.exp(
            policy.config["beta"] *
            (adv /
             (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5))))
        policy.p_loss = -torch.mean(exp_advs.detach() * logprobs)
        # Value loss.
        policy.v_loss = 0.5 * adv_squared_mean
    else:
        # Policy loss (simple BC loss term).
        policy.p_loss = -1.0 * torch.mean(logprobs)
        # Value loss.
        policy.v_loss = 0.0

    # Combine both losses.
    policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \
        policy.v_loss

    return policy.total_loss
def build_vtrace_loss(policy, model, dist_class, train_batch):
    model_out, _ = model(train_batch)
    action_dist = dist_class(model_out, model)

    if isinstance(policy.action_space, gym.spaces.Discrete):
        is_multidiscrete = False
        output_hidden_shape = [policy.action_space.n]
    elif isinstance(policy.action_space, gym.spaces.MultiDiscrete):
        is_multidiscrete = True
        output_hidden_shape = policy.action_space.nvec.astype(np.int32)
    else:
        is_multidiscrete = False
        output_hidden_shape = 1

    def _make_time_major(*args, **kw):
        return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
                               *args, **kw)

    actions = train_batch[SampleBatch.ACTIONS]
    dones = train_batch[SampleBatch.DONES]
    rewards = train_batch[SampleBatch.REWARDS]
    behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
    behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
    if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
        unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                list(output_hidden_shape),
                                                dim=1)
        unpacked_outputs = torch.split(model_out,
                                       list(output_hidden_shape),
                                       dim=1)
    else:
        unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                output_hidden_shape,
                                                dim=1)
        unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
    values = model.value_function()

    if policy.is_recurrent():
        max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
        mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                  max_seq_len)
        mask = torch.reshape(mask_orig, [-1])
    else:
        mask = torch.ones_like(rewards)

    # Prepare actions for loss.
    loss_actions = actions if is_multidiscrete else torch.unsqueeze(actions,
                                                                    dim=1)

    # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
    drop_last = policy.config["vtrace_drop_last_ts"]
    loss = VTraceLoss(
        actions=_make_time_major(loss_actions, drop_last=drop_last),
        actions_logp=_make_time_major(action_dist.logp(actions),
                                      drop_last=drop_last),
        actions_entropy=_make_time_major(action_dist.entropy(),
                                         drop_last=drop_last),
        dones=_make_time_major(dones, drop_last=drop_last),
        behaviour_action_logp=_make_time_major(behaviour_action_logp,
                                               drop_last=drop_last),
        behaviour_logits=_make_time_major(unpacked_behaviour_logits,
                                          drop_last=drop_last),
        target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last),
        discount=policy.config["gamma"],
        rewards=_make_time_major(rewards, drop_last=drop_last),
        values=_make_time_major(values, drop_last=drop_last),
        bootstrap_value=_make_time_major(values)[-1],
        dist_class=TorchCategorical if is_multidiscrete else dist_class,
        model=model,
        valid_mask=_make_time_major(mask, drop_last=drop_last),
        config=policy.config,
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        entropy_coeff=policy.entropy_coeff,
        clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"],
        clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"],
    )

    # Store values for stats function in model (tower), such that for
    # multi-GPU, we do not override them during the parallel loss phase.
    model.tower_stats["pi_loss"] = loss.pi_loss
    model.tower_stats["vf_loss"] = loss.vf_loss
    model.tower_stats["entropy"] = loss.entropy
    model.tower_stats["mean_entropy"] = loss.mean_entropy
    model.tower_stats["total_loss"] = loss.total_loss

    values_batched = make_time_major(
        policy,
        train_batch.get(SampleBatch.SEQ_LENS),
        values,
        drop_last=policy.config["vtrace"] and drop_last,
    )
    model.tower_stats["vf_explained_var"] = explained_variance(
        torch.reshape(loss.value_targets, [-1]),
        torch.reshape(values_batched, [-1]))

    return loss.total_loss
Exemple #4
0
    def loss(self, model: ModelV2, dist_class: Type[ActionDistribution],
             train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for Proximal Policy Objective.

        Args:
            model: The Model to calculate the loss for.
            dist_class: The action distr. class.
            train_batch: The training data.

        Returns:
            The PPO loss tensor given the input batch.
        """

        logits, state = model(train_batch)
        curr_action_dist = dist_class(logits, model)

        # RNN case: Mask away 0-padded chunks at end of time axis.
        if state:
            B = len(train_batch[SampleBatch.SEQ_LENS])
            max_seq_len = logits.shape[0] // B
            mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                 max_seq_len,
                                 time_major=model.is_time_major())
            mask = torch.reshape(mask, [-1])
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        # non-RNN case: No masking.
        else:
            mask = None
            reduce_mean_valid = torch.mean

        prev_action_dist = dist_class(
            train_batch[SampleBatch.ACTION_DIST_INPUTS], model)

        logp_ratio = torch.exp(
            curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) -
            train_batch[SampleBatch.ACTION_LOGP])

        # Only calculate kl loss if necessary (kl-coeff > 0.0).
        if self.config["kl_coeff"] > 0.0:
            action_kl = prev_action_dist.kl(curr_action_dist)
            mean_kl_loss = reduce_mean_valid(action_kl)
        else:
            mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device)

        curr_entropy = curr_action_dist.entropy()
        mean_entropy = reduce_mean_valid(curr_entropy)

        surrogate_loss = torch.min(
            train_batch[Postprocessing.ADVANTAGES] * logp_ratio,
            train_batch[Postprocessing.ADVANTAGES] *
            torch.clamp(logp_ratio, 1 - self.config["clip_param"],
                        1 + self.config["clip_param"]))
        mean_policy_loss = reduce_mean_valid(-surrogate_loss)

        # Compute a value function loss.
        if self.config["use_critic"]:
            prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
            value_fn_out = model.value_function()
            vf_loss1 = torch.pow(
                value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
            vf_clipped = prev_value_fn_out + torch.clamp(
                value_fn_out - prev_value_fn_out,
                -self.config["vf_clip_param"], self.config["vf_clip_param"])
            vf_loss2 = torch.pow(
                vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0)
            vf_loss = torch.max(vf_loss1, vf_loss2)
            mean_vf_loss = reduce_mean_valid(vf_loss)
        # Ignore the value function.
        else:
            vf_loss = mean_vf_loss = 0.0

        total_loss = reduce_mean_valid(-surrogate_loss +
                                       self.config["vf_loss_coeff"] * vf_loss -
                                       self.entropy_coeff * curr_entropy)

        # Add mean_kl_loss (already processed through `reduce_mean_valid`),
        # if necessary.
        if self.config["kl_coeff"] > 0.0:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["vf_explained_var"] = explained_variance(
            train_batch[Postprocessing.VALUE_TARGETS], model.value_function())
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["mean_kl_loss"] = mean_kl_loss

        return total_loss
Exemple #5
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[ActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        """Constructs the loss for APPO.

        With IS modifications and V-trace for Advantage Estimation.

        Args:
            model (ModelV2): The Model to calculate the loss for.
            dist_class (Type[ActionDistribution]): The action distr. class.
            train_batch (SampleBatch): The training data.

        Returns:
            Union[TensorType, List[TensorType]]: A single loss tensor or a list
                of loss tensors.
        """
        target_model = self.target_models[model]

        model_out, _ = model(train_batch)
        action_dist = dist_class(model_out, model)

        if isinstance(self.action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            output_hidden_shape = [self.action_space.n]
        elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            output_hidden_shape = self.action_space.nvec.astype(np.int32)
        else:
            is_multidiscrete = False
            output_hidden_shape = 1

        def _make_time_major(*args, **kwargs):
            return make_time_major(
                self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs
            )

        actions = train_batch[SampleBatch.ACTIONS]
        dones = train_batch[SampleBatch.DONES]
        rewards = train_batch[SampleBatch.REWARDS]
        behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

        target_model_out, _ = target_model(train_batch)

        prev_action_dist = dist_class(behaviour_logits, model)
        values = model.value_function()
        values_time_major = _make_time_major(values)

        drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"]

        if self.is_recurrent():
            max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
            mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
            mask = torch.reshape(mask, [-1])
            mask = _make_time_major(mask, drop_last=drop_last)
            num_valid = torch.sum(mask)

            def reduce_mean_valid(t):
                return torch.sum(t[mask]) / num_valid

        else:
            reduce_mean_valid = torch.mean

        if self.config["vtrace"]:
            logger.debug(
                "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})"
            )

            old_policy_behaviour_logits = target_model_out.detach()
            old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)

            if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
                unpacked_behaviour_logits = torch.split(
                    behaviour_logits, list(output_hidden_shape), dim=1
                )
                unpacked_old_policy_behaviour_logits = torch.split(
                    old_policy_behaviour_logits, list(output_hidden_shape), dim=1
                )
            else:
                unpacked_behaviour_logits = torch.chunk(
                    behaviour_logits, output_hidden_shape, dim=1
                )
                unpacked_old_policy_behaviour_logits = torch.chunk(
                    old_policy_behaviour_logits, output_hidden_shape, dim=1
                )

            # Prepare actions for loss.
            loss_actions = (
                actions if is_multidiscrete else torch.unsqueeze(actions, dim=1)
            )

            # Prepare KL for loss.
            action_kl = _make_time_major(
                old_policy_action_dist.kl(action_dist), drop_last=drop_last
            )

            # Compute vtrace on the CPU for better perf.
            vtrace_returns = vtrace.multi_from_logits(
                behaviour_policy_logits=_make_time_major(
                    unpacked_behaviour_logits, drop_last=drop_last
                ),
                target_policy_logits=_make_time_major(
                    unpacked_old_policy_behaviour_logits, drop_last=drop_last
                ),
                actions=torch.unbind(
                    _make_time_major(loss_actions, drop_last=drop_last), dim=2
                ),
                discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float())
                * self.config["gamma"],
                rewards=_make_time_major(rewards, drop_last=drop_last),
                values=values_time_major[:-1] if drop_last else values_time_major,
                bootstrap_value=values_time_major[-1],
                dist_class=TorchCategorical if is_multidiscrete else dist_class,
                model=model,
                clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
                clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
            )

            actions_logp = _make_time_major(
                action_dist.logp(actions), drop_last=drop_last
            )
            prev_actions_logp = _make_time_major(
                prev_action_dist.logp(actions), drop_last=drop_last
            )
            old_policy_actions_logp = _make_time_major(
                old_policy_action_dist.logp(actions), drop_last=drop_last
            )
            is_ratio = torch.clamp(
                torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0
            )
            logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp)
            self._is_ratio = is_ratio

            advantages = vtrace_returns.pg_advantages.to(logp_ratio.device)
            surrogate_loss = torch.min(
                advantages * logp_ratio,
                advantages
                * torch.clamp(
                    logp_ratio,
                    1 - self.config["clip_param"],
                    1 + self.config["clip_param"],
                ),
            )

            mean_kl_loss = reduce_mean_valid(action_kl)
            mean_policy_loss = -reduce_mean_valid(surrogate_loss)

            # The value function loss.
            value_targets = vtrace_returns.vs.to(values_time_major.device)
            if drop_last:
                delta = values_time_major[:-1] - value_targets
            else:
                delta = values_time_major - value_targets
            mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

            # The entropy loss.
            mean_entropy = reduce_mean_valid(
                _make_time_major(action_dist.entropy(), drop_last=drop_last)
            )

        else:
            logger.debug("Using PPO surrogate loss (vtrace=False)")

            # Prepare KL for Loss
            action_kl = _make_time_major(prev_action_dist.kl(action_dist))

            actions_logp = _make_time_major(action_dist.logp(actions))
            prev_actions_logp = _make_time_major(prev_action_dist.logp(actions))
            logp_ratio = torch.exp(actions_logp - prev_actions_logp)

            advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES])
            surrogate_loss = torch.min(
                advantages * logp_ratio,
                advantages
                * torch.clamp(
                    logp_ratio,
                    1 - self.config["clip_param"],
                    1 + self.config["clip_param"],
                ),
            )

            mean_kl_loss = reduce_mean_valid(action_kl)
            mean_policy_loss = -reduce_mean_valid(surrogate_loss)

            # The value function loss.
            value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS])
            delta = values_time_major - value_targets
            mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0))

            # The entropy loss.
            mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy()))

        # The summed weighted loss
        total_loss = (
            mean_policy_loss
            + mean_vf_loss * self.config["vf_loss_coeff"]
            - mean_entropy * self.entropy_coeff
        )

        # Optional additional KL Loss
        if self.config["use_kl_loss"]:
            total_loss += self.kl_coeff * mean_kl_loss

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["total_loss"] = total_loss
        model.tower_stats["mean_policy_loss"] = mean_policy_loss
        model.tower_stats["mean_kl_loss"] = mean_kl_loss
        model.tower_stats["mean_vf_loss"] = mean_vf_loss
        model.tower_stats["mean_entropy"] = mean_entropy
        model.tower_stats["value_targets"] = value_targets
        model.tower_stats["vf_explained_var"] = explained_variance(
            torch.reshape(value_targets, [-1]),
            torch.reshape(
                values_time_major[:-1] if drop_last else values_time_major, [-1]
            ),
        )

        return total_loss