def _fn(*args, **kwargs): bound = sig.bind(*args, **kwargs) type_promoting_args = tuple( bound.arguments[x] for x in self.type_promoting_arg_names # type: ignore[union-attr] if x in bound.arguments.keys()) flattened_type_promoting_args = tree_flatten( type_promoting_args)[0] compute_dtype, result_dtype = utils.elementwise_dtypes( *flattened_type_promoting_args, type_promotion_kind=self.type_promotion_kind, ) promoted_args = { x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) for x in self.type_promoting_arg_names # type: ignore[union-attr] if x in bound.arguments.keys() } bound.arguments.update(promoted_args) result = fn(**bound.arguments) # FIXME?: assumes result is a single tensor assert isinstance(result, TensorLike) return _maybe_convert_to_dtype(result, result_dtype)
def inner(*args, **kwargs): flat_args = [ x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor) ] computation_dtype, result_dtype = utils.elementwise_dtypes( *flat_args, type_promotion_kind=type_promotion ) # TODO: pretty sure this is not quite right def increase_prec(x): if isinstance(x, Tensor): return x.to(computation_dtype) else: return x def decrease_prec(x): if isinstance(x, Tensor): return x.to(result_dtype) else: return x r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs)) return tree_map(decrease_prec, r)
def meta_angle(self): _, result_dtype = elementwise_dtypes( self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) return self.new_empty(self.size(), dtype=result_dtype)