Exemplo n.º 1
0
def emit_namedtuple_call(
    overloads: Sequence[PythonSignatureNativeFunctionPair]
) -> Tuple[List[str], Dict[str, str]]:
    """
    Generate block of named tuple type def inits, and add typeref snippets
    to declarations that use them
    """
    typenames: Dict[str, str] = {
    }  # map from unique name + field name lists to typedef name
    typedefs: List[str] = []  # typedef declarations and init code

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)
        if typename is None:
            typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
            typenames[tn_key] = typename
            typedefs.append(f"""\
static PyTypeObject* {typename} = get_namedtuple("{name}");""")

    return typedefs, typenames
Exemplo n.º 2
0
def emit_namedtuple_typedefs(
    overloads: Sequence[PythonSignatureNativeFunctionPair]
) -> Tuple[List[str], Dict[str, str]]:
    """
    Generate block of named tuple type def inits, and add typeref snippets
    to declarations that use them
    """
    flddefnames: Dict[str, str] = {
    }  # map from unique field name lists to field def name
    flddefs: List[str] = []  # field def declarations
    typenames: Dict[str, str] = {
    }  # map from unique name + field name lists to typedef name
    typedefs: List[str] = []  # typedef declarations and init code

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        fn_key = '_'.join(fieldnames)
        fieldsname = flddefnames.get(fn_key)
        if fieldsname is None:
            fieldsname = f'NamedTuple_fields{"" if not flddefs else len(flddefs)}'
            flddefnames[fn_key] = fieldsname
            fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames)
            flddefs.append(f"""\
static PyStructSequence_Field {fieldsname}[] = {{ {fields},  {{nullptr}} }};
""")

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)
        if typename is None:
            typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
            typenames[tn_key] = typename
            typedefs.append(f"""\
static PyTypeObject {typename};
static bool {typename}_initialized = false;
if (!{typename}_initialized) {{
  {typename}_initialized = true;
  static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, {fieldsname}, {len(fieldnames)} }};
  PyStructSequence_InitType(&{typename}, &desc);
  {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
}}
""")

    return flddefs + typedefs, typenames
Exemplo n.º 3
0
def generate_return_type_definition_and_map_entry(
    overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
    """
    Generate block of function in `python_return_types.cpp` to initialize
    and return named tuple for a native function which returns named tuple
    and relevant entry for the map in same file.
    """
    typenames: Dict[str, str] = {
    }  # map from unique name + field name lists to typedef name
    definitions: List[str] = []  # function defintion to register the typedef
    map_entries: List[str] = [
    ]  # C++ map entry of <function_name, function creates it namedtuple>

    for overload in overloads:
        fieldnames = namedtuple_fieldnames(overload.function.func.returns)
        if not fieldnames:
            continue

        fields = ', '.join(f'{{"{fn}", ""}}' for fn in fieldnames)

        name = cpp.name(overload.function.func)  # use @with_native_function?
        tn_key = gen_namedtuple_typename_key(overload.function)
        typename = typenames.get(tn_key)

        if typename is None:
            typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
            typenames[tn_key] = typename
            definitions.append(f"""\
PyTypeObject* get_{name}_namedtuple() {{
    static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
    static PyTypeObject {typename};
    static bool is_initialized = false;
    static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
    if (!is_initialized) {{
        PyStructSequence_InitType(&{typename}, &desc);
        {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
        is_initialized = true;
    }}
    return &{typename};
}}
""")
            map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')

    return definitions, map_entries
Exemplo n.º 4
0
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
    name = cpp.name(f.func)
    fieldnames = namedtuple_fieldnames(f.func.returns)
    return '_'.join([name] + fieldnames)