Ejemplo n.º 1
0
def sample_inputs_sparse_coo_masked_reduction(
    op_info, device, dtype, requires_grad, **kwargs
):
    """Sample inputs for masked reduction operators that support inputs
    with sparse coo layouts.
    """
    if op_info.supports_sparse:
        op_name = op_info.name.replace("_masked.", "")
        for sample_input in sample_inputs_masked_reduction(
            op_info, device, dtype, requires_grad, **kwargs
        ):
            mask = sample_input.kwargs.get("mask")
            if mask is not None:
                sample_input_kwargs = sample_input.kwargs.copy()
                sample_input_kwargs.update(mask=mask.to_sparse())
                yield SampleInput(
                    sample_input.input.to_sparse(),
                    args=sample_input.args,
                    kwargs=sample_input_kwargs,
                )
            else:
                if op_name in {"prod", "amax", "amin"}:
                    # FIXME: for now reductions with non-zero reduction identity and
                    # unspecified mask are not supported for sparse COO
                    # tensors, see torch._masked.prod implementation
                    # for details.
                    continue
                yield SampleInput(
                    sample_input.input.to_sparse(),
                    args=sample_input.args,
                    kwargs=sample_input.kwargs,
                )
Ejemplo n.º 2
0
def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs):

    samples = (
        SampleInput(
            make_tensor((S, ),
                        dtype=dtype,
                        device=device,
                        requires_grad=requires_grad)),
        SampleInput(
            make_tensor((),
                        dtype=dtype,
                        device=device,
                        requires_grad=requires_grad)),
    )

    if requires_grad and op_info.op == torch.special.i0e:
        # NOTE: `i0e`'s first-order gradient is not continous
        # at `0`, hence we don't test `i0e` with any input being `0`.
        # TODO: Remove this when `make_tensor` supports excluding `0`.
        for sample in samples:
            t = sample.input
            t[t == 0] = torch.finfo(dtype).eps  # type: ignore[index]
    elif requires_grad and op_info.op != torch.special.i0e:
        # Special Case for gradient
        # Sample with `0` in the input
        t = make_tensor((S, ),
                        dtype=dtype,
                        device=device,
                        requires_grad=requires_grad)
        t[0] = 0

        samples += (SampleInput(t), )  # type: ignore[assignment]

    return samples
Ejemplo n.º 3
0
def sample_inputs_sparse_csr_masked_reduction(
    op_info, device, dtype, requires_grad, **kwargs
):
    """Sample inputs for masked reduction operators that support inputs
    with sparse csr layouts.
    """
    if op_info.supports_sparse_csr:
        op_name = op_info.name.replace("_masked.", "")
        for sample_input in sample_inputs_masked_reduction(
            op_info, device, dtype, requires_grad, **kwargs
        ):
            if not (
                sample_input.input.ndim == 2 and sample_input.kwargs.get("keepdim")
            ):
                # - sparse CSR tensors are always 2-D tensors
                # - masked reduction on CSR tensors are defined only if keepdim is True.
                continue
            mask = sample_input.kwargs.get("mask")
            if mask is not None:
                sample_input_kwargs = sample_input.kwargs.copy()
                sample_input_kwargs.update(mask=mask.to_sparse_csr())
                new_sample = SampleInput(
                    sample_input.input.to_sparse_csr(),
                    args=sample_input.args,
                    kwargs=sample_input_kwargs,
                )
            else:
                if op_name in ["prod", "amax", "amin", "mean"]:
                    # reductions with non-zero reduction identity and
                    # unspecified mask is not supported for sparse CSR
                    # tensors, see torch._masked.prod implementation
                    # for details.
                    continue
                new_sample = SampleInput(
                    sample_input.input.to_sparse_csr(),
                    args=sample_input.args,
                    kwargs=sample_input.kwargs,
                )
            yield new_sample
            if sample_input.kwargs["dim"] == 0:
                # Reductions of CSR tensors use different implementations for
                # inner and/or outer dimensions. So, as a minimum of testing CSR
                # implementations the following kwargs must be generated:
                #   dict(dim=0, keepdim=True)
                #   dict(dim=1, keepdim=True)
                #   dict(dim=(0, 1), keepdim=True)
                # Here we generate the dim=1 case from the dim=0 case.
                sample_input_kwargs = new_sample.kwargs.copy()
                sample_input_kwargs.update(dim=1)
                yield SampleInput(
                    new_sample.input.clone(),
                    args=sample_input.args,
                    kwargs=sample_input_kwargs,
                )
Ejemplo n.º 4
0
def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
    def mt(shape, **kwargs):
        return make_tensor(shape,
                           device=device,
                           dtype=dtype,
                           requires_grad=requires_grad,
                           **kwargs)

    yield SampleInput(mt((9, 10)))
    yield SampleInput(mt((50, )), kwargs=dict(dim=0))
    yield SampleInput(mt((5, 11)), kwargs=dict(dim=(1, )))
    yield SampleInput(mt((5, 6)), kwargs=dict(dim=(0, 1)))
    yield SampleInput(mt((5, 6, 2)), kwargs=dict(dim=(0, 2)))
Ejemplo n.º 5
0
def sample_inputs_masked_logaddexp(op_info, device, dtype, requires_grad, **kwargs):
    """Sample inputs for masked logaddexp."""
    inputs: List[SampleInput] = []
    shapes = [(S,), (S, S), (S, M, S)]
    input_mask_lists = [
        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
    ]
    other_mask_lists = [
        list(_generate_masked_op_mask(shape, device, **kwargs)) for shape in shapes
    ]

    for shape, input_masks, other_masks in zip(
        shapes, input_mask_lists, other_mask_lists
    ):
        for input_mask, other_mask in zip(input_masks, other_masks):
            input = make_tensor(
                shape, dtype=dtype, device=device, requires_grad=requires_grad
            )
            other = make_tensor(
                shape, dtype=dtype, device=device, requires_grad=requires_grad
            )
            inputs.append(
                SampleInput(
                    input.clone().requires_grad_(requires_grad),
                    args=(other.clone().requires_grad_(requires_grad),),
                    kwargs=dict(input_mask=input_mask, other_mask=other_mask),
                )
            )
    return inputs
Ejemplo n.º 6
0
def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs):
    """Sample inputs for masked cumsum and cumprod."""
    inputs: List[SampleInput] = []
    for sample_input in sample_inputs_softmax_variant(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        for mask in _generate_masked_op_mask(
            sample_input.input.shape, device, **kwargs
        ):
            if type(mask) != torch.Tensor:
                continue
            sample_input_args, sample_input_kwargs = sample_input.args, dict(
                mask=mask, **sample_input.kwargs
            )
            if "keepdim" in sample_input_kwargs:
                sample_input_kwargs.pop("keepdim")
            # dimension is required
            if sample_input_args:
                dim = sample_input.args[0]
            else:
                if "dim" not in sample_input_kwargs:
                    continue
                dim = sample_input_kwargs.pop("dim")
                sample_input_args = (dim,)
            inputs.append(
                SampleInput(
                    sample_input.input.clone().requires_grad_(requires_grad),
                    args=sample_input_args,
                    kwargs=sample_input_kwargs,
                )
            )

    return inputs
Ejemplo n.º 7
0
def sample_inputs_masked_softmax(
    op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
):
    """Sample inputs for masked softmax, log_softmax, and softmin.

    Masked normalization operator is a reduction operator with
    trailing mask optional argument. A mask is a bool tensor with the
    same shape as input or a shape that is broadcastable to input
    shape.
    """
    inputs: List[SampleInput] = []
    for sample_input in sample_inputs_softmax_variant(
        op_info, device, dtype, requires_grad, with_dtype=with_dtype, **kwargs
    ):
        for mask in _generate_masked_op_mask(
            sample_input.input.shape, device, **kwargs
        ):
            sample_input_args, sample_input_kwargs = sample_input.args, dict(
                mask=mask, **sample_input.kwargs
            )
            inputs.append(
                SampleInput(
                    sample_input.input.clone().requires_grad_(requires_grad),
                    args=sample_input_args,
                    kwargs=sample_input_kwargs,
                )
            )
    return inputs
Ejemplo n.º 8
0
def sample_inputs_polygamma(op_info, device, dtype, requires_grad, **kwargs):
    make_arg = partial(make_tensor,
                       device=device,
                       dtype=dtype,
                       requires_grad=requires_grad)
    tensor_shapes = ((S, S), ())
    ns = (1, 2, 3, 4, 5)

    for shape, n in product(tensor_shapes, ns):
        yield SampleInput(make_arg(shape), args=(n, ))
Ejemplo n.º 9
0
def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs):
    low, _ = op_info.domain

    if requires_grad:
        low = 0 + op_info._domain_eps

    return (
        SampleInput(
            make_tensor((L, ),
                        dtype=dtype,
                        device=device,
                        low=low,
                        requires_grad=requires_grad)),
        SampleInput(
            make_tensor((),
                        dtype=dtype,
                        device=device,
                        low=low,
                        requires_grad=requires_grad)),
    )
