def convert_to_meta_tensors( sig: DispatcherSignature) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): unwrapped_name = f"{arg.name}_meta" unwrapped_tensor_args.append( f"auto {unwrapped_name} = to_meta({arg.name});") context.append(arg.with_name(unwrapped_name)) else: context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def convert_to_meta_tensors( sig: DispatcherSignature) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. a_ = arg.name unwrapped_name = f"{arg.name}_meta" unwrapped_tensor_args.append( f"auto {unwrapped_name} = to_meta({a_});") context.append(arg.with_name(unwrapped_name)) else: # for non-tensor inputs, we want to pass them directly into the redispatch calls. context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: context: List[Binding] = [] unwrapped_tensor_args: List[str] = [] for arg in sig.arguments(): if is_tensor_like(arg.argument): # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. # for tensor inputs, we want to unwrap them before passing them into the redispatch calls. a_ = arg.name unwrapped_name = f"{arg.name}_meta" unwrapped_tensor_args.append( f"auto {unwrapped_name} = at::native::empty_strided_meta({a_}.sizes(), {a_}.strides(), \ /*dtype=*/c10::make_optional({a_}.scalar_type()), /*layout=*/c10::make_optional({a_}.layout()), \ /*device=*/c10::make_optional(c10::Device(kMeta)), /*pin_memory=*/c10::nullopt);" ) context.append(arg.with_name(unwrapped_name)) else: # for non-tensor inputs, we want to pass them directly into the redispatch calls. context.append(arg) unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) return unwrap_tensor_args_str, context
def gen_composite_view_copy_kernel( g: NativeFunctionsViewGroup) -> Optional[str]: if g.view_copy is None: return None # view_copy is a native signature, since we're generating an at::native:: kernel view_copy_sig = NativeSignature(g.view_copy.func) # view is a dispatcher signature, since we're calling into the at::_ops API view_sig = DispatcherSignature(g.view.func) view_api_name = g.view.func.name.unambiguous_name() exprs = ", ".join([ e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments()) ]) # view ops today always return either a Tensor or a list of Tensors assert len(g.view.func.returns) == 1 assert g.view.func.returns[0].type == BaseType( BaseTy.Tensor) or g.view.func.returns[0].type == ListType( BaseType(BaseTy.Tensor), None) if g.view.func.returns[0].type == BaseType(BaseTy.Tensor): return_cloned_output = """\ return output.clone();""" else: # If the return type is a list, we need to clone each tensor in the list. return_cloned_output = f"""\ {view_copy_sig.returns_type().cpp_type()} out_clone; for (const auto i : c10::irange(output.size())) {{ out_clone.push_back(output[i].clone()); }} return out_clone;""" # The default generated composite kernel for {view}_copy() operators just clones # the input tensor, and runs the underlying view on the clone. return f"""