def process_function(f: NativeFunction) -> Optional[str]: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") if Variant.function not in f.variants or not is_factory: return None sig = CppSignatureGroup.from_native_function(f, method=False).signature formals: List[str] = [] exprs: List[str] = [] requires_grad = "false" for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) if arg.default: formals.append(f"{qualified_type} {arg.name} = {arg.default}") else: formals.append(f"{qualified_type} {arg.name}") if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. exprs.append( f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)") # Manually set the requires_grad bit on the result tensor. requires_grad = f"{arg.name}.requires_grad()" else: exprs.append(arg.name) return f"""\
def is_factory_function(f: NativeFunction) -> bool: if Variant.function not in f.variants: return False name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) return has_tensor_options or name.endswith("_like")
def go(f: NativeFunction) -> str: # header comments if isinstance(ps, PythonSignatureDeprecated): schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}" else: schema_comment = f"// aten::{f.func}" deprecated = "[deprecated] " if ps.deprecated else "" # dispatch lambda signature name = cpp.name(f.func) lambda_formals = ", ".join( map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))) lambda_return = dispatch_lambda_return_str(f) # dispatch lambda body dispatch_callee = cpp_dispatch_target(f) dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments parser_outputs = arg_parser_output_exprs(ps, f) lambda_arg_exprs = dispatch_lambda_exprs(ps, f) inits = "\n".join(lambda_arg_exprs.inits) lambda_args = ", ".join(lambda_arg_exprs.exprs) # scatter fields # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky # solution for enabling the 'requires_grad' argument for tensor methods # new_full, new_empty, and new_zeros. A much better but more difficult to # implement solution involves refactoring according to Ed's description here: # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 need_set_requires_grad = ps.tensor_options_args and ( not has_tensor_options(f) or (ps.method and ("requires_grad" in parser_outputs))) set_requires_grad = ( f'.set_requires_grad({parser_outputs["requires_grad"].expr})' if need_set_requires_grad else "") if lambda_return == "void": return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ pybind11::gil_scoped_release no_gil; {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; Py_RETURN_NONE; """ else: typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f)) namedtuple_typeref = f"{typename}, " if typename is not None else "" return f"""\
def process_function(f: NativeFunction) -> Optional[str]: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") if Variant.function not in f.variants or not is_factory: return None cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) sigs = [cpp_sigs.signature] if cpp_sigs.symint_signature is not None: sigs.append(cpp_sigs.symint_signature) r = "" for sig in sigs: formals: List[str] = [] exprs: List[str] = [] requires_grad = "false" for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) if arg.default: formals.append(f"{qualified_type} {arg.name} = {arg.default}") else: formals.append(f"{qualified_type} {arg.name}") if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. exprs.append( f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)" ) # Manually set the requires_grad bit on the result tensor. requires_grad = f"{arg.name}.requires_grad()" else: exprs.append(arg.name) r += f"""\ inline at::Tensor {sig.name()}({', '.join(formals)}) {{ at::AutoDispatchBelowADInplaceOrView guard; return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); }} """ return r