def meta_cdist_forward(x1, x2, p, compute_mode):
    check(
        x1.dim() >= 2,
        lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
    )
    check(
        x2.dim() >= 2,
        lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
    )
    check(
        x1.size(-1) == x2.size(-1),
        lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
    )
    check(
        utils.is_float_dtype(x1.dtype),
        lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
    )
    check(
        utils.is_float_dtype(x2.dtype),
        lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
    )
    check(p >= 0, lambda: "cdist only supports non-negative p values")
    check(
        compute_mode >= 0 and compute_mode <= 2,
        lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}",
    )
    r1 = x1.size(-2)
    r2 = x2.size(-2)
    batch_tensor1 = x1.shape[:-2]
    batch_tensor2 = x2.shape[:-2]
    output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
    output_shape.extend([r1, r2])
    return x1.new_empty(output_shape)
Beispiel #2
0
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype,
                     fn_name: str):
    """
    Checks related to the dtype kwarg in `linalg.*norm` functions
    """
    if dtype is not None:
        check(
            utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
            lambda:
            f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
        )
        check(
            utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
            lambda:
            "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".
            format(
                fn_name=fn_name,
                d="complex" if utils.is_complex_dtype(x_dtype) else "real",
                dtype=dtype,
            ),
        )
        check(
            utils.get_higher_dtype(dtype, x_dtype) == dtype,
            lambda:
            f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
            "without narrowing to the specified dtype ({dtype})",
        )
Beispiel #3
0
def norm(self: Tensor,
         p: float = 2,
         dim: List[int] = None,
         keepdim: bool = False):
    if dim is None:
        dim = []

    if p == 0:
        return (self != 0).sum(dim, keepdim=keepdim)
    elif p == float("inf"):
        return self.abs().amax(dim, keepdim=keepdim)
    elif p == -float("inf"):
        return self.abs().amin(dim, keepdim=keepdim)

    def fast_pow(x, ord):
        if ord == 1.0:
            return x
        elif ord == 2.0:
            return x.square()
        elif ord == 0.5:
            return x.sqrt()
        else:
            return x.pow(ord)

    if not (p % 2.0 == 0.0 and utils.is_float_dtype(self.dtype)):
        self = self.abs()

    return fast_pow(fast_pow(self, p).sum(dim, keepdim=keepdim), 1.0 / p)
Beispiel #4
0
def meta_embedding_bag(
    weight,
    indices,
    offsets,
    scale_grad_by_freq=False,
    mode=0,
    sparse=False,
    per_sample_weights=None,
    include_last_offset=False,
    padding_idx=-1,
):
    check(
        indices.dtype in (torch.long, torch.int),
        lambda: f"expected indices to be long or int, got {indices.dtype}",
    )
    check(
        offsets.dtype in (torch.long, torch.int),
        lambda: f"expected offsets to be long or int, got {offsets.dtype}",
    )
    check(
        utils.is_float_dtype(weight.dtype),
        lambda:
        f"expected weight to be floating point type, got {weight.dtype}",
    )

    num_bags = offsets.size(0)
    if include_last_offset:
        check(num_bags >= 1,
              lambda: "include_last_offset: numBags should be at least 1")
        num_bags -= 1

    output = weight.new_empty(num_bags, weight.size(1))
    MODE_SUM, MODE_MEAN, MODE_MAX = range(3)

    if per_sample_weights is not None:
        check(
            mode == MODE_SUM,
            lambda:
            "embedding_bag: per_sample_weights only supported with mode='sum'",
        )
        check(
            per_sample_weights.dtype == weight.dtype,
            lambda:
            f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype",
        )
        check(
            per_sample_weights.ndim == 1,
            lambda:
            f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
        )
        check(
            per_sample_weights.numel() == indices.numel(),
            lambda:
            (f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
             f"to be the same as indices.numel() ({indices.numel()})"),
        )

    def is_fast_path_index_select_scale(src, scale, output, padding_idx):
        return (is_fast_path_index_select(src, output, padding_idx)
                and scale.stride(0) == 1)

    def is_fast_path_index_select(src, output, padding_idx):
        return ((src.dtype == torch.float or src.dtype == torch.half)
                and src.stride(1) == 1 and output.stride(1) == 1
                and padding_idx < 0)

    def is_fast_path(src, scale, output, padding_idx):
        if scale is not None:
            return is_fast_path_index_select_scale(src, scale, output,
                                                   padding_idx)
        else:
            return is_fast_path_index_select(src, output, padding_idx)

    if offsets.device.type != "cpu":
        offset2bag = indices.new_empty(indices.size(0))
        bag_size = indices.new_empty(offsets.size())
        if mode == MODE_MAX:
            max_indices = indices.new_empty(num_bags, weight.size(1))
        else:
            max_indices = indices.new_empty(0)
    else:
        fast_path_sum = is_fast_path(weight, per_sample_weights, output,
                                     padding_idx)
        if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum:
            offset2bag = offsets.new_empty(indices.size(0))
        else:
            offset2bag = offsets.new_empty(0)
        bag_size = offsets.new_empty(num_bags)
        max_indices = offsets.new_empty(bag_size.size())
    return output, offset2bag, bag_size, max_indices
Beispiel #5
0
def vector_norm(
    x: TensorLikeType,
    ord: float = 2.0,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> Tensor:
    # Checks
    check_fp_or_complex(x.dtype, "linalg.vector_norm")

    if isinstance(dim, int):
        dim = [dim]  # type: ignore[assignment]
    elif not isinstance(dim, List) and dim is not None:
        # refs.amin just accepts List rather than DimType (Tuple)
        dim = list(dim)  # type: ignore[assignment]

    if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
        check(
            dim is not None and len(dim) != 0,
            lambda:
            f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
            "because the operation does not have an identity",
        )
        shape = x.shape
        assert dim is not None  # mypy does not seem to be able to see through check?
        for d in dim:
            check(
                shape[d] != 0,
                lambda:
                f"linalg.vector_norm cannot compute the {ord} norm on the "
                f"dimension {d} because this dimension is empty and the "
                "operation does not have an identity",
            )
    check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")

    computation_dtype, result_dtype = utils.reduction_dtypes(
        x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype)

    to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype)

    # Implementation
    if ord == 0.0:
        return refs.sum(refs.ne(x, 0.0),
                        dim=dim,
                        keepdim=keepdim,
                        dtype=result_dtype)
    elif ord == float("inf"):
        return to_result_dtype(
            refs.amax(torch.abs(x), dim=dim, keepdim=keepdim))
    elif ord == float("-inf"):
        return to_result_dtype(
            refs.amin(torch.abs(x), dim=dim, keepdim=keepdim))
    else:
        # From here on the computation dtype is important as the reduction is non-trivial
        x = prims.convert_element_type(x, computation_dtype)
        reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim)

        if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)):
            x = torch.abs(x)
        return to_result_dtype(
            torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))