def stats(policy, train_batch): values_batched = _make_time_major( policy, train_batch.get("seq_lens"), policy.model.value_function(), drop_last=policy.config["vtrace"]) stats_dict = { "cur_lr": tf.cast(policy.cur_lr, tf.float64), "policy_loss": policy.loss.pi_loss, "entropy": policy.loss.entropy, "var_gnorm": tf.global_norm(policy.model.trainable_variables()), "vf_loss": policy.loss.vf_loss, "vf_explained_var": explained_variance( tf.reshape(policy.loss.value_targets, [-1]), tf.reshape(values_batched, [-1])), } if policy.config["vtrace"]: is_stat_mean, is_stat_var = tf.nn.moments(policy.loss.is_ratio, [0, 1]) stats_dict.update({"mean_IS": is_stat_mean}) stats_dict.update({"var_IS": is_stat_var}) if policy.config["use_kl_loss"]: stats_dict.update({"kl": policy.loss.mean_kl}) stats_dict.update({"KL_Coeff": policy.kl_coeff}) return stats_dict
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: values_batched = _make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), self.model.value_function(), drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"], ) stats_dict = { "cur_lr": tf.cast(self.cur_lr, tf.float64), "total_loss": self._total_loss, "policy_loss": self._mean_policy_loss, "entropy": self._mean_entropy, "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()), "vf_loss": self._mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(self._value_targets, [-1]), tf.reshape(values_batched, [-1]), ), "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), } if self.config["vtrace"]: is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1]) stats_dict["mean_IS"] = is_stat_mean stats_dict["var_IS"] = is_stat_var if self.config["use_kl_loss"]: stats_dict["kl"] = self._mean_kl_loss stats_dict["KL_Coeff"] = self.kl_coeff return stats_dict
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """Stats function for APPO. Returns a dict with important loss stats. Args: policy (Policy): The Policy to generate stats for. train_batch (SampleBatch): The SampleBatch (already) used for training. Returns: Dict[str, TensorType]: The stats dict. """ values_batched = _make_time_major( policy, train_batch.get(SampleBatch.SEQ_LENS), policy.model.value_function(), drop_last=policy.config["vtrace"] and policy.config["vtrace_drop_last_ts"], ) stats_dict = { "cur_lr": tf.cast(policy.cur_lr, tf.float64), "total_loss": policy._total_loss, "policy_loss": policy._mean_policy_loss, "entropy": policy._mean_entropy, "var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()), "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance(tf.reshape(policy._value_targets, [-1]), tf.reshape(values_batched, [-1])), "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), } if policy.config["vtrace"]: is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1]) stats_dict["mean_IS"] = is_stat_mean stats_dict["var_IS"] = is_stat_var if policy.config["use_kl_loss"]: stats_dict["kl"] = policy._mean_kl_loss stats_dict["KL_Coeff"] = policy.kl_coeff return stats_dict
def make_time_major(*args, **kw): return _make_time_major(policy, train_batch.get("seq_lens"), *args, **kw)
def make_time_major(*args, **kw): return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw)