def spl_torch_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. predictions = dist_class(dist_inputs, model) targets = [] if policy.config["learn_action"]: targets.append(train_batch[SampleBatch.ACTIONS]) if policy.config["learn_reward"]: targets.append(train_batch[SampleBatch.REWARDS]) assert len(targets) > 0 targets = torch.cat(targets, dim=0) # Save the loss in the policy object for the spl_stats below. policy.spl_loss = policy.config["loss_fn"](predictions.dist.probs, targets) policy.entropy = predictions.dist.entropy().mean() return policy.spl_loss
def spl_torch_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) if policy.config["explore"]: # Adding that because of a bug in TorchCategorical # which modify dist_inputs through action_dist: _, _ = policy.exploration.get_exploration_action( action_distribution=action_dist, timestep=policy.global_timestep, explore=policy.config["explore"], ) action_dist = dist_class(dist_inputs, policy.model) targets = [] if policy.config["learn_action"]: targets.append(train_batch[SampleBatch.ACTIONS]) if policy.config["learn_reward"]: targets.append(train_batch[SampleBatch.REWARDS]) assert len(targets) > 0, (f"In config, use learn_action=True and/or " f"learn_reward=True to specify which target to " f"use in supervised learning") targets = torch.cat(targets, dim=0) # Save the loss in the policy object for the spl_stats below. policy.spl_loss = policy.config["loss_fn"](action_dist.dist.probs, targets) policy.entropy = action_dist.dist.entropy().mean() return policy.spl_loss