示例#1
0
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType:
    """
    Reference implementation of torch.nn.functional.prelu
    """
    check(
        isinstance(a, TensorLike),
        lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}",
    )
    check(
        isinstance(weight, TensorLike),
        lambda:
        f"prelu: Expected `weight` to be tensor, but got: {type(weight)}",
    )

    if weight.numel() != 1:
        check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.")
        channel_size = a.shape[1] if a.ndim >= 2 else 1
        check(
            weight.numel() == channel_size,
            lambda:
            f"Mismatch of parameter numbers and input channel size. Found parameter numbers ="
            f" {weight.numel()} and channel size = {channel_size}.",
        )

    check(
        weight.ndim == 0 or weight.ndim == 1,
        lambda:
        f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: "
        f"ndim = {weight.ndim}",
    )
    weight = prims.broadcast_in_dim(weight, a.shape,
                                    tuple() if weight.ndim == 0 else (1, ))

    return refs.where(a > 0, a, a * weight)
示例#2
0
def _reduction(
    a: Tensor,
    prim: Callable,
    *,
    has_identity: bool = True,
    accepts_dim_tuple: bool = True,  # to handle min/argmin that accept single dim only
    dims: Optional[DimsType] = None,
    keepdims: bool = False,
    dtype: Optional[torch.dtype] = None,  # should be specified for ops that support it
    out: Optional[Tensor] = None,
    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
):  # it is usually SAME, but I want
    # ref writers to actually think about what to put here
    assert isinstance(a, TensorLike)
    if out is not None:
        assert isinstance(out, TensorLike)
        if dtype is not None:
            # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
            if dtype != out.dtype:
                raise RuntimeError(
                    "dtype argument and out dtype must match in reduction"
                )
    if not accepts_dim_tuple:
        assert dims is None or isinstance(dims, int)
    if isinstance(dims, int):
        dims = (dims,)  # type: ignore[assignment]
    dims = utils.reduction_dims(a.shape, dims)
    if not has_identity:
        valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims)
        if not valid_shape:
            raise RuntimeError(
                "reducing over zero-size dimension for reduction operation without identity"
            )
    # even though some reductions, like amin or amax, don't strictly require type promotion,
    # all the math ops (including comparisons) are still defined only for a computation type,
    # so promotion will still happen. We are doing it explicitly here
    inp_dtype = dtype if dtype is not None else a.dtype
    computation_dtype = utils._get_computation_dtype(inp_dtype)
    a_converted = prims.convert_element_type(a, computation_dtype)
    result = prim(a_converted, dims)

    if keepdims:
        output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
        broadcast_dims = [i for i in range(a.ndim) if i not in dims]
        result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
    if out is not None:
        if dtype is None:
            if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME:
                if out.dtype != a.dtype:
                    raise RuntimeError("Expected the dtype for input and out to match")
            elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL:
                if out.dtype != torch.bool:
                    raise RuntimeError("Expected the dtype for input and out to match")
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)  # type: ignore[arg-type]

    if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME:
        result_dtype = dtype if dtype else a.dtype
        result = prims.convert_element_type(result, result_dtype)
    return result
示例#3
0
    def _maybe_broadcast(x, shape):
        if x is None:
            return None
        elif isinstance(x, Number):
            return x
        elif isinstance(x, TensorLike):
            common_rank = len(common_shape) + 1
            start = common_rank - (len(x.shape) + 1)
            dims = tuple(range(start, len(x.shape) + start))

            # TODO: add a pass to remove unnecessary broadcast_in_dim calls
            return prims.broadcast_in_dim(x, common_shape, dims)
        else:
            raise RuntimeError("Unexpected type when broadcasting: " +
                               str(type(x)) + "!")
示例#4
0
    def __maybe_broadcast(x, shape):
        if x is None:
            return None
        elif isinstance(x, Number):
            return x
        elif isinstance(x, TensorLike):
            if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
                return x

            if tuple(x.shape) != common_shape:
                common_rank = len(common_shape) + 1
                start = common_rank - (len(x.shape) + 1)
                dims = tuple(range(start, len(x.shape) + start))
                return prims.broadcast_in_dim(x, common_shape, dims)
        else:
            raise RuntimeError("Unexpected type when broadcasting: " +
                               str(type(x)) + "!")
示例#5
0
 def _wrapper(a):
     a_sum = prims.sum(a, [0, 1])
     a_bc = prims.broadcast_in_dim(a_sum, [], [])
     return a_bc
示例#6
0
 def _wrapper(a, b, broadcast_dimensions):
     a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
     return prims.add(a_bc, b)
示例#7
0
 def _wrapper(a, shape, broadcast_dimensions):
     return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
fusion1.print_ir()

# Execute Fusion
input1 = torch.randn(3, device='cuda')
input2 = torch.randn(2, 3, 4, device='cuda')

# Kernel compilation should be cached for the 2nd iteration
# with input tensors of the same shape
for _ in range(5) :
    o = fusion1.execute([input1, input2])[0]

assert(o.shape == torch.Size([2, 3, 4]))

# Reference in prim torch
ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2)
assert(ref_o.allclose(o))
assert(ref_o.shape == o.shape)

fusion2 = Fusion()

input1 = torch.randn(1, 1, 4, device='cuda')
input2 = torch.randn(2, 3, 4, device='cuda')

with FusionDefinition(fusion2) as fd :
    t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride())
    t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride())

    t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2])
    t2 = fd.ops.add(t0_b, t1)