def pg_loss_stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """Returns the calculated loss in a stats dict. Args: policy (Policy): The Policy object. train_batch (SampleBatch): The data used for training. Returns: Dict[str, TensorType]: The stats dict. """ return { "policy_loss": torch.mean(torch.stack(policy.get_tower_stats("policy_loss"))), }
def stats_fn(policy: Policy, batch: SampleBatch) -> Dict[str, TensorType]: return {"loss": torch.mean(torch.stack(policy.get_tower_stats("loss")))}