Ejemplo n.º 1
0
    def reduce(self, tensor: storch.Tensor, detach_weights=True):
        plate_weighting = self.weight
        if detach_weights:
            plate_weighting = self.weight.detach()
        if self.n == 1:
            return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor)
        # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean)
        elif plate_weighting.ndim == 0:
            return storch.sum(tensor, self) * plate_weighting

        # Case: There is a weight for each plate which is not dependent on the other batch dimensions
        elif plate_weighting.ndim == 1:
            index = tensor.get_plate_dim_index(self.name)
            plate_weighting = plate_weighting[
                (...,) + (None,) * (tensor.ndim - index - 1)
            ]
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)

        # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor
        else:
            for parent_plate in self.parents:
                if parent_plate not in tensor.plates:
                    raise ValueError(
                        "Plate missing when reducing tensor: " + parent_plate.name
                    )
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)
Ejemplo n.º 2
0
 def undo_unique(self, unique_tensor: storch.Tensor) -> torch.Tensor:
     """
     Convert the unique tensor back to the non-unique format, then add the old plates back in
     # TODO: Make sure self.shrunken_plates is added
     # TODO: What if unique_tensor contains new plates after the unique?
     :param unique_tensor:
     :return:
     """
     plate_idx = unique_tensor.get_plate_dim_index(self.name)
     with storch.ignore_wrapping():
         dim_swapped = unique_tensor.transpose(
             plate_idx, unique_tensor.plate_dims - 1
         )
         fl_selected = torch.index_select(
             dim_swapped, dim=0, index=self.inv_indexing
         )
         selected = fl_selected.reshape(
             tuple(map(lambda p: p.n, self.shrunken_plates)) + fl_selected.shape[1:]
         )
         return storch.Tensor(
             selected,
             [unique_tensor],
             self.shrunken_plates + unique_tensor.plates,
             "undo_unique_" + unique_tensor.name,
         )
Ejemplo n.º 3
0
def _convert_indices(tensor: storch.Tensor, dims: _indices) -> (List[int], List[str]):
    conv_indices = []
    red_batches = []
    if not isinstance(dims, List):
        dims = [dims]
    for index in dims:
        if isinstance(index, int):
            if index >= tensor.plate_dims or index < 0 and index >= -tensor.event_dims:
                conv_indices.append(index)
            else:
                print(tensor.shape, index)
                raise IndexError(
                    "Can only pass indexes for event dimensions."
                    + str(tensor)
                    + ". Index: "
                    + str(index)
                )
        else:
            if isinstance(index, storch.Plate):
                index = index.name
            conv_indices.append(tensor.get_plate_dim_index(index))
            red_batches.append(index)
    return tuple(conv_indices), red_batches