def go(f: NativeFunction) -> str: # header comments deprecated = '[deprecated] ' if ps.deprecated else '' schema_comment = f'// {deprecated}aten::{f.func}' # dispatch lambda signature name = cpp.name(f.func) lambda_formals = ', '.join( map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))) lambda_return = dispatch_lambda_return_str(f) # dispatch lambda body dispatch_callee = cpp_dispatch_target(f) dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments parser_outputs = arg_parser_output_exprs(ps, f) lambda_arg_exprs = dispatch_lambda_exprs(ps, f) inits = '\n'.join(lambda_arg_exprs.inits) lambda_args = ', '.join(lambda_arg_exprs.exprs) # scatter fields # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky # solution for enabling the 'requires_grad' argument for tensor methods # new_full, new_empty, and new_zeros. A much better but more difficult to # implement solution involves refactoring according to Ed's description here: # https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589 need_set_requires_grad = ps.tensor_options_args and ( not has_tensor_options(f) or (ps.method and ('requires_grad' in parser_outputs))) set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \ if need_set_requires_grad else '' if lambda_return == 'void': return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ pybind11::gil_scoped_release no_gil; {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; Py_RETURN_NONE; """ else: typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f)) namedtuple_typeref = f'&{typename}, ' if typename is not None else '' return f"""\
def go(f: NativeFunction) -> str: # header comments deprecated = '[deprecated] ' if ps.deprecated else '' schema_comment = f'// {deprecated}aten::{f.func}' # dispatch lambda signature name = cpp.name(f.func) lambda_formals = ', '.join( map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))) lambda_return = dispatch_lambda_return_str(f) # dispatch lambda body dispatch_callee = cpp_dispatch_target(f) dispatch_args = ', '.join(cpp_dispatch_exprs(f, python_signature=ps)) # from arg parser outputs to dispatch lambda arguments parser_outputs = arg_parser_output_exprs(ps, f) lambda_arg_exprs = dispatch_lambda_exprs(ps, f) inits = '\n'.join(lambda_arg_exprs.inits) lambda_args = ', '.join(lambda_arg_exprs.exprs) need_set_requires_grad = ps.tensor_options_args and ( not has_tensor_options(f) or (ps.method and ('requires_grad' in parser_outputs))) set_requires_grad = f'.set_requires_grad({parser_outputs["requires_grad"].expr})' \ if need_set_requires_grad else '' if lambda_return == 'void': return f"""\ {schema_comment} {inits} auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ // pybind11::gil_scoped_release no_gil; {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; """ else: return f"""\
def dispatch_lambda_exprs(ps: PythonSignature, f: NativeFunction) -> DispatchLambdaArgumentExprs: # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser # outputs. arg_parser_outputs = arg_parser_output_exprs(ps, f) lambda_args = dispatch_lambda_args(ps, f) inits: List[str] = [] lambda_args_exprs: Dict[str, str] = {} has_toptions = has_tensor_options(f) # 1. special inits/unpacking to provide binding exprs for lambda arguments. for a in ps.arguments(skip_tensor_options=True): name = a.name arg_parser_expr = arg_parser_outputs[a.name].expr if has_toptions and name == 'self': # TODO: why this needs to be special case? inits.extend([ f'auto self = {arg_parser_expr};', ]) lambda_args_exprs[name] = name elif isinstance(a, PythonOutArgument) and len( a.outputs) > 1 and f.func.is_out_fn(): inits.extend([ f'auto out = {arg_parser_expr};', ]) for i, out_arg in enumerate(a.outputs): lambda_args_exprs[out_arg.name] = f'out[{i}]' elif str(a.type) == 'Dimname[]?': inits.extend([ f'auto __{name} = {arg_parser_expr};', f'c10::optional<DimnameList> {name} = \ __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;', ]) lambda_args_exprs[name] = name else: # default case - directly using PythonArgParser output expr lambda_args_exprs[name] = arg_parser_expr # method's self is passed directly to python binding, rather than parsed if ps.method: lambda_args_exprs['self'] = 'self' # 2. special packing/checking for TensorOptions. tensor_options_args_names = list( map(lambda a: a.name, ps.tensor_options_args)) if has_toptions: if f.func.is_out_fn(): raise RuntimeError(f'{f.func}: tensor options with output arg') for a in ps.tensor_options_args: if a.name not in TENSOR_OPTIONS_FIELDS: raise RuntimeError( f'{f.func}: unrecognized tensor options field \'{a.name}\' in python binding arguments' ) if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name): raise RuntimeError( f'{f.func}: unrecognized type \'{str(a.type)}\' for tensor options field \'{a.name}\'' ) if not all( map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())): raise RuntimeError( f'{f.func}: incomplete tensor options args: {tensor_options_args_names}' ) inits.append(f'''\ const auto options = TensorOptions() .dtype({arg_parser_outputs['dtype'].expr}) //.device({arg_parser_outputs['device'].expr}) //.layout({arg_parser_outputs['layout'].expr}) .requires_grad({arg_parser_outputs['requires_grad'].expr}) .pinned_memory({arg_parser_outputs['pin_memory'].expr}); // torch::utils::maybe_initialize_cuda(options); ''') lambda_args_exprs['options'] = 'options' return DispatchLambdaArgumentExprs( exprs=tuple(map(lambda a: lambda_args_exprs[a.name], lambda_args)), inits=inits, )