def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: return convert_to_numpy({ "cur_lr": self.cur_lr, "total_loss": torch.mean(torch.stack(self.get_tower_stats("total_loss"))), "policy_loss": torch.mean(torch.stack(self.get_tower_stats("pi_loss"))), "entropy": torch.mean(torch.stack(self.get_tower_stats("mean_entropy"))), "entropy_coeff": self.entropy_coeff, "var_gnorm": global_norm(self.model.trainable_variables()), "vf_loss": torch.mean(torch.stack(self.get_tower_stats("vf_loss"))), "vf_explained_var": torch.mean(torch.stack(self.get_tower_stats("vf_explained_var"))), })
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, Any]: return { "cur_lr": policy.cur_lr, "total_loss": torch.mean(torch.stack(policy.get_tower_stats("total_loss"))), "policy_loss": torch.mean(torch.stack(policy.get_tower_stats("pi_loss"))), "entropy": torch.mean(torch.stack(policy.get_tower_stats("mean_entropy"))), "entropy_coeff": policy.entropy_coeff, "var_gnorm": global_norm(policy.model.trainable_variables()), "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("vf_loss"))), "vf_explained_var": torch.mean(torch.stack(policy.get_tower_stats("vf_explained_var"))), }
def stats(policy: Policy, train_batch: SampleBatch): """Stats function for APPO. Returns a dict with important loss stats. Args: policy (Policy): The Policy to generate stats for. train_batch (SampleBatch): The SampleBatch (already) used for training. Returns: Dict[str, TensorType]: The stats dict. """ stats_dict = { "cur_lr": policy.cur_lr, "total_loss": torch.mean(torch.stack(policy.get_tower_stats("total_loss"))), "policy_loss": torch.mean(torch.stack(policy.get_tower_stats("mean_policy_loss"))), "entropy": torch.mean(torch.stack(policy.get_tower_stats("mean_entropy"))), "entropy_coeff": policy.entropy_coeff, "var_gnorm": global_norm(policy.model.trainable_variables()), "vf_loss": torch.mean(torch.stack(policy.get_tower_stats("mean_vf_loss"))), "vf_explained_var": torch.mean(torch.stack(policy.get_tower_stats("vf_explained_var"))), } if policy.config["vtrace"]: is_stat_mean = torch.mean(policy._is_ratio, [0, 1]) is_stat_var = torch.var(policy._is_ratio, [0, 1]) stats_dict["mean_IS"] = is_stat_mean stats_dict["var_IS"] = is_stat_var if policy.config["use_kl_loss"]: stats_dict["kl"] = torch.mean( torch.stack(policy.get_tower_stats("mean_kl_loss"))) stats_dict["KL_Coeff"] = policy.kl_coeff return stats_dict