예제 #1
0
def where(condition, x, y):
    """Returns elements chosen from x or y depending on a boolean tensor condition.

    The input tensors ``condition``, ``x``, and ``y`` must all be broadcastable to the same shape.

    Args:
        condition (tensor_like[bool]): A boolean tensor. Where True, elements from
            ``x`` will be chosen, otherwise ``y``.
        x (tensor_like): values from which to choose if the condition evaluates to ``True``
        y (tensor_like): values from which to choose if the condition evaluates to ``False``

    Returns:
        tensor_like: A tensor with elements from ``x`` where the condition is ``True``, and
        ``y`` otherwise. The output tensor has the same shape as the input tensors.

    **Example**

    >>> a = torch.tensor([0.6, 0.23, 0.7, 1.5, 1.7], requires_grad=True)
    >>> b = torch.tensor([-1., -2., -3., -4., -5.], requires_grad=True)
    >>> math.where(a < 1, a, b)
    tensor([ 0.6000,  0.2300,  0.7000, -4.0000, -5.0000], grad_fn=<SWhereBackward>)
    """
    return np.where(condition, x, y, like=_multi_dispatch([condition, x, y]))
예제 #2
0
def where(condition, x=None, y=None):
    """Returns elements chosen from x or y depending on a boolean tensor condition,
    or the indices of entries satisfying the condition.

    The input tensors ``condition``, ``x``, and ``y`` must all be broadcastable to the same shape.

    Args:
        condition (tensor_like[bool]): A boolean tensor. Where ``True`` , elements from
            ``x`` will be chosen, otherwise ``y``. If ``x`` and ``y`` are ``None`` the
            indices where ``condition==True`` holds will be returned.
        x (tensor_like): values from which to choose if the condition evaluates to ``True``
        y (tensor_like): values from which to choose if the condition evaluates to ``False``

    Returns:
        tensor_like or tuple[tensor_like]: If ``x is None`` and ``y is None``, a tensor
        or tuple of tensors with the indices where ``condition`` is ``True`` .
        Else, a tensor with elements from ``x`` where the ``condition`` is ``True``,
        and ``y`` otherwise. In this case, the output tensor has the same shape as
        the input tensors.

    **Example with three arguments**

    >>> a = torch.tensor([0.6, 0.23, 0.7, 1.5, 1.7], requires_grad=True)
    >>> b = torch.tensor([-1., -2., -3., -4., -5.], requires_grad=True)
    >>> math.where(a < 1, a, b)
    tensor([ 0.6000,  0.2300,  0.7000, -4.0000, -5.0000], grad_fn=<SWhereBackward>)

    .. warning::

        The output format for ``x=None`` and ``y=None`` follows the respective
        interface and differs between TensorFlow and all other interfaces:
        For TensorFlow, the output is a tensor with shape
        ``(num_true, len(condition.shape))`` where ``num_true`` is the number
        of entries in ``condition`` that are ``True`` .
        The entry at position ``(i, j)`` is the ``j`` th entry of the ``i`` th
        index.
        For all other interfaces, the output is a tuple of tensor-like objects,
        with the ``j`` th object indicating the ``j`` th entries of all indices.
        Also see the examples below.

    **Example with single argument**

    For Torch, Autograd, JAX and NumPy, the output formatting is as follows:

    >>> a = [[0.6, 0.23, 1.7],[1.5, 0.7, -0.2]]
    >>> math.where(torch.tensor(a) < 1)
    (tensor([0, 0, 1, 1]), tensor([0, 1, 1, 2]))

    This is not a single tensor-like object but corresponds to the shape
    ``(2, 4)`` . For TensorFlow, on the other hand:

    >>> math.where(tf.constant(a) < 1)
    tf.Tensor(
    [[0 0]
     [0 1]
     [1 1]
     [1 2]], shape=(4, 2), dtype=int64)

    As we can see, the dimensions are swapped and the output is a single Tensor.
    Note that the number of dimensions of the output does *not* depend on the input
    shape, it is always two-dimensional.

    """
    if x is None and y is None:
        interface = _multi_dispatch([condition])
        return np.where(condition, like=interface)

    return np.where(condition, x, y, like=_multi_dispatch([condition, x, y]))