コード例 #1
0
ファイル: wrappers.py プロジェクト: huaxz1986/pytorch
def _safe_copy_out(
    *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
):
    # Checks same device
    if copy_from.device != copy_to.device:
        msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
            copy_from.device, copy_to.device
        )
        raise RuntimeError(msg)

    # Checks safe cast
    if exact_dtype:
        utils.check(
            copy_from.dtype == copy_to.dtype,
            lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
            "but got {copy_to.dtype} instead",
        )
    else:
        utils.check(
            utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
            lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
            "but this can't be cast because it is not safe!",
        )

    return copy_to.copy_(copy_from)
コード例 #2
0
def _fft_c2r(
    func_name: str,
    input: TensorLikeType,
    n: Optional[int],
    dim: int,
    norm: NormType,
    forward: bool,
) -> TensorLikeType:
    """Common code for performing any complex to real FFT (irfft or hfft)"""
    input = _maybe_promote_tensor_fft(input, require_complex=True)
    dims = (utils.canonicalize_dim(input.ndim, dim), )
    last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
    check(last_dim_size >= 1,
          lambda: f"Invalid number of data points ({n}) specified")

    if n is not None:
        input = _resize_fft_input(input,
                                  dims=dims,
                                  sizes=(last_dim_size // 2 + 1, ))

    if forward:
        input = torch.conj(input)

    output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
    return _apply_norm(output,
                       norm=norm,
                       signal_numel=last_dim_size,
                       forward=forward)
コード例 #3
0
def norm(
    A: TensorLikeType,
    ord: Optional[Union[float, str]] = None,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
    if dim is not None:
        if isinstance(dim, int):
            dim = (dim, )  # type: ignore[assignment]
        check(
            len(dim) in (1, 2),
            lambda:
            "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
        )
    elif ord is not None:
        check(
            A.ndim in (1, 2),
            lambda:
            "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
        )

    if ord is not None and ((dim is not None and len(dim) == 2) or
                            (dim is None and A.ndim == 2)):
        if dim is None:
            dim = (0, 1)
        return matrix_norm(A, ord, dim, keepdim, dtype=dtype)
    else:
        if ord is None:
            ord = 2.0
        return vector_norm(A, ord, dim, keepdim, dtype=dtype)
コード例 #4
0
def meta_dot(self, tensor):
    check(
        self.dim() == 1 and tensor.dim() == 1,
        lambda:
        f"1D tensors expected, but got {self.dim()}D and {tensor.dim()}D tensors",
    )
    return self.new_empty(())
コード例 #5
0
def ihfftn(
    input: TensorLikeType,
    s: Optional[ShapeType] = None,
    dim: Optional[DimsType] = None,
    norm: NormType = None,
) -> TensorLikeType:
    check(
        not input.dtype.is_complex,
        lambda:
        f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
    )
    shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
    check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
    input = _maybe_promote_tensor_fft(input, require_complex=False)
    input = _resize_fft_input(input, dim, shape)

    tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)

    if len(dim) == 1:
        tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
        return prims.conj(tmp)

    tmp = prims.conj_physical(tmp)
    tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
    return _apply_norm(tmp,
                       norm=norm,
                       signal_numel=_prod(shape),
                       forward=False)
コード例 #6
0
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.prelu
    """
    check(
        isinstance(a, TensorLike),
        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
    )
    check(
        isinstance(weight, TensorLike),
        lambda:
        f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
    )

    if weight.numel() != 1:
        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
        channel_size = a.shape[1] if a.ndim >= 2 else 1
        check(
            weight.numel() == channel_size,
            lambda:
            f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
            f" {weight.numel()} and channel size = {channel_size}.",
        )

    check(
        weight.ndim == 0 or weight.ndim == 1,
        lambda:
        f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
        f"ndim = {weight.ndim}",
    )
    weight = prims.broadcast_in_dim(weight, a.shape,
                                    tuple() if weight.ndim == 0 else (1, ))

    return refs.where(a > 0, a, a * weight)
コード例 #7
0
def _canonicalize_fft_c2r_shape_and_dim_args(
    fname: str,
    input: TensorLikeType,
    s: Optional[ShapeType],
    dim: Optional[DimsType],
) -> _CanonicalizeC2rReturn:
    """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
    as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
    (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
    check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")

    if s is None or s[-1] == -1:
        last_dim_size = 2 * (input.shape[dim[-1]] - 1)
    else:
        last_dim_size = shape[-1]

    check(
        last_dim_size >= 1,
        lambda: f"Invalid number of data points ({last_dim_size}) specified",
    )

    shape_list = list(shape)
    shape_list[-1] = last_dim_size // 2 + 1
    return _CanonicalizeC2rReturn(shape=tuple(shape_list),
                                  dim=dim,
                                  last_dim_size=last_dim_size)
コード例 #8
0
def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType:
    dim = utils.canonicalize_dims(a.ndim, dim)
    check(
        a.shape[dim] % 2 == 0,
        lambda:
        f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}",
    )
    b, c = torch.tensor_split(a, 2, dim)

    return b * torch.sigmoid(c)
コード例 #9
0
        def _fn(*args, out=None, **kwargs):
            if is_factory_fn and out is not None:
                for k in factory_kwargs:
                    out_attr = getattr(out, k)
                    if k not in kwargs:
                        kwargs[k] = out_attr

            result = fn(*args, **kwargs)
            assert (isinstance(result, TensorLike) and is_tensor
                    or isinstance(result, Tuple)  # type: ignore[arg-type]
                    and len(result) == len(out_names))
            if out is not None:
                # Naively you might expect this assert to be true, but
                # it's not:
                #
                #   assert type(out) == type(result)
                #
                # The reason is that functions under this wrapper can
                # get registered to the Meta dispatch key, and that
                # means they can be executed in a context where tensor
                # subclasses are disabled (with no_dispatch), which is a
                # handy way for an is-a tensor subclass (e.g.,
                # FakeTensor) to have the normal meta backend create a
                # meta tensor, to be wrapped once it gets returned.
                # In this situation, you will get a FakeTensor as
                # the output tensor, but not the result--which will
                # be a normal meta tensor, but this is perfectly
                # harmless.
                if is_tensor:
                    assert isinstance(out, TensorLike)
                    # These two operations are done in-place
                    _maybe_resize_out(out, result.shape)
                    _safe_copy_out(
                        copy_from=result, copy_to=out,
                        exact_dtype=exact_dtype)  # type: ignore[arg-type]
                else:
                    assert isinstance(out, Tuple)  # type: ignore[arg-type]
                    utils.check(
                        len(out) == len(result),
                        lambda:
                        f"expected tuple of {len(result)} elements but got {len(out)}",
                        TypeError,
                    )
                    for r, o in zip(result, out):
                        # These two operations are done in-place
                        _maybe_resize_out(o, r.shape)
                        _safe_copy_out(
                            copy_from=r, copy_to=o,
                            exact_dtype=exact_dtype)  # type: ignore[arg-type]
            else:
                out = result
            # mypy does not see through  the definition of out_type given that it's in a different scope
            return out if is_tensor else return_type(
                *out)  # type: ignore[operator]
コード例 #10
0
def _apply_norm(x: TensorLikeType, norm: NormType, signal_numel: int,
                forward: bool) -> TensorLikeType:
    """Apply normalization to the un-normalized FFT result"""
    check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")

    if norm == "ortho":
        return x * (1 / math.sqrt(signal_numel))

    normalize = (not forward and (norm is None or norm == "backward")) or (
        forward and norm == "forward")
    return x * (1 / signal_numel) if normalize else x
コード例 #11
0
def meta_diag(self, dim=0):
    check(self.dim() in (1, 2), lambda: "matrix or a vector expected")
    if self.dim() == 1:
        sz = self.size(0) + abs(dim)
        return self.new_empty((sz, sz))

    # case: dim is 2
    if dim >= 0:
        sz = min(self.size(0), self.size(1) - dim)
    else:
        sz = min(self.size(0) + dim, self.size(1))
    return self.new_empty((sz, ))
コード例 #12
0
def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType:
    check(a.ndim == 2,
          lambda: f"pdist only supports 2D tensors, got: {a.ndim}D")
    check(p >= 0, lambda: "pdist only supports non-negative p values")
    # For p == 2 we can use an efficient implementation, but other values of p
    # require creating a much bigger tensor for an intermediate step
    if p == 2:
        aTa = torch.mm(a, a.T)
        aTa_diag = torch.diag(aTa)
        t = torch.sqrt(
            torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0))
    else:
        t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2)
    i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device)
    return t.flatten().index_select(0, i[0] * t.shape[0] + i[1])
