Exemplo n.º 1
0
def node_ctor_arg_rvalue_string(arg: NamedCType, schema: LazyIrSchema) -> str:
    """
    Given a NamedCType from a lazy IR schema,
    generate a c++ string for materializing an rvalue of that arg for passing into
    a lazy Node constructor.
    """

    if isValueType(arg.type):
        if isinstance(arg.type, BaseCType):
            if arg.name in schema.wrapped_scalar_names:
                return f"torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen({arg.name})"
            return f"lazy_{arg.name}->GetIrValue()"
        elif isinstance(arg.type, OptionalCType):
            if arg.name in schema.wrapped_scalar_names:
                return f"{arg.name} ? " \
                    f"c10::make_optional(torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(*{arg.name})) : " \
                    "c10::nullopt"
            return f"lazy_{arg.name} ? " \
                   f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : " \
                   "c10::nullopt"
        else:
            raise AssertionError(
                "TODO not sure if there are other valid types to handle here")
    else:
        if isinstance(arg.type, VectorCType) and isinstance(
                arg.type.elem, BaseCType):
            return f"std::vector<{arg.type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
        elif (isinstance(arg.type, OptionalCType)
              and isinstance(arg.type.elem, VectorCType)
              and isinstance(arg.type.elem.elem, BaseCType)):
            return f"torch::lazy::ToOptionalVector<{arg.type.elem.elem.type}>({arg.name})"
        else:
            return f"{arg.name}"
Exemplo n.º 2
0
def ts_lowering_body(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
    # for now, we just want one IR class decl and soon after also the method defs
    # and we use the functional version not out/inplace.
    func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
    schema = LazyIrSchema(func)

    emplace_arguments = []
    for value in schema.positional_arg_types:
        if isValueType(value.type):
            if isinstance(value.type, OptionalCType):
                emplace_arguments.append(
                    f"has_{value.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
                )
                continue
            emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
            continue
        emplace_arguments.append(f'"{value.name}", {value.name}_')

    emplace_arguments_str = "\n    ".join(
        [f"arguments.emplace_back({a});" for a in emplace_arguments])
    emplace_kwarg_values = [
        f'loctx->GetOutputOp(operand({i}))'
        for i in range(len(schema.keyword_values))
    ]
    emplace_kwarg_scalars = [
        f'"{t.name}", {t.name}_' for t in schema.keyword_scalars
    ]
    assert len(
        schema.keyword_values
    ) == 0, "TODO the logic for operand(i) is broken if there are kw values"
    emplace_kwarguments = "\n    ".join([
        f"kwarguments.emplace_back({a});"
        for a in emplace_kwarg_values + emplace_kwarg_scalars
    ])
    return f"""\