def reduce(self, tensor: storch.Tensor, detach_weights=True): plate_weighting = self.weight if detach_weights: plate_weighting = self.weight.detach() if self.n == 1: return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor) # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean) elif plate_weighting.ndim == 0: return storch.sum(tensor, self) * plate_weighting # Case: There is a weight for each plate which is not dependent on the other batch dimensions elif plate_weighting.ndim == 1: index = tensor.get_plate_dim_index(self.name) plate_weighting = plate_weighting[ (...,) + (None,) * (tensor.ndim - index - 1) ] weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self) # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor else: for parent_plate in self.parents: if parent_plate not in tensor.plates: raise ValueError( "Plate missing when reducing tensor: " + parent_plate.name ) weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self)
def undo_unique(self, unique_tensor: storch.Tensor) -> torch.Tensor: """ Convert the unique tensor back to the non-unique format, then add the old plates back in # TODO: Make sure self.shrunken_plates is added # TODO: What if unique_tensor contains new plates after the unique? :param unique_tensor: :return: """ plate_idx = unique_tensor.get_plate_dim_index(self.name) with storch.ignore_wrapping(): dim_swapped = unique_tensor.transpose( plate_idx, unique_tensor.plate_dims - 1 ) fl_selected = torch.index_select( dim_swapped, dim=0, index=self.inv_indexing ) selected = fl_selected.reshape( tuple(map(lambda p: p.n, self.shrunken_plates)) + fl_selected.shape[1:] ) return storch.Tensor( selected, [unique_tensor], self.shrunken_plates + unique_tensor.plates, "undo_unique_" + unique_tensor.name, )
def _convert_indices(tensor: storch.Tensor, dims: _indices) -> (List[int], List[str]): conv_indices = [] red_batches = [] if not isinstance(dims, List): dims = [dims] for index in dims: if isinstance(index, int): if index >= tensor.plate_dims or index < 0 and index >= -tensor.event_dims: conv_indices.append(index) else: print(tensor.shape, index) raise IndexError( "Can only pass indexes for event dimensions." + str(tensor) + ". Index: " + str(index) ) else: if isinstance(index, storch.Plate): index = index.name conv_indices.append(tensor.get_plate_dim_index(index)) red_batches.append(index) return tuple(conv_indices), red_batches