def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", ) weight = prims.broadcast_in_dim(weight, a.shape, tuple() if weight.ndim == 0 else (1, )) return refs.where(a > 0, a, a * weight)
def _reduction( a: Tensor, prim: Callable, *, has_identity: bool = True, accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only dims: Optional[DimsType] = None, keepdims: bool = False, dtype: Optional[torch.dtype] = None, # should be specified for ops that support it out: Optional[Tensor] = None, output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, ): # it is usually SAME, but I want # ref writers to actually think about what to put here assert isinstance(a, TensorLike) if out is not None: assert isinstance(out, TensorLike) if dtype is not None: # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms if dtype != out.dtype: raise RuntimeError( "dtype argument and out dtype must match in reduction" ) if not accepts_dim_tuple: assert dims is None or isinstance(dims, int) if isinstance(dims, int): dims = (dims,) # type: ignore[assignment] dims = utils.reduction_dims(a.shape, dims) if not has_identity: valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims) if not valid_shape: raise RuntimeError( "reducing over zero-size dimension for reduction operation without identity" ) # even though some reductions, like amin or amax, don't strictly require type promotion, # all the math ops (including comparisons) are still defined only for a computation type, # so promotion will still happen. We are doing it explicitly here inp_dtype = dtype if dtype is not None else a.dtype computation_dtype = utils._get_computation_dtype(inp_dtype) a_converted = prims.convert_element_type(a, computation_dtype) result = prim(a_converted, dims) if keepdims: output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] broadcast_dims = [i for i in range(a.ndim) if i not in dims] result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) if out is not None: if dtype is None: if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME: if out.dtype != a.dtype: raise RuntimeError("Expected the dtype for input and out to match") elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL: if out.dtype != torch.bool: raise RuntimeError("Expected the dtype for input and out to match") out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME: result_dtype = dtype if dtype else a.dtype result = prims.convert_element_type(result, result_dtype) return result
def _maybe_broadcast(x, shape): if x is None: return None elif isinstance(x, Number): return x elif isinstance(x, TensorLike): common_rank = len(common_shape) + 1 start = common_rank - (len(x.shape) + 1) dims = tuple(range(start, len(x.shape) + start)) # TODO: add a pass to remove unnecessary broadcast_in_dim calls return prims.broadcast_in_dim(x, common_shape, dims) else: raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
def __maybe_broadcast(x, shape): if x is None: return None elif isinstance(x, Number): return x elif isinstance(x, TensorLike): if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): return x if tuple(x.shape) != common_shape: common_rank = len(common_shape) + 1 start = common_rank - (len(x.shape) + 1) dims = tuple(range(start, len(x.shape) + start)) return prims.broadcast_in_dim(x, common_shape, dims) else: raise RuntimeError("Unexpected type when broadcasting: " + str(type(x)) + "!")
def _wrapper(a): a_sum = prims.sum(a, [0, 1]) a_bc = prims.broadcast_in_dim(a_sum, [], []) return a_bc
def _wrapper(a, b, broadcast_dimensions): a_bc = prims.broadcast_in_dim(a, b.shape, broadcast_dimensions) return prims.add(a_bc, b)
def _wrapper(a, shape, broadcast_dimensions): return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
fusion1.print_ir() # Execute Fusion input1 = torch.randn(3, device='cuda') input2 = torch.randn(2, 3, 4, device='cuda') # Kernel compilation should be cached for the 2nd iteration # with input tensors of the same shape for _ in range(5) : o = fusion1.execute([input1, input2])[0] assert(o.shape == torch.Size([2, 3, 4])) # Reference in prim torch ref_o = refs.add(prims.broadcast_in_dim(input1, [2, 3, 4], [1]), input2) assert(ref_o.allclose(o)) assert(ref_o.shape == o.shape) fusion2 = Fusion() input1 = torch.randn(1, 1, 4, device='cuda') input2 = torch.randn(2, 3, 4, device='cuda') with FusionDefinition(fusion2) as fd : t0 = fd.define_tensor(sizes=input1.size(), strides=input1.stride()) t1 = fd.define_tensor(sizes=input2.size(), strides=input2.stride()) t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [0, 1, 2]) t2 = fd.ops.add(t0_b, t1)