Ejemplo n.º 10
0
def sample_inputs_masked_reduction(op_info, device, dtype, requires_grad, **kwargs):
    """Sample inputs for masked reduction operators.

    Masked reduction operator is a reduction operator with trailing
    mask optional argument. A mask is a bool tensor with the same
    shape as input or a shape that is broadcastable to input shape.
    """
    kwargs["supports_multiple_dims"] = op_info.supports_multiple_dims

    for sample_input in sample_inputs_reduction(
        op_info, device, dtype, requires_grad, **kwargs
    ):
        for mask in _generate_masked_op_mask(
            sample_input.input.shape, device, **kwargs
        ):
            sample_input_args, sample_input_kwargs = sample_input.args, dict(
                mask=mask, **sample_input.kwargs
            )
            yield SampleInput(
                sample_input.input.detach().requires_grad_(requires_grad),
                args=sample_input_args,
                kwargs=sample_input_kwargs,
            )
            if (
                not requires_grad
                and dtype.is_floating_point
                and sample_input.input.ndim == 2
                and mask is not None
                and mask.shape == sample_input.input.shape
            ):
                for v in [torch.inf, -torch.inf, torch.nan]:
                    t = sample_input.input.detach()
                    t.diagonal(0, -2, -1).fill_(v)
                    yield SampleInput(
                        t.requires_grad_(requires_grad),
                        args=sample_input_args,
                        kwargs=sample_input_kwargs,
                    )
Ejemplo n.º 11
0
def sample_inputs_masked_norm(op_info, device, dtype, requires_grad, **kwargs):
    """Sample inputs for masked norm."""
    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
        for sample_input in sample_inputs_masked_reduction(
            op_info, device, dtype, requires_grad, **kwargs
        ):
            sample_input_args, sample_input_kwargs = (
                ord,
            ) + sample_input.args, sample_input.kwargs.copy()
            yield SampleInput(
                sample_input.input.clone().requires_grad_(requires_grad),
                args=sample_input_args,
                kwargs=sample_input_kwargs,
            )
Ejemplo n.º 12
0
def sample_inputs_masked_std_var(op_info, device, dtype, requires_grad, **kwargs):
    """Sample inputs for masked std/var."""
    for unbiased in [False, True]:
        for sample_input in sample_inputs_masked_reduction(
            op_info, device, dtype, requires_grad, **kwargs
        ):
            if sample_input.args:
                dim = sample_input.args[0]
                sample_input_args = (
                    sample_input.args[:1] + (unbiased,) + sample_input.args[1:]
                )
                sample_input_kwargs = sample_input.kwargs.copy()
            else:
                dim = sample_input.kwargs.get("dim")
                sample_input_args = sample_input.args
                sample_input_kwargs = dict(sample_input.kwargs, unbiased=unbiased)
            if requires_grad:
                if sample_input_kwargs.get("mask") is None:
                    orig_count = torch._masked.sum(
                        torch.ones(sample_input.input.shape, dtype=torch.int64),
                        dim,
                        keepdim=True,
                    )
                else:
                    inmask = torch._masked._input_mask(
                        sample_input.input, *sample_input_args, **sample_input_kwargs
                    )
                    orig_count = torch._masked.sum(
                        inmask.new_ones(sample_input.input.shape, dtype=torch.int64),
                        dim,
                        keepdim=True,
                        mask=inmask,
                    )
                if orig_count.min() <= int(unbiased) + 1:
                    # Skip samples that lead to singularities in var
                    # computation resulting nan values both in var and
                    # autograd output that test_grad_fn cannot handle
                    # correctly. Also, skip samples when the autograd output
                    # for std could not be handled correctly due to torch.sqrt
                    continue
            yield SampleInput(
                sample_input.input.detach().requires_grad_(requires_grad),
                args=sample_input_args,
                kwargs=sample_input_kwargs,
            )
Ejemplo n.º 13
0
def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwargs):
    """Sample inputs for masked normalize."""
    inputs: List[SampleInput] = []
    for ord in [2.0, 1, float("inf"), float("-inf"), 0]:
        for sample_input in sample_inputs_softmax_variant(
            op_info, device, dtype, requires_grad, **kwargs
        ):
            sample_input_args, sample_input_kwargs = (
                ord,
            ) + sample_input.args, sample_input.kwargs.copy()
            inputs.append(
                SampleInput(
                    sample_input.input.clone().requires_grad_(requires_grad),
                    args=sample_input_args,
                    kwargs=sample_input_kwargs,
                )
            )
    return inputs
Ejemplo n.º 14
0
def sample_inputs_softmax_variant(
    op_info, device, dtype, requires_grad, with_dtype=False, **kwargs
):
    make_arg = partial(
        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
    )
    cases = [
        ((S,), (0,)),
        ((S, S), (0,)),
        ((S, S), (1,)),
        ((S, S), (-1,)),
        ((S, M, S), (2,)),
    ]
    kwargs = dict(dtype=torch.float64) if with_dtype else None

    # PyTorch on XLA throws an error when passed with dim argument for 0d tensor.
    # See https://github.com/pytorch/xla/issues/3061 for more details.
    if torch.device(device).type != "xla":
        cases.append(((), (0,)))

    return [
        SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
    ]