def _reduction( a: Tensor, prim: Callable, *, has_identity: bool = True, accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only dims: Optional[DimsType] = None, keepdims: bool = False, dtype: Optional[torch.dtype] = None, # should be specified for ops that support it out: Optional[Tensor] = None, output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, ): # it is usually SAME, but I want # ref writers to actually think about what to put here assert isinstance(a, TensorLike) if out is not None: assert isinstance(out, TensorLike) if dtype is not None: # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms if dtype != out.dtype: raise RuntimeError( "dtype argument and out dtype must match in reduction" ) if not accepts_dim_tuple: assert dims is None or isinstance(dims, int) if isinstance(dims, int): dims = (dims,) # type: ignore[assignment] dims = utils.reduction_dims(a.shape, dims) if not has_identity: valid_shape = all(a.shape[i] for i in range(a.ndim) if i in dims) if not valid_shape: raise RuntimeError( "reducing over zero-size dimension for reduction operation without identity" ) # even though some reductions, like amin or amax, don't strictly require type promotion, # all the math ops (including comparisons) are still defined only for a computation type, # so promotion will still happen. We are doing it explicitly here inp_dtype = dtype if dtype is not None else a.dtype computation_dtype = utils._get_computation_dtype(inp_dtype) a_converted = prims.convert_element_type(a, computation_dtype) result = prim(a_converted, dims) if keepdims: output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] broadcast_dims = [i for i in range(a.ndim) if i not in dims] result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) if out is not None: if dtype is None: if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME: if out.dtype != a.dtype: raise RuntimeError("Expected the dtype for input and out to match") elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL: if out.dtype != torch.bool: raise RuntimeError("Expected the dtype for input and out to match") out = _maybe_resize_out(out, result.shape) return copy_to(out, result, allow_cross_device=False) # type: ignore[arg-type] if output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME: result_dtype = dtype if dtype else a.dtype result = prims.convert_element_type(result, result_dtype) return result
def meta_var_mean_correction(self, dim, *, correction, keepdim=False): dim = utils.reduction_dims(self.shape, dim) if keepdim: output_shape = tuple(self.shape[i] if i not in dim else 1 for i in range(self.ndim)) else: output_shape = utils.compute_reduction_output_shape(self.shape, dim) result1 = self.new_empty(output_shape, dtype=toRealValueType(self.dtype)) result2 = self.new_empty(output_shape) return result1, result2
def meta_nanmedian_dim(input, dim=-1, keepdim=False): dim = utils.reduction_dims(input.shape, (dim,)) output_shape = _compute_reduction_shape(input, dim, keepdim) return input.new_empty(output_shape), input.new_empty( output_shape, dtype=torch.long )
def meta_nansum(input, dims=None, keepdim=False, *, dtype=None): output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True) dims = utils.reduction_dims(input.shape, dims) output_shape = _compute_reduction_shape(input, dims, keepdim) return input.new_empty(output_shape, dtype=output_dtype)
def meta_var_mean_correction(self, dim, *, correction, keepdim=False): dim = utils.reduction_dims(self.shape, dim) output_shape = _compute_reduction_shape(self, dim, keepdim) result1 = self.new_empty(output_shape, dtype=toRealValueType(self.dtype)) result2 = self.new_empty(output_shape) return result1, result2