示例#1
0
def expand_with_ignore_as(
    tensor, expand_as, ignore_dim: Union[str, int]
) -> torch.Tensor:
    """
    Expands the tensor like expand_as, but ignores a single dimension.
    Ie, if tensor is of size a x b,  expand_as of size d x a x c and dim=-1, then the return will be of size d x a x b.
    It also automatically expands all plate dimensions correctly.
    :param ignore_dim: Can be a string referring to the plate dimension
    """
    # diff = expand_as.ndim - tensor.ndim
    def _expand_with_ignore(tensor, expand_as, dim: int):
        new_dims = expand_as.ndim - tensor.ndim
        # after_dims = tensor.ndim - dim
        return tensor[(...,) + (None,) * new_dims].expand(
            expand_as.shape[:dim]
            + (-1,)
            + (expand_as.shape[dim + 1 :] if dim != -1 else ())
        )

    if isinstance(ignore_dim, str):
        return storch.deterministic(
            _expand_with_ignore, expand_plates=True, dim=ignore_dim
        )(tensor, expand_as)
    return storch.deterministic(_expand_with_ignore, expand_plates=True)(
        tensor, expand_as, ignore_dim
    )
示例#2
0
文件: arm.py 项目: HEmile/storchastic
    def estimator(
        self, tensor: StochasticTensor, cost: CostTensor
    ) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
        # TODO: No support for alternative plate weighting
        plate = tensor.get_plate(tensor.name)
        index = tensor.get_plate_dim_index(plate.name)
        log_prob, baseline = storch.deterministic(self.comp_estimator)(
            tensor, cost, tensor.distribution.logits, index, plate.n // 2)

        return log_prob, (1 - magic_box(log_prob)) * baseline
示例#3
0
def b_binary_cross_entropy(
    input: storch.Tensor,
    target: storch.Tensor,
    dims: Union[str, List[str]] = None,
    weight=None,
    reduction: str = "mean",
):
    r"""Function that measures the Binary Cross Entropy in a batched way
    between the target and the output.

    See :class:`~torch.nn.BCELoss` for details.

    Args:
        input: Tensor of arbitrary shape
        target: Tensor of the same shape as input
        weight (Tensor, optional): a manual rescaling weight
                if provided it's repeated to match input tensor shape
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``

    Examples::

        >>> input = torch.randn((3, 2), requires_grad=True)
        >>> target = torch.rand((3, 2), requires_grad=False)
        >>> loss = b_binary_cross_entropy(F.sigmoid(input), target)
        >>> loss.backward()
    """

    if not dims:
        dims = []
    if isinstance(dims, str):
        dims = [dims]

    target = target.expand_as(input)
    unreduced = deterministic(F.binary_cross_entropy)(input,
                                                      target,
                                                      weight,
                                                      reduction="none")

    # unreduced = _loss(input, target, weight)
    indices = list(unreduced.event_dim_indices) + dims

    if reduction == "mean":
        return storch.mean(unreduced, indices)
    elif reduction == "sum":
        return storch.sum(unreduced, indices)
    elif reduction == "none":
        return unreduced
示例#4
0
def gather(input: storch.Tensor, dim: str, index: storch.Tensor):
    # TODO: Should be allowed to accept int and storch.Plate as well
    return storch.deterministic(torch.gather, dim=dim, expand_plates=True)(
        input, index=index
    )
示例#5
0
def expand_as(tensor: AnyTensor, expand_as: AnyTensor) -> AnyTensor:
    return storch.deterministic(torch.expand_as)(tensor, expand_as)
示例#6
0
 def __and__(self, other):
     if isinstance(other, bool):
         raise IllegalStorchExposeError(
             "Calling 'and' with a bool exposes the underlying tensor as a bool."
         )
     return storch.deterministic(self._tensor.__and__)(other)