def emit_namedtuple_call( overloads: Sequence[PythonSignatureNativeFunctionPair] ) -> Tuple[List[str], Dict[str, str]]: """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ typenames: Dict[str, str] = { } # map from unique name + field name lists to typedef name typedefs: List[str] = [] # typedef declarations and init code for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' typenames[tn_key] = typename typedefs.append(f"""\ static PyTypeObject* {typename} = get_namedtuple("{name}");""") return typedefs, typenames
def emit_namedtuple_typedefs( overloads: Sequence[PythonSignatureNativeFunctionPair] ) -> Tuple[List[str], Dict[str, str]]: """ Generate block of named tuple type def inits, and add typeref snippets to declarations that use them """ flddefnames: Dict[str, str] = { } # map from unique field name lists to field def name flddefs: List[str] = [] # field def declarations typenames: Dict[str, str] = { } # map from unique name + field name lists to typedef name typedefs: List[str] = [] # typedef declarations and init code for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue fn_key = '_'.join(fieldnames) fieldsname = flddefnames.get(fn_key) if fieldsname is None: fieldsname = f'NamedTuple_fields{"" if not flddefs else len(flddefs)}' flddefnames[fn_key] = fieldsname fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames) flddefs.append(f"""\ static PyStructSequence_Field {fieldsname}[] = {{ {fields}, {{nullptr}} }}; """) name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'NamedTuple{"" if not typedefs else len(typedefs)}' typenames[tn_key] = typename typedefs.append(f"""\ static PyTypeObject {typename}; static bool {typename}_initialized = false; if (!{typename}_initialized) {{ {typename}_initialized = true; static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, {fieldsname}, {len(fieldnames)} }}; PyStructSequence_InitType(&{typename}, &desc); {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; }} """) return flddefs + typedefs, typenames
def generate_return_type_definition_and_map_entry( overloads: Sequence[PythonSignatureNativeFunctionPair], ) -> Tuple[List[str], List[str]]: """ Generate block of function in `python_return_types.cpp` to initialize and return named tuple for a native function which returns named tuple and relevant entry for the map in same file. """ typenames: Dict[str, str] = { } # map from unique name + field name lists to typedef name definitions: List[str] = [] # function defintion to register the typedef map_entries: List[str] = [ ] # C++ map entry of <function_name, function creates it namedtuple> for overload in overloads: fieldnames = namedtuple_fieldnames(overload.function.func.returns) if not fieldnames: continue fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames) name = cpp.name(overload.function.func) # use @with_native_function? tn_key = gen_namedtuple_typename_key(overload.function) typename = typenames.get(tn_key) if typename is None: typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}' typenames[tn_key] = typename definitions.append(f"""\ PyTypeObject* get_{name}_namedtuple() {{ static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }}; static PyTypeObject {typename}; static bool is_initialized = false; static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }}; if (!is_initialized) {{ PyStructSequence_InitType(&{typename}, &desc); {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr; is_initialized = true; }} return &{typename}; }} """) map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ') return definitions, map_entries
def gen_namedtuple_typename_key(f: NativeFunction) -> str: name = cpp.name(f.func) fieldnames = namedtuple_fieldnames(f.func.returns) return '_'.join([name] + fieldnames)