コード例 #13
0
def softshrink(a: TensorLikeType, lambd: float = 0.5):
    # Formula for reference,
    # softshrink(x) = x - lambd if x > lambd
    #               = x + lambd if x < -lambd
    #               = 0 otherwise
    check(
        lambd >= 0,
        lambda:
        f"lambda must be greater or equal to 0, but found to be {lambd}",
    )
    ge_mask = a > lambd
    le_mask = a < -lambd
    zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask))
    result = refs.where(ge_mask, a - lambd, a)
    result = refs.where(le_mask, a + lambd, result)
    return refs.where(zero_mask, 0, result)
コード例 #14
0
def rfftn(
    input: TensorLikeType,
    s: Optional[ShapeType] = None,
    dim: Optional[DimsType] = None,
    norm: NormType = None,
) -> TensorLikeType:
    check(
        not input.dtype.is_complex,
        lambda:
        f"rfftn expects a real-valued input tensor, but got {input.dtype}",
    )
    shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
    input = _maybe_promote_tensor_fft(input, require_complex=False)
    input = _resize_fft_input(input, dim, shape)
    out = prims.fft_r2c(input, dim=dim, onesided=True)
    return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
コード例 #15
0
def _canonicalize_fft_shape_and_dim_args(
        input: TensorLikeType, shape: Optional[ShapeType],
        dim: Optional[DimsType]) -> _ShapeAndDims:
    """Convert the shape and dim arguments into a canonical form where neither are optional"""
    input_dim = input.ndim
    input_sizes = input.shape

    if dim is not None:
        if not isinstance(dim, Sequence):
            dim = (dim, )
        ret_dims = utils.canonicalize_dims(input_dim, dim)

        # Check dims are unique
        check(len(set(dim)) == len(dim), lambda: "FFT dims must be unique")

    if shape is not None:
        if not isinstance(shape, Sequence):
            shape = (shape, )

        # Has shape, might have dim
        check(
            dim is None or len(dim) == len(shape),
            lambda:
            "When given, dim and shape arguments must have the same length",
        )
        transform_ndim = len(shape)

        check(
            transform_ndim <= input_dim,
            lambda: f"Got shape with {transform_ndim} values but input tensor "
            f"only has {input_dim} dimensions.",
        )

        # If shape is given, dims defaults to the last len(shape) dimensions
        if dim is None:
            ret_dims = tuple(range(input_dim - transform_ndim, input_dim))

        # Translate any -1 values in shape to the default length
        ret_shape = tuple(s if s != -1 else input_sizes[d]
                          for (s, d) in zip(shape, ret_dims))
    elif dim is None:
        # No shape, no dim
        ret_dims = tuple(range(input_dim))
        ret_shape = tuple(input_sizes)
    else:
        # No shape, has dim
        ret_shape = tuple(input_sizes[d] for d in ret_dims)

    for n in ret_shape:
        check(n > 0, lambda: f"Invalid number of data points ({n}) specified")

    return _ShapeAndDims(shape=ret_shape, dims=ret_dims)
コード例 #16
0
def _fftn_c2c(
    function_name: str,
    input: TensorLikeType,
    shape: Tuple[int, ...],
    dim: Tuple[int, ...],
    norm: NormType,
    forward: bool,
) -> TensorLikeType:
    """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
    check(
        input.dtype.is_complex,
        lambda: f"{function_name} expects a complex input tensor, "
        f"but got {input.dtype}",
    )
    x = _resize_fft_input(input, dim, shape)
    output = prims.fft_c2c(x, dim=dim, forward=forward)
    return _apply_norm(output,
                       norm=norm,
                       signal_numel=_prod(shape),
                       forward=forward)
コード例 #17
0
def _fft_c2c(
    func_name: str,
    input: TensorLikeType,
    n: Optional[int],
    dim: int,
    norm: NormType,
    forward: bool,
) -> TensorLikeType:
    """Common code for performing any complex to complex FFT (fft or ifft)"""
    check(
        input.dtype.is_complex,
        lambda:
        f"{func_name} expects a complex input tensor, but got {input.dtype}",
    )
    dims = (utils.canonicalize_dim(input.ndim, dim), )

    if n is not None:
        input = _resize_fft_input(input, dims, (n, ))

    ret = prims.fft_c2c(input, dim=dims, forward=forward)
    return _apply_norm(ret, norm, input.shape[dim], forward)
コード例 #18
0
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype,
                     fn_name: str):
    """
    Checks related to the dtype kwarg in `linalg.*norm` functions
    """
    if dtype is not None:
        check(
            utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
            lambda:
            f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
        )
        check(
            utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
            lambda:
            "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".
            format(
                fn_name=fn_name,
                d="complex" if utils.is_complex_dtype(x_dtype) else "real",
                dtype=dtype,
            ),
        )
        check(
            utils.get_higher_dtype(dtype, x_dtype) == dtype,
            lambda:
            f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
            "without narrowing to the specified dtype ({dtype})",
        )
コード例 #19
0
def meta_pad2d(self, padding):
    valid_dims = self.size(1) != 0 and self.size(2) != 0
    check(
        (self.ndim == 3 and valid_dims)
        or (self.ndim == 4 and valid_dims and self.size(3) != 0),
        lambda: f"3D or 4D (batch mode) tensor expected for input, but got: {self}",
    )
    if self.ndim == 4:
        nbatch, nplane, input_h, input_w = self.shape
    else:
        nbatch = 1
        nplane, input_h, input_w = self.shape

    pad_l, pad_r, pad_t, pad_b = padding

    output_h = input_h + pad_t + pad_b
    output_w = input_w + pad_l + pad_r

    if self.ndim == 3:
        return self.new_empty((nplane, output_h, output_w))
    else:
        return self.new_empty((nbatch, nplane, output_h, output_w))
コード例 #20
0
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
    dim1 = batch1.size(1)
    dim2 = batch2.size(2)
    self = self.expand((dim1, dim2))
    check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
    check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
    check(
        batch1.size(0) == batch2.size(0),
        lambda:
        f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
    )
    check(
        batch1.size(2) == batch2.size(1),
        lambda:
        (f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
         f"and {batch2.size(1)}x{batch2.size(2)})"),
    )
    check(
        self.size(0) == dim1 and self.size(1) == dim2,
        lambda: "self tensor does not match matmul output shape",
    )
    return self.new_empty(self.size())
コード例 #21
0
def _fft_r2c(
    func_name: str,
    input: TensorLikeType,
    n: Optional[int],
    dim: int,
    norm: NormType,
    forward: bool,
    onesided: bool,
) -> TensorLikeType:
    """Common code for performing any real to complex FFT (rfft or ihfft)"""
    check(
        not input.dtype.is_complex,
        lambda:
        f"{func_name} expects a floating point input tensor, but got {input.dtype}",
    )
    input = _maybe_promote_tensor_fft(input)
    dims = (utils.canonicalize_dim(input.ndim, dim), )

    if n is not None:
        input = _resize_fft_input(input, dims, (n, ))

    ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
    ret = _apply_norm(ret, norm, input.shape[dim], forward)
    return ret if forward else torch.conj(ret)
コード例 #22
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda:
            f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda:
                f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = utils.reduction_dtypes(
        x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype)

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0),
                        dim=dim,
                        keepdim=keepdim,
                        dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(
            refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(
            refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(
            torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
コード例 #23
0
def dot_check(self, other):
    check(
        self.dim() == 1 and other.dim() == 1,
        lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
    )
コード例 #24
0
def matrix_norm(
    A: TensorLikeType,
    ord: Union[float, str] = "fro",
    dim: DimsType = (-2, -1),
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
    # shape
    check_is_matrix(A, "linalg.matrix_norm")
    # dim
    dim = utils.canonicalize_dims(A.ndim, dim)
    if isinstance(dim, int):
        dim = (dim, )  # type: ignore[assignment]
    check(
        len(dim) == 2,
        lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
    check(
        dim[0] != dim[1],
        lambda:
        "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
    )
    # dtype arg
    check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")

    if isinstance(ord, str):
        # ord
        check(
            ord in ("fro", "nuc"),
            lambda: "linalg.matrix_norm: Order {ord} not supported.",
        )
        # dtype
        check_fp_or_complex(A.dtype,
                            "linalg.matrix_norm",
                            allow_low_precision_dtypes=ord != "nuc")

        if ord == "fro":
            return vector_norm(A, 2, dim, keepdim, dtype=dtype)
        else:  # ord == "nuc"
            if dtype is not None:
                A = prims.convert_element_type(A, dtype)
            perm = backshift_permutation(dim[0], dim[1], A.ndim)
            result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
            if keepdim:
                inv_perm = inverse_permutation(perm)
                result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
            return result
    else:
        # ord
        abs_ord = abs(ord)
        check(
            abs_ord in (2, 1, float("inf")),
            lambda: "linalg.matrix_norm: Order {ord} not supported.",
        )
        # dtype
        check_fp_or_complex(A.dtype,
                            "linalg.matrix_norm",
                            allow_low_precision_dtypes=ord != 2)

        max_min = partial(torch.amax if ord > 0.0 else torch.amin,
                          keepdim=keepdim)

        if abs_ord == 2.0:
            if dtype is not None:
                A = prims.convert_element_type(A, dtype)
            perm = backshift_permutation(dim[0], dim[1], A.ndim)
            result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
            if keepdim:
                inv_perm = inverse_permutation(perm)
                result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
            return result
        else:  # 1, -1, inf, -inf
            dim0, dim1 = dim
            if abs_ord == float("inf"):
                dim0, dim1 = dim1, dim0
            if not keepdim and (dim0 < dim1):
                dim1 -= 1
            return max_min(
                vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype),
                dim1)
コード例 #25
0
def meta_adaptive_avg_pool2d(self, output_size):
    check(
        self.ndim == 3 or self.ndim == 4,
        lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
    )
    return self.new_empty(self.shape[:-2] + tuple(output_size))
コード例 #26
0
def meta_embedding_bag(
    weight,
    indices,
    offsets,
    scale_grad_by_freq=False,
    mode=0,
    sparse=False,
    per_sample_weights=None,
    include_last_offset=False,
    padding_idx=-1,
):
    check(
        indices.dtype in (torch.long, torch.int),
        lambda: f"expected indices to be long or int, got {indices.dtype}",
    )
    check(
        offsets.dtype in (torch.long, torch.int),
        lambda: f"expected offsets to be long or int, got {offsets.dtype}",
    )
    check(
        utils.is_float_dtype(weight.dtype),
        lambda:
        f"expected weight to be floating point type, got {weight.dtype}",
    )

    num_bags = offsets.size(0)
    if include_last_offset:
        check(num_bags >= 1,
              lambda: "include_last_offset: numBags should be at least 1")
        num_bags -= 1

    output = weight.new_empty(num_bags, weight.size(1))
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)

    if per_sample_weights is not None:
        check(
            mode == MODE_SUM,
            lambda:
            "embedding_bag: per_sample_weights only supported with mode='sum'",
        )
        check(
            per_sample_weights.dtype == weight.dtype,
            lambda:
            f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
        )
        check(
            per_sample_weights.ndim == 1,
            lambda:
            f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
        )
        check(
            per_sample_weights.numel() == indices.numel(),
            lambda:
            (f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
             f"to be the same as indices.numel() ({indices.numel()})"),
        )

    def is_fast_path_index_select_scale(src, scale, output, padding_idx):
        return (is_fast_path_index_select(src, output, padding_idx)
                and scale.stride(0) == 1)

    def is_fast_path_index_select(src, output, padding_idx):
        return ((src.dtype == torch.float or src.dtype == torch.half)
                and src.stride(1) == 1 and output.stride(1) == 1
                and padding_idx < 0)

    def is_fast_path(src, scale, output, padding_idx):
        if scale is not None:
            return is_fast_path_index_select_scale(src, scale, output,
                                                   padding_idx)
        else:
            return is_fast_path_index_select(src, output, padding_idx)

    if offsets.device.type != "cpu":
        offset2bag = indices.new_empty(indices.size(0))
        bag_size = indices.new_empty(offsets.size())
        if mode == MODE_MAX:
            max_indices = indices.new_empty(num_bags, weight.size(1))
        else:
            max_indices = indices.new_empty(0)
    else:
        fast_path_sum = is_fast_path(weight, per_sample_weights, output,
                                     padding_idx)
        if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
            offset2bag = offsets.new_empty(indices.size(0))
        else:
            offset2bag = offsets.new_empty(0)
        bag_size = offsets.new_empty(num_bags)
        max_indices = offsets.new_empty(bag_size.size())
    return output, offset2bag, bag_size, max_indices
コード例 #27
0
def meta_cdist_forward(x1, x2, p, compute_mode):
    check(
        x1.dim() >= 2,
        lambda:
        f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
    )
    check(
        x2.dim() >= 2,
        lambda:
        f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
    )
    check(
        x1.size(-1) == x2.size(-1),
        lambda:
        f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
    )
    check(
        utils.is_float_dtype(x1.dtype),
        lambda:
        "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
    )
    check(
        utils.is_float_dtype(x2.dtype),
        lambda:
        "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
    )
    check(p >= 0, lambda: "cdist only supports non-negative p values")
    check(
        compute_mode >= 0 and compute_mode <= 2,
        lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}",
    )
    r1 = x1.size(-2)
    r2 = x2.size(-2)
    batch_tensor1 = x1.shape[:-2]
    batch_tensor2 = x2.shape[:-2]
    output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
    output_shape.extend([r1, r2])
    return x1.new_empty(output_shape)
コード例 #28
0
def meta_index_Tensor(self, indices):
    check(indices, lambda: "at least one index must be provided")
    # aten::index is the internal advanced indexing implementation
    # checkIndexTensorTypes and expandTensors
    result: List[Optional[Tensor]] = []
    for i, index in enumerate(indices):
        if index is not None:
            check(
                index.dtype in [torch.long, torch.int8, torch.bool],
                lambda:
                "tensors used as indices must be long, byte or bool tensors",
            )
            if index.dtype in [torch.int8, torch.bool]:
                nonzero = index.nonzero()
                k = len(result)
                check(
                    k + index.ndim <= self.ndim,
                    lambda:
                    f"too many indices for tensor of dimension {self.ndim}",
                    IndexError,
                )
                for j in range(index.ndim):
                    check(
                        index.shape[j] == self.shape[k + j],
                        lambda:
                        f"The shape of the mask {index.shape} at index {i} "
                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
                        IndexError,
                    )
                    result.append(nonzero.select(1, j))
            else:
                result.append(index)
        else:
            result.append(index)
    indices = result
    check(
        len(indices) <= self.ndim,
        lambda:
        f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
    )
    # expand_outplace
    import torch._refs as refs  # avoid import cycle in mypy

    indices = list(refs._maybe_broadcast(*indices))
    # add missing null tensors
    while len(indices) < self.ndim:
        indices.append(None)

    # hasContiguousSubspace
    #   true if all non-null tensors are adjacent
    # See:
    # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
    # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
    state = 0
    has_contiguous_subspace = False
    for index in indices:
        if state == 0:
            if index is not None:
                state = 1
        elif state == 1:
            if index is None:
                state = 2
        else:
            if index is not None:
                break
    else:
        has_contiguous_subspace = True

    # transposeToFront
    # This is the logic that causes the newly inserted dimensions to show up
    # at the beginning of the tensor, if they're not contiguous
    if not has_contiguous_subspace:
        dims = []
        transposed_indices = []
        for i, index in enumerate(indices):
            if index is not None:
                dims.append(i)
                transposed_indices.append(index)
        for i, index in enumerate(indices):
            if index is None:
                dims.append(i)
                transposed_indices.append(index)
        self = self.permute(dims)
        indices = transposed_indices

    # AdvancedIndex::AdvancedIndex
    # Now we can assume the indices have contiguous subspace
    # This is simplified from AdvancedIndex which goes to more effort
    # to put the input and indices in a form so that TensorIterator can
    # take them.  If we write a ref for this, probably that logic should
    # get implemented
    before_shape: List[int] = []
    after_shape: List[int] = []
    replacement_shape: List[int] = []
    for dim, index in enumerate(indices):
        if index is None:
            if replacement_shape:
                after_shape.append(self.shape[dim])
            else:
                before_shape.append(self.shape[dim])
        else:
            replacement_shape = list(index.shape)
    return self.new_empty(before_shape + replacement_shape + after_shape)
コード例 #29
0
def meta_adaptive_avg_pool3d(self, output_size):
    check(
        self.ndim == 4 or self.ndim == 5,
        lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
    )
    return self.new_empty(self.shape[:-3] + tuple(output_size))