def expand_with_ignore_as( tensor, expand_as, ignore_dim: Union[str, int] ) -> torch.Tensor: """ Expands the tensor like expand_as, but ignores a single dimension. Ie, if tensor is of size a x b, expand_as of size d x a x c and dim=-1, then the return will be of size d x a x b. It also automatically expands all plate dimensions correctly. :param ignore_dim: Can be a string referring to the plate dimension """ # diff = expand_as.ndim - tensor.ndim def _expand_with_ignore(tensor, expand_as, dim: int): new_dims = expand_as.ndim - tensor.ndim # after_dims = tensor.ndim - dim return tensor[(...,) + (None,) * new_dims].expand( expand_as.shape[:dim] + (-1,) + (expand_as.shape[dim + 1 :] if dim != -1 else ()) ) if isinstance(ignore_dim, str): return storch.deterministic( _expand_with_ignore, expand_plates=True, dim=ignore_dim )(tensor, expand_as) return storch.deterministic(_expand_with_ignore, expand_plates=True)( tensor, expand_as, ignore_dim )
def estimator( self, tensor: StochasticTensor, cost: CostTensor ) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]: # TODO: No support for alternative plate weighting plate = tensor.get_plate(tensor.name) index = tensor.get_plate_dim_index(plate.name) log_prob, baseline = storch.deterministic(self.comp_estimator)( tensor, cost, tensor.distribution.logits, index, plate.n // 2) return log_prob, (1 - magic_box(log_prob)) * baseline
def b_binary_cross_entropy( input: storch.Tensor, target: storch.Tensor, dims: Union[str, List[str]] = None, weight=None, reduction: str = "mean", ): r"""Function that measures the Binary Cross Entropy in a batched way between the target and the output. See :class:`~torch.nn.BCELoss` for details. Args: input: Tensor of arbitrary shape target: Tensor of the same shape as input weight (Tensor, optional): a manual rescaling weight if provided it's repeated to match input tensor shape reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Examples:: >>> input = torch.randn((3, 2), requires_grad=True) >>> target = torch.rand((3, 2), requires_grad=False) >>> loss = b_binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ if not dims: dims = [] if isinstance(dims, str): dims = [dims] target = target.expand_as(input) unreduced = deterministic(F.binary_cross_entropy)(input, target, weight, reduction="none") # unreduced = _loss(input, target, weight) indices = list(unreduced.event_dim_indices) + dims if reduction == "mean": return storch.mean(unreduced, indices) elif reduction == "sum": return storch.sum(unreduced, indices) elif reduction == "none": return unreduced
def gather(input: storch.Tensor, dim: str, index: storch.Tensor): # TODO: Should be allowed to accept int and storch.Plate as well return storch.deterministic(torch.gather, dim=dim, expand_plates=True)( input, index=index )
def expand_as(tensor: AnyTensor, expand_as: AnyTensor) -> AnyTensor: return storch.deterministic(torch.expand_as)(tensor, expand_as)
def __and__(self, other): if isinstance(other, bool): raise IllegalStorchExposeError( "Calling 'and' with a bool exposes the underlying tensor as a bool." ) return storch.deterministic(self._tensor.__and__)(other)