def reduce(self, tensor: storch.Tensor, detach_weights=True): plate_weighting = self.weight if detach_weights: plate_weighting = self.weight.detach() if self.n == 1: return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor) # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean) elif plate_weighting.ndim == 0: return storch.sum(tensor, self) * plate_weighting # Case: There is a weight for each plate which is not dependent on the other batch dimensions elif plate_weighting.ndim == 1: index = tensor.get_plate_dim_index(self.name) plate_weighting = plate_weighting[ (...,) + (None,) * (tensor.ndim - index - 1) ] weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self) # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor else: for parent_plate in self.parents: if parent_plate not in tensor.plates: raise ValueError( "Plate missing when reducing tensor: " + parent_plate.name ) weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self)
def logsumexp(tensor: storch.Tensor, dims: _indices) -> storch.Tensor: indices, reduced_batches = _convert_indices(tensor, dims) return storch.reduce(torch.logsumexp, plates=reduced_batches)(tensor, indices)
def mean(tensor: storch.Tensor, dims: _indices) -> storch.Tensor: indices, reduced_batches = _convert_indices(tensor, dims) return storch.reduce(torch.mean, plates=reduced_batches)(tensor, indices)