def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", ) weight = prims.broadcast_in_dim( weight, a.shape, tuple() if weight.ndim == 0 else (1,) ) return refs.where(a > 0, a, a * weight)
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: assert isinstance(a, TensorLike) utils.validate_idx(a.ndim, dim) utils.validate_dim_length(outer_length) # Verifies the dim can be split with the specified lhs_length _inner_length = a.shape[dim] / outer_length inner_length: int = int(_inner_length) if inner_length != _inner_length: msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format( a.shape[dim], outer_length) raise ValueError(msg) new_shape: List[int] = [] new_strides: List[int] = [] for idx in range(a.ndim): if idx == dim: new_shape.extend((outer_length, inner_length)) new_strides.extend( (a.stride()[idx] * inner_length, a.stride()[idx])) else: new_shape.append(a.shape[idx]) new_strides.append(a.stride()[idx]) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _reshape_meta(a: TensorLikeType, shape: ShapeType): assert isinstance(a, TensorLike) utils.validate_shape(shape) # Validates the tensor and the requested shape have the # same number of elements numel = reduce(operator.mul, shape) if numel != a.numel(): msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format( a.numel(), numel) raise ValueError(msg) return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
def _maybe_resize_out(out: TensorLikeType, shape): if out.numel() == 0: return prims.resize(out, shape) if out.numel() != reduce(operator.mul, shape, 1): msg = ( "An output with one or more elements was resized since it had shape {0} " "which does not match the required output shape {1}. " "This behavior is deprecated, and in a future PyTorch release outputs will not " "be resized unless they have zero elements. " "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." .format(str(out.shape), str(shape))) warnings.warn(msg) return prims.resize(out, shape) return out
def vector_norm( x: TensorLikeType, ord: float = 2.0, dim: Optional[DimsType] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> Tensor: # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") if isinstance(dim, int): dim = [dim] # type: ignore[assignment] elif not isinstance(dim, List) and dim is not None: # refs.amin just accepts List rather than DimType (Tuple) dim = list(dim) # type: ignore[assignment] if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): check( dim is not None and len(dim) != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " "because the operation does not have an identity", ) shape = x.shape assert dim is not None # mypy does not seem to be able to see through check? for d in dim: check( shape[d] != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " f"dimension {d} because this dimension is empty and the " "operation does not have an identity", ) check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") computation_dtype, result_dtype = reduction_dtypes( x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype ) to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) # Implementation if ord == 0.0: return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) elif ord == float("inf"): return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) elif ord == float("-inf"): return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) else: # From here on the computation dtype is important as the reduction is non-trivial x = prims.convert_element_type(x, computation_dtype) reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) # Avoid computing a sqrt in abs and then squaring (more stable) # This could potentially be done for complex dtypes as # x = torch.real(torch.conj(x) * x)) # and it should be more stable, but it's not clear whether it'll be faster on, say # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType: assert isinstance(a, TensorLike) shape = a.shape strides = a.stride() utils.validate_idx(shape, start) utils.validate_exclusive_idx(shape, end) # Verifies end is strictly greater than start # (Collapse requires a non-empty interval) assert end > start length = 1 stride = 1 for idx in range(start, end): if idx != (end - 1): assert strides[idx] == strides[idx + 1] * shape[idx + 1] length = length * shape[idx] stride = stride * strides[idx] new_shape = shape[:start] + (length, ) + shape[end:] new_strides = strides[:start] + (stride, ) + shape[end:] return TensorMeta(a, shape=new_shape, strides=new_strides)
def _reshape_meta(a: TensorLikeType, shape: Sequence): assert isinstance(a, TensorLike) utils.validate_shape(shape) # Validates the tensor and the requested shape have the # same number of elements numel = reduce(lambda acc, x: acc * x, shape) assert a.numel() == numel
def _copy_to_meta(a: TensorLikeType, b: TensorLikeType): assert isinstance(a, TensorLike) assert isinstance(b, TensorLike) # Validates the cast is safe # TODO: move this as an option on the reference # a_typ = utils.dtype_to_type(a.dtype) # b_typ = utils.dtype_to_type(b.dtype) # if a_typ is not utils.get_higher_type(a_typ, b_typ): # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!") # Validates the tensors have the same number of elements if a.numel() != b.numel(): msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format( b.numel(), a.numel()) raise RuntimeError(msg) return a
def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: assert isinstance(a, TensorLike) utils.validate_idx(a.shape, dim) utils.validate_dim_length(outer_length) # Verifies the dim can be split with the specified lhs_length _inner_length = a.shape[dim] / outer_length inner_length: int = int(_inner_length) assert inner_length == _inner_length new_shape: List[int] = [] new_strides: List[int] = [] for idx in a.shape: if idx == dim: new_shape.extend((outer_length, inner_length)) new_strides.extend( (a.stride()[idx] * inner_length, a.stride()[idx])) else: new_shape.append(a.shape[idx]) new_strides.append(a.stride()[idx]) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: assert isinstance(a, TensorLike) for idx in dimensions: utils.validate_idx(a.ndim, idx) assert a.shape[idx] == 1 new_shape = [] new_strides = [] for idx in range(len(a.shape)): if idx in dimensions: continue new_shape.append(a.shape[idx]) new_strides.append(a.stride()[idx]) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType: if a.ndim != len(permutation): msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format( a.ndim, len(permutation)) raise ValueError(msg) if not utils.is_valid_permutation(a.ndim, permutation): msg = "Received an invalid permutation, {0}!".format(permutation) raise ValueError(msg) new_shape = [0] * a.ndim new_strides = [0] * a.ndim for idx, dim in enumerate(permutation): new_shape[idx] = a.shape[dim] new_strides[idx] = a.stride()[dim] return TensorMeta(a, shape=tuple(new_shape), strides=tuple(new_strides))
def _broadcast_in_dim_meta(a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]): # Type checks assert isinstance(a, TensorLike) assert isinstance(shape, Sequence) assert isinstance(broadcast_dimensions, Sequence) # every dimension must be accounted for assert a.ndim == len(broadcast_dimensions) # broadcast shape must have weakly more dimensions assert len(shape) >= a.ndim # broadcast_dimensions must be an ascending sequence # (no relative reordering of dims) of integers and # each dimension must be within the new shape def _greater_than_reduce(acc, x): assert isinstance(x, int) assert x > acc assert x < len(shape) return x reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1) # shape must be broadcastable to for idx, new_idx in enumerate(broadcast_dimensions): assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx] new_strides = [] original_idx = 0 for idx in range(len(shape)): if idx in broadcast_dimensions: new_strides.append(a.stride()[original_idx]) original_idx = original_idx + 1 else: new_strides.append(0) return TensorMeta(a, shape=shape, strides=new_strides)
def _collapse_view_helper( a: TensorLikeType, start: int, end: int) -> Tuple[Optional[ShapeType], Optional[StrideType]]: assert isinstance(a, TensorLike) # Special-case for zero dimensional tensors if a.ndim == 0: shape = (1, ) strides = (1, ) else: shape = a.shape # type: ignore[assignment] strides = a.stride() utils.validate_idx(len(shape), start) utils.validate_exclusive_idx(len(shape), end) # Verifies end is strictly greater than start # (Collapse requires a non-empty interval) if end <= start: msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format( end, start) raise ValueError(msg) length = 1 stride = 1 for idx in range(start, end): if idx != (end - 1): if not (strides[idx] == strides[idx + 1] * shape[idx + 1]): return None, None length = length * shape[idx] stride = stride * strides[idx] new_shape = shape[:start] + (length, ) + shape[end:] new_strides = strides[:start] + (stride, ) + shape[end:] return new_shape, new_strides
def _slice_meta( a: TensorLikeType, start_indices: DimsSequenceType, limit_indices: DimsSequenceType, strides: Optional[StrideType] = None, ) -> TensorLikeType: _strides = strides if strides is not None else [1] * len(start_indices) if a.ndim != len(start_indices): msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format( a.ndim, len(start_indices)) raise ValueError(msg) if a.ndim != len(limit_indices): msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format( a.ndim, len(limit_indices)) raise ValueError(msg) if a.ndim != len(_strides): msg = ( "Attempting to slice tensor of rank {0} with strides of length {1}!" .format(a.ndim, len(limit_indices))) raise ValueError(msg) for x, y in zip(start_indices, a.shape): if x < 0: msg = "Attempting to slice a tensor with a negative start index of {0}!".format( x) raise ValueError(msg) if x > y: msg = ( "Attempting to slice a tensor but a start index in {0} is greater than" " the length of its corresponding dimension in shape {1}". format(start_indices, a.shape)) raise ValueError(msg) for x, y, z in zip(limit_indices, a.shape, start_indices): if x < 0: msg = "Attempting to slice a tensor with a negative stop index of {0}!".format( x) raise ValueError(msg) if x > y: msg = ( "Attempting to slice a tensor but a stop index in {0} is greater than the length of " " its corresponding dimension in shape {1}".format( limit_indices, a.shape)) raise ValueError(msg) if x < z: msg = ( "Attempting to slice a tensor but a start index in {0} is greater than " " its corresponding stop index {1}".format(x, z)) for x in _strides: if x <= 0: msg = ( "Attempting to slice a tensor with a non-positive step of {0}!" .format(x)) raise ValueError(msg) new_shape = [] for x, y, z in zip(start_indices, limit_indices, _strides): new_shape.append(math.floor((y - x) / z)) new_strides = [] for x, y in zip(a.stride(), _strides): new_strides.append(x * y) return TensorMeta(a, shape=new_shape, strides=new_strides)
def _reshape_view_helper(a: TensorLikeType, shape: ShapeType, *, allow_copy: bool) -> TensorLikeType: # NOTE: Reshape may be given a shape with a -1 length # This indicates that the dimension's length should be inferred # Creates a valid shape for idx in range(len(shape)): if shape[idx] == -1: # Verifies there's only one dimension of length -1 in the shape if shape.count(-1) > 1: msg = "Can only infer the length of one dimension, but got shape {0}!".format( str(shape)) raise ValueError(msg) # TODO: improve error message if a.numel() > 0: length = reduce(operator.floordiv, (x for x in shape if x != -1), a.numel()) else: msg = "Cannot reshape a tensor of zero elements into shape {0} because the unspecified length is ambiguous!".format( str(shape)) raise ValueError(msg) shape = list(shape) shape[idx] = length break # Short-circuits if shape is the same utils.validate_shape(shape) if tuple(a.shape) == tuple(shape): return prims.view_of(a) numel = reduce(operator.mul, shape) if len(shape) > 0 else 1 if a.numel() != numel: msg = "Attempting to reshape a tensor with shape {0} and {1} elements to a shape {2} with {3} elements!".format( str(a.shape), a.numel(), str(shape), numel) raise ValueError(msg) # Special-cases tensors with no elements if a.numel() == 0: return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) # Special-cases reshaping zero dim tensors if a.ndim == 0: _a = a for length in shape: assert length == 1 _a = unsqueeze(_a, -1) return _a # Special-cases reshaping to zero dim tensors if len(shape) == 0: _a = a for length in a.shape: assert length == 1 _a = squeeze(_a, -1) return _a # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape # NOTE [Reshape Algorithm] # This algorithm works by attempting to greedily construct the desired dimensions in # the output shape, left to right. It does this by, conceptually, accumulating # dimensions of the original tensor, also left to right, until the dimension # can be constructed using prims.split_dim. # The algorithm also has special handling for tail squeezes/unsqueezes, like # if a reshape from (5, 5) to (5, 5, 1) or vice versa. # # This algorithm does not flatten the original tensor and then split dims as appropriate # because that would create copies more often than this algorithm. flatten is the only # operation below which can create a view or a copy, and while it prefers creating # views it may sometimes create a copy if the tensor's strides do not permit a view. # As a result, this algorithm tries to minimize flattening. # # Note that a better version of this algorithm may exist. Regions which could be # flattened without creating a copy can be identified in advance, and that might # allow fewer flatten calls or faster short-circuiting to make a copy. idx = 0 a_ = a for length in shape: # Handles tail unsqueezes if idx >= a_.ndim: assert length == 1 a_ = unsqueeze(a_, -1) idx = idx + 1 continue # Skips dimensions that are already the correct length if length == a_.shape[idx]: idx = idx + 1 continue # Gathers enough original dimensions such that this new dimension can be created # Note that this accumulation will terminate because we've verified a and the shape # specify the same number of elements above accum = a_.shape[idx] end = idx while accum % length != 0: end = end + 1 accum = accum * a_.shape[end] if end != idx: # NOTE: in this case multiple dimensions must be flatten to create the desired dimension # This flattening is why reshape sometimes creates a copy -- because flattening # may return a view of a copy # Checks if collapse can be a view and short-circuits to copying reshape if it can't new_shape, new_strides = prims._collapse_view_helper( a_, idx, end + 1) if new_shape is None: if allow_copy: return prims.reshape(a, shape) msg = "Cannot view a tensor with shape {0} and strides {1} as a tensor with shape {2}!".format( a.shape, a.stride(), shape) raise ValueError(msg) a_ = flatten(a_, idx, end) # Splits the (possibly flattened) dimension to create the desired dim length if accum != length: a_ = prims.split_dim(a_, idx, length) idx = idx + 1 # Squeezes tail while idx < a_.ndim: assert a_.shape[idx] == 1 a_ = squeeze(a_, idx) return a_
def _resize_meta(a: TensorLikeType, shape: Union[torch.Size, List[int], Tuple[int, ...]]): assert a.numel() == 0 return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))