예제 #1
0
    def reduce(self, tensor: storch.Tensor, detach_weights=True):
        plate_weighting = self.weight
        if detach_weights:
            plate_weighting = self.weight.detach()
        if self.n == 1:
            return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor)
        # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean)
        elif plate_weighting.ndim == 0:
            return storch.sum(tensor, self) * plate_weighting

        # Case: There is a weight for each plate which is not dependent on the other batch dimensions
        elif plate_weighting.ndim == 1:
            index = tensor.get_plate_dim_index(self.name)
            plate_weighting = plate_weighting[
                (...,) + (None,) * (tensor.ndim - index - 1)
            ]
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)

        # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor
        else:
            for parent_plate in self.parents:
                if parent_plate not in tensor.plates:
                    raise ValueError(
                        "Plate missing when reducing tensor: " + parent_plate.name
                    )
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)
예제 #2
0
def logsumexp(tensor: storch.Tensor, dims: _indices) -> storch.Tensor:
    indices, reduced_batches = _convert_indices(tensor, dims)
    return storch.reduce(torch.logsumexp, plates=reduced_batches)(tensor, indices)
예제 #3
0
def mean(tensor: storch.Tensor, dims: _indices) -> storch.Tensor:
    indices, reduced_batches = _convert_indices(tensor, dims)
    return storch.reduce(torch.mean, plates=reduced_batches)(tensor, indices)