def _safe_copy_out(*, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False): # Checks same device if copy_from.device != copy_to.device: msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format( copy_from.device, copy_to.device) raise RuntimeError(msg) # Checks safe cast if exact_dtype: utils.check( copy_from.dtype == copy_to.dtype, lambda: f"Expected out tensor to have dtype {copy_from.dtype} " "but got {copy_to.dtype} instead", ) else: utils.check( utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " "but this can't be cast because it is not safe!", ) return prims.copy_to(copy_to, copy_from)
def _fn(*args, out=None, **kwargs): result = fn(*args, **kwargs) assert (isinstance(result, TensorLike) and is_tensor or isinstance(result, Tuple) # type: ignore[arg-type] and len(result) == len(out_names)) if out is not None: assert type(out) == type(result) if is_tensor: assert isinstance(out, TensorLike) # These two operations are done in-place _maybe_resize_out(out, result.shape) _safe_copy_out( copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: assert isinstance(out, Tuple) # type: ignore[arg-type] utils.check( len(out) == len(result), lambda: f"expected tuple of {len(result)} elements but got {len(out)}", TypeError, ) for r, o in zip(result, out): # These two operations are done in-place _maybe_resize_out(o, r.shape) _safe_copy_out( copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] else: out = result # mypy does not see through the definition of out_type given that it's in a different scope return out if is_tensor else return_type( *out) # type: ignore[operator]
def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: """ Reference implementation of torch.nn.functional.prelu """ check( isinstance(a, TensorLike), lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", ) check( isinstance(weight, TensorLike), lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", ) if weight.numel() != 1: check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") channel_size = a.shape[1] if a.ndim >= 2 else 1 check( weight.numel() == channel_size, lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" f" {weight.numel()} and channel size = {channel_size}.", ) check( weight.ndim == 0 or weight.ndim == 1, lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " f"ndim = {weight.ndim}", ) weight = prims.broadcast_in_dim( weight, a.shape, tuple() if weight.ndim == 0 else (1,) ) return refs.where(a > 0, a, a * weight)
def vector_norm( x: TensorLikeType, ord: float = 2.0, dim: Optional[DimsType] = None, keepdim: bool = False, *, dtype: Optional[torch.dtype] = None, ) -> Tensor: # Checks check_fp_or_complex(x.dtype, "linalg.vector_norm") if isinstance(dim, int): dim = [dim] # type: ignore[assignment] elif not isinstance(dim, List) and dim is not None: # refs.amin just accepts List rather than DimType (Tuple) dim = list(dim) # type: ignore[assignment] if x.numel() == 0 and (ord < 0.0 or ord == float("inf")): check( dim is not None and len(dim) != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor " "because the operation does not have an identity", ) shape = x.shape assert dim is not None # mypy does not seem to be able to see through check? for d in dim: check( shape[d] != 0, lambda: f"linalg.vector_norm cannot compute the {ord} norm on the " f"dimension {d} because this dimension is empty and the " "operation does not have an identity", ) check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") computation_dtype, result_dtype = reduction_dtypes( x, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype ) to_result_dtype = partial(prims.convert_element_type, dtype=result_dtype) # Implementation if ord == 0.0: return refs.sum(refs.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) elif ord == float("inf"): return to_result_dtype(refs.amax(torch.abs(x), dim=dim, keepdim=keepdim)) elif ord == float("-inf"): return to_result_dtype(refs.amin(torch.abs(x), dim=dim, keepdim=keepdim)) else: # From here on the computation dtype is important as the reduction is non-trivial x = prims.convert_element_type(x, computation_dtype) reduce_sum = partial(refs.sum, dim=dim, keepdim=keepdim) # Avoid computing a sqrt in abs and then squaring (more stable) # This could potentially be done for complex dtypes as # x = torch.real(torch.conj(x) * x)) # and it should be more stable, but it's not clear whether it'll be faster on, say # CPU (abs is 1 vectorised operation), so leaving it just for real dtypes for now if not (ord % 2.0 == 0.0 and is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord))
def meta_linalg_qr_helper(input, mode): if mode == "reduced": compute_q = True reduced_mode = True elif mode == "complete": compute_q = True reduced_mode = False elif mode == "r": compute_q = False reduced_mode = True else: raise RuntimeError(f"qr received unrecognized mode {mode}") check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor") check( utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), lambda: f"expected float or complex tensor, but got {input.dtype}" ) m = input.size(-2) n = input.size(-1) mn = min(m, n) if compute_q: Qt_shape = list(input.size()) Qt_shape[-2] = mn if reduced_mode else m Qt_shape[-1] = m Q = input.new_empty(Qt_shape) Q.transpose_(-2, -1) else: Q = input.new_empty(0) Rt_shape = list(input.size()) Rt_shape[-2] = n Rt_shape[-1] = mn if reduced_mode or not compute_q else m R = input.new_empty(Rt_shape) R.transpose_(-2, -1) return (Q, R)
def meta_diag(self, dim=0): check(self.dim() in (1, 2), lambda: "matrix or a vector expected") if self.dim() == 1: sz = self.size(0) + abs(dim) return self.new_empty((sz, sz)) # case: dim is 2 if dim >= 0: sz = min(self.size(0), self.size(1) - dim) else: sz = min(self.size(0) + dim, self.size(1)) return self.new_empty((sz,))
def softshrink(a: TensorLikeType, lambd: float = 0.5): # Formula for reference, # softshrink(x) = x - lambd if x > lambd # = x + lambd if x < -lambd # = 0 otherwise check( lambd >= 0, lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", ) ge_mask = a > lambd le_mask = a < -lambd zero_mask = torch.logical_not(refs.logical_or(ge_mask, le_mask)) result = refs.where(ge_mask, a - lambd, a) result = refs.where(le_mask, a + lambd, result) return refs.where(zero_mask, 0, result)
def meta_pad2d(self, padding): valid_dims = self.size(1) != 0 and self.size(2) != 0 check((self.ndim == 3 and valid_dims) or (self.ndim == 4 and valid_dims and self.size(3) != 0), f"3D or 4D (batch mode) tensor expected for input, but got: {self}") if self.ndim == 4: nbatch, nplane, input_h, input_w = self.shape else: nbatch = 1 nplane, input_h, input_w = self.shape pad_l, pad_r, pad_t, pad_b = padding output_h = input_h + pad_t + pad_b output_w = input_w + pad_l + pad_r if self.ndim == 3: return self.new_empty((nplane, output_h, output_w)) else: return self.new_empty((nbatch, nplane, output_h, output_w))
def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) self = self.expand((dim1, dim2)) check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") check( batch1.size(0) == batch2.size(0), lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}", ) check( batch1.size(2) == batch2.size(1), lambda: ( f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} " f"and {batch2.size(1)}x{batch2.size(2)})" ), ) check( self.size(0) == dim1 and self.size(1) == dim2, lambda: "self tensor does not match matmul output shape", ) return self.new_empty(self.size())
def meta_linalg_cholesky_ex(input, upper=False, check_errors=False): check(input.ndim >= 2, lambda: f"expected matrix or batch of matrices, but got {input.ndim}-D tensor") check( utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), lambda: f"expected float or complex tensor, but got {input.dtype}" ) check(input.size(-1) == input.size(-2), lambda: f"expected square matrix but got {input.shape}") L = input.new_empty(input.size()) L.transpose_(-2, -1) info_sizes = input.size()[:-2] info = input.new_empty(info_sizes, dtype=torch.int) return L, info
def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): """ Checks related to the dtype kwarg in `linalg.*norm` functions """ if dtype is not None: check( is_float_dtype(dtype) or is_complex_dtype(dtype), lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", ) check( is_complex_dtype(dtype) == is_complex_dtype(x_dtype), lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( fn_name=fn_name, d="complex" if is_complex_dtype(x_dtype) else "real", dtype=dtype, ), ) check( get_higher_dtype(dtype, x_dtype) == dtype, lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " "without narrowing to the specified dtype ({dtype})", )
def meta_embedding_bag( weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=-1, ): check( indices.dtype in (torch.long, torch.int), lambda: f"expected indices to be long or int, got {indices.dtype}", ) check( offsets.dtype in (torch.long, torch.int), lambda: f"expected offsets to be long or int, got {offsets.dtype}", ) check( utils.is_float_dtype(weight.dtype), lambda: f"expected weight to be floating point type, got {weight.dtype}", ) num_bags = offsets.size(0) if include_last_offset: check( num_bags >= 1, lambda: "include_last_offset: numBags should be at least 1" ) num_bags -= 1 output = weight.new_empty(num_bags, weight.size(1)) MODE_SUM, MODE_MEAN, MODE_MAX = range(3) if per_sample_weights is not None: check( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) check( per_sample_weights.dtype == weight.dtype, lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", ) check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", ) check( per_sample_weights.numel() == indices.numel(), lambda: ( f"expected per_sample_weights.numel() ({per_sample_weights.numel()} " f"to be the same as indices.numel() ({indices.numel()})" ), ) def is_fast_path_index_select_scale(src, scale, output, padding_idx): return ( is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1 ) def is_fast_path_index_select(src, output, padding_idx): return ( (src.dtype == torch.float or src.dtype == torch.half) and src.stride(1) == 1 and output.stride(1) == 1 and padding_idx < 0 ) def is_fast_path(src, scale, output, padding_idx): if scale is not None: return is_fast_path_index_select_scale(src, scale, output, padding_idx) else: return is_fast_path_index_select(src, output, padding_idx) if offsets.device.type != "cpu": offset2bag = indices.new_empty(indices.size(0)) bag_size = indices.new_empty(offsets.size()) if mode == MODE_MAX: max_indices = indices.new_empty(num_bags, weight.size(1)) else: max_indices = indices.new_empty(0) else: fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx) if mode == MODE_MEAN or mode == MODE_MAX or not fast_path_sum: offset2bag = offsets.new_empty(indices.size(0)) else: offset2bag = offsets.new_empty(0) bag_size = offsets.new_empty(num_bags) max_indices = offsets.new_empty(bag_size.size()) return output, offset2bag, bag_size, max_indices
def meta_cdist_forward(x1, x2, p, compute_mode): check( x1.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D", ) check( x2.dim() >= 2, lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D", ) check( x1.size(-1) == x2.size(-1), lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}", ) check( utils.is_float_dtype(x1.dtype), lambda: "cdist only supports floating-point dtypes, X1 got: {x1.dtype}", ) check( utils.is_float_dtype(x2.dtype), lambda: "cdist only supports floating-point dtypes, X2 got: {x2.dtype}", ) check(p >= 0, lambda: "cdist only supports non-negative p values") check( compute_mode >= 0 and compute_mode <= 2, lambda: f"possible modes: 0, 1, 2, but was: {compute_mode}", ) r1 = x1.size(-2) r2 = x2.size(-2) batch_tensor1 = x1.shape[:-2] batch_tensor2 = x2.shape[:-2] output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) output_shape.extend([r1, r2]) return x1.new_empty(output_shape)
def meta_index_Tensor(self, indices): check(indices, lambda: "at least one index must be provided") # aten::index is the internal advanced indexing implementation # checkIndexTensorTypes and expandTensors result: List[Optional[Tensor]] = [] for i, index in enumerate(indices): if index is not None: check( index.dtype in [torch.long, torch.int8, torch.bool], lambda: "tensors used as indices must be long, byte or bool tensors", ) if index.dtype in [torch.int8, torch.bool]: nonzero = index.nonzero() k = len(result) check( k + index.ndim <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim}", IndexError, ) for j in range(index.ndim): check( index.shape[j] == self.shape[k + j], lambda: f"The shape of the mask {index.shape} at index {i} " f"does not match the shape of the indexed tensor {self.shape} at index {k + j}", IndexError, ) result.append(nonzero.select(1, j)) else: result.append(index) else: result.append(index) indices = result check( len(indices) <= self.ndim, lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})", ) # expand_outplace import torch._refs as refs # avoid import cycle in mypy indices = list(refs._maybe_broadcast(*indices)) # add missing null tensors while len(indices) < self.ndim: indices.append(None) # hasContiguousSubspace # true if all non-null tensors are adjacent # See: # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency state = 0 has_contiguous_subspace = False for index in indices: if state == 0: if index is not None: state = 1 elif state == 1: if index is None: state = 2 else: if index is not None: break else: has_contiguous_subspace = True # transposeToFront # This is the logic that causes the newly inserted dimensions to show up # at the beginning of the tensor, if they're not contiguous if not has_contiguous_subspace: dims = [] transposed_indices = [] for i, index in enumerate(indices): if index is not None: dims.append(i) transposed_indices.append(index) for i, index in enumerate(indices): if index is None: dims.append(i) transposed_indices.append(index) self = self.permute(dims) indices = transposed_indices # AdvancedIndex::AdvancedIndex # Now we can assume the indices have contiguous subspace # This is simplified from AdvancedIndex which goes to more effort # to put the input and indices in a form so that TensorIterator can # take them. If we write a ref for this, probably that logic should # get implemented before_shape: List[int] = [] after_shape: List[int] = [] replacement_shape: List[int] = [] for dim, index in enumerate(indices): if index is None: if replacement_shape: after_shape.append(self.shape[dim]) else: before_shape.append(self.shape[dim]) else: replacement_shape = list(index.shape) return self.new_empty(before_shape + replacement_shape + after_shape)
def meta_adaptive_avg_pool3d(self, output_size): check( self.ndim == 4 or self.ndim == 5, lambda: f"Expected 4D or 5D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-3] + tuple(output_size))
def meta_adaptive_avg_pool2d(self, output_size): check( self.ndim == 3 or self.ndim == 4, lambda: f"Expected 3D or 4D tensor, but got {self.shape}", ) return self.new_empty(self.shape[:-2] + tuple(output_size))
def meta_dot(self, tensor): check( self.dim() == 1 and tensor.dim() == 1, lambda: f"1D tensors expected, but got {self.dim()}D and {tensor.dim()}D tensors", ) return self.new_empty(())