Пример #1
0
def stats(policy, train_batch):
    values_batched = _make_time_major(
        policy,
        train_batch.get("seq_lens"),
        policy.model.value_function(),
        drop_last=policy.config["vtrace"])

    stats_dict = {
        "cur_lr": tf.cast(policy.cur_lr, tf.float64),
        "policy_loss": policy.loss.pi_loss,
        "entropy": policy.loss.entropy,
        "var_gnorm": tf.global_norm(policy.model.trainable_variables()),
        "vf_loss": policy.loss.vf_loss,
        "vf_explained_var": explained_variance(
            tf.reshape(policy.loss.value_targets, [-1]),
            tf.reshape(values_batched, [-1])),
    }

    if policy.config["vtrace"]:
        is_stat_mean, is_stat_var = tf.nn.moments(policy.loss.is_ratio, [0, 1])
        stats_dict.update({"mean_IS": is_stat_mean})
        stats_dict.update({"var_IS": is_stat_var})

    if policy.config["use_kl_loss"]:
        stats_dict.update({"kl": policy.loss.mean_kl})
        stats_dict.update({"KL_Coeff": policy.kl_coeff})

    return stats_dict
Пример #2
0
        def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
            values_batched = _make_time_major(
                self,
                train_batch.get(SampleBatch.SEQ_LENS),
                self.model.value_function(),
                drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"],
            )

            stats_dict = {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "total_loss": self._total_loss,
                "policy_loss": self._mean_policy_loss,
                "entropy": self._mean_entropy,
                "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()),
                "vf_loss": self._mean_vf_loss,
                "vf_explained_var": explained_variance(
                    tf.reshape(self._value_targets, [-1]),
                    tf.reshape(values_batched, [-1]),
                ),
                "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
            }

            if self.config["vtrace"]:
                is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1])
                stats_dict["mean_IS"] = is_stat_mean
                stats_dict["var_IS"] = is_stat_var

            if self.config["use_kl_loss"]:
                stats_dict["kl"] = self._mean_kl_loss
                stats_dict["KL_Coeff"] = self.kl_coeff

            return stats_dict
Пример #3
0
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
    """Stats function for APPO. Returns a dict with important loss stats.

    Args:
        policy (Policy): The Policy to generate stats for.
        train_batch (SampleBatch): The SampleBatch (already) used for training.

    Returns:
        Dict[str, TensorType]: The stats dict.
    """
    values_batched = _make_time_major(
        policy,
        train_batch.get(SampleBatch.SEQ_LENS),
        policy.model.value_function(),
        drop_last=policy.config["vtrace"]
        and policy.config["vtrace_drop_last_ts"],
    )

    stats_dict = {
        "cur_lr":
        tf.cast(policy.cur_lr, tf.float64),
        "total_loss":
        policy._total_loss,
        "policy_loss":
        policy._mean_policy_loss,
        "entropy":
        policy._mean_entropy,
        "var_gnorm":
        tf.linalg.global_norm(policy.model.trainable_variables()),
        "vf_loss":
        policy._mean_vf_loss,
        "vf_explained_var":
        explained_variance(tf.reshape(policy._value_targets, [-1]),
                           tf.reshape(values_batched, [-1])),
        "entropy_coeff":
        tf.cast(policy.entropy_coeff, tf.float64),
    }

    if policy.config["vtrace"]:
        is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1])
        stats_dict["mean_IS"] = is_stat_mean
        stats_dict["var_IS"] = is_stat_var

    if policy.config["use_kl_loss"]:
        stats_dict["kl"] = policy._mean_kl_loss
        stats_dict["KL_Coeff"] = policy.kl_coeff

    return stats_dict
Пример #4
0
 def make_time_major(*args, **kw):
     return _make_time_major(policy, train_batch.get("seq_lens"), *args,
                             **kw)
Пример #5
0
 def make_time_major(*args, **kw):
     return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
                             *args, **kw)