示例#1
0
def genCppH(hFilePath, cppFilePath, templateGlslPaths, tmpDirPath, env):
    print("hFilePath:{}".format(hFilePath))
    print("cppFilePath:{}".format(cppFilePath))
    h = "#pragma once\n"
    nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n"
    nsend = "\n} } } //namespace at::native::vulkan\n"

    h += nsbegin

    cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
    cpp += nsbegin

    for templateGlslPath in templateGlslPaths:
        name = getName(templateGlslPath)
        h += "extern const char* " + name + ";\n"
        cpp += "const char* " + name + " = \n"

        codeTemplate = CodeTemplate.from_file(templateGlslPath)
        srcPath = tmpDirPath + "/" + name + ".glsl"
        content = codeTemplate.substitute(env)

        lines = content.split("\n")
        for l in lines:
            if (len(l) < 1):
                continue
            cpp += "\"" + l + "\\n\"\n"

        cpp += ";\n"

    cpp += nsend
    h += nsend

    with open(hFilePath, "w") as f:
        f.write(h)
    with open(cppFilePath, "w") as f:
        f.write(cpp)
示例#2
0
#
#                         Method Impl Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

# python binding for all overloads of a particular function/method
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  ${method_header}
  static PythonArgParser parser({
    ${signatures}
  }, /*traceable=*/${traceable});

  ParsedArgs<${max_args}> parsed_args;
  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
  ${check_has_torch_function}
  switch (_r.idx) {
    ${dispatch}
  }
  ${method_footer}
}

""")

# handler for a single parsed signature - may be a single overload or
# a pair of overloads that whose signatures only differ in output params
# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
PY_VARIABLE_CASE = CodeTemplate("""\
case ${overload_index}: {
示例#3
0
def should_trace(f: NativeFunction) -> bool:
    # Operations involving Storage or Type are not traceable at the moment
    if any(
            str(arg.type) in {"Storage", "Type", "ConstQuantizerPtr"}
            for arg in f.func.schema_order_arguments()):
        return False
    # We can't trace functions which don't have any Tensor or TensorList returns
    if not any(r.type.is_tensor_like() for r in f.func.returns):
        return False
    return f.func.name.name.base not in DONT_RECORD_TRACE


SELECT = CodeTemplate("""\

if (${cond}) {
  ${true}
} else {
  ${false}
}
""")

OP_NAME = CodeTemplate("""\
op_name = c10::Symbol::fromQualString("aten::${trace_name}");
""")

# These functions have their names recorded under trace renamed,
RENAME_TRACE = {
    "zero": "zeros_like",  # replacing aten::zero_ with aten::zeros_like
    "fill": "full_like",  # replacing aten::fill_ with aten::full_like
}

from typing import Set
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.code_template import CodeTemplate

import yaml

# Safely load fast C Yaml loader/dumper if they are available
try:
    from yaml import CSafeLoader as Loader
except ImportError:
    from yaml import SafeLoader as Loader  # type: ignore[misc]

if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
  return $dtype_checks;
}"""
if_condition_template = CodeTemplate(if_condition_template_str)

selected_kernel_dtypes_h_template_str = """
#include <c10/core/ScalarType.h>
#include <c10/util/string_view.h>
#include <c10/macros/Macros.h>

namespace at {
inline constexpr bool should_include_kernel_dtype(
  const char *kernel_tag_str,
  at::ScalarType scalar_type
) {
  c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str);
  $body
  return false;
}
    instructions = 1
    constants = 2
    types = 3
    operators = 4
    register_size = 5


EXCLUDED_OP_SET = [
    "aten::full.names",
    "aten::full.out",
    "aten::full",
]

EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"]

ONE_INSTRUCTION = CodeTemplate("""
    Instruction{OpCode::${operator_name}, ${X}, ${N}},""")

INSTRUCTION_LIST = CodeTemplate("""std::vector<Instruction>({
        ${instruction_list}
    }), // instructions list""")

ONE_CONSTANT = CodeTemplate("""
    c10::IValue(${constant}),""")

CONSTANT_LIST = CodeTemplate("""std::vector<c10::IValue>({
        ${constant_list}
    }), // constants list""")

CONSTANTS_LIST_EMPTY = """std::vector<c10::IValue>(), // constants list"""

ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""")
示例#6
0
def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
    saved_variables: List[str] = []
    release_variables: List[str] = []
    saved_list_sizes: List[str] = []
    unpack: List[str] = []
    asserts: List[str] = []
    compute_index_ranges: List[str] = []
    getter_definitions: List[str] = []
    py_getsetdef_structs: List[str] = []

    for arg in info.args_with_derivatives:
        if (
            arg.type == "at::TensorList"
            or arg.type == "const c10::List<c10::optional<at::Tensor>> &"
        ):
            size = f"{arg.name}_size_"
            saved_list_sizes.append(f"size_t {arg.name}_size_;")
        else:
            size = "1"
        compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")

    def save_var(var: SavedAttribute, is_output: bool) -> None:
        name = var.nctype.name
        type = var.nctype.type
        should_append_getsetdef = True
        should_append_raw_getsetdef = False

        if (
            type == BaseCType(tensorT)
            or type == OptionalCType(BaseCType(tensorT))
            or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
            or (type == BaseCType(scalarT) and is_output)
        ):
            saved_variables.append(f"SavedVariable {name}_;")
            release_variables.append(f"{name}_.reset_data();")
            ptr = "shared_from_this()" if is_output else ""
            unpack.append(f"auto {name} = {name}_.unpack({ptr});")
            getter_definitions.append(
                GETTER_DEFINITION_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
                )
            )
            getter_definitions.append(
                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
                )
            )
            should_append_raw_getsetdef = True
        elif type == BaseCType(tensorListT):
            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
            saved_variables.append(f"bool {name}_released_ = false;")
            # Just clear() is sufficient, we don't need to loop and clear each variable.
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
            release_variables.append(f"{name}_.clear();")
            release_variables.append(f"{name}_released_ = true;")
            unpack.append(f"auto {name} = unpack_list({name}_);")
            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
            getter_definitions.append(
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
                )
            )
            getter_definitions.append(
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
                )
            )
            should_append_raw_getsetdef = True
        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
            saved_variables.append(f"std::vector<SavedVariable> {name}_;")
            saved_variables.append(f"bool {name}_released_ = false;")
            # Just clear() is sufficient, we don't need to loop and clear each variable.
            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
            release_variables.append(f"{name}_.clear();")
            release_variables.append(f"{name}_released_ = true;")
            unpack.append(f"auto {name} = unpack_opt_list({name}_);")
            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
            getter_definitions.append(
                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
                )
            )
            getter_definitions.append(
                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
                )
            )
            should_append_raw_getsetdef = True
        elif type == BaseCType(intArrayRefT):
            saved_variables.append(f"std::vector<int64_t> {name};")
            getter_definitions.append(
                GETTER_DEFINITION.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
                )
            )
        elif type == BaseCType(optionalIntArrayRefT):
            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
                )
            )
        elif type == OptionalCType(BaseCType(intArrayRefT)):
            saved_variables.append(f"c10::OptionalArray<int64_t> {name};")
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
                )
            )
        elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
            saved_variables.append(f"c10::OptionalArray<double> {name};")
            getter_definitions.append(
                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
                )
            )
        elif type == BaseCType(longT):
            saved_variables.append(f"{type.cpp_type()} {name} = 0;")
            getter_definitions.append(
                GETTER_DEFINITION.substitute(
                    op=info.op, name=name, body=GETTER_BODY_INT64_T
                )
            )
        elif type == BaseCType(stringT):
            saved_variables.append(f"std::string {name};")
            getter_definitions.append(
                GETTER_DEFINITION.substitute(
                    op=info.op, name=name, body=GETTER_BODY_STRING
                )
            )
        elif type == OptionalCType(BaseCType(stringT)):
            saved_variables.append(f"c10::optional<std::string> {name};")
            getter_definitions.append(
                GETTER_DEFINITION_OPT.substitute(
                    op=info.op, name=name, body=GETTER_BODY_STRING
                )
            )
        else:
            saved_variables.append(f"{type.cpp_type()} {name};")

            if type in MISC_GETTER_DEFS:
                getter_def, body = MISC_GETTER_DEFS[type]
                getter_definitions.append(
                    getter_def.substitute(op=info.op, name=name, body=body)
                )
            else:
                # Types we don't expose python bindings to yet:
                #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
                #   std::vector<std::vector<int64_t>>, std::vector<at::ScalarType>
                should_append_getsetdef = False

        if should_append_getsetdef:
            py_getsetdef_structs.append(
                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
            )
        if should_append_raw_getsetdef:
            py_getsetdef_structs.append(
                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
            )

    for var in info.all_saved_inputs:
        save_var(var, is_output=False)
    for var in info.all_saved_outputs:
        save_var(var, is_output=True)

    # lock the mutex when we release variables and in Node::apply to protect thread safety
    # see Note [Thread Safety on Autograd Node]
    if len(release_variables) > 0:
        thread_lock = "std::lock_guard<std::mutex> lock(mutex_);"
    else:
        thread_lock = ""

    if uses_retain_variables(info):
        will_release_variables = WILL_RELEASE_VARIABLES.substitute()
    else:
        will_release_variables = ""

    body: List[str] = []

    if uses_single_grad(info):
        body.append("const auto& grad = grads[0];")
    else:
        # Generate aliases for gradients named for returned values.
        body.extend(
            f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
            for name in info.used_named_gradients
        )

    def emit_derivative(
        derivative: Derivative,
        args_with_derivatives: Sequence[Binding],
    ) -> Tuple[bool, str]:
        formula = derivative.formula
        var_names = derivative.var_names
        if len(var_names) == 1:
            checks_any_grad_defined = False
            if "not_implemented" not in formula:
                matching_args = [
                    arg for arg in args_with_derivatives if arg.name == var_names[0]
                ]
                if len(matching_args) == 1:
                    # We can add undefined grad support if the input variable is a Tensor
                    arg = matching_args[0]
                    if isinstance(arg.argument, Argument) and str(
                        arg.argument.type
                    ) in ("Tensor", "Tensor?"):
                        formula = "any_grad_defined ? (" + formula + ") : Tensor()"
                        checks_any_grad_defined = True
            return (
                checks_any_grad_defined,
                DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula),
            )
        else:
            if "grad_input_mask" in formula:
                masks = [f"should_compute_output({{ {n}_ix }})," for n in var_names]
                grad_input_mask = GRAD_INPUT_MASK.substitute(
                    masks=masks, n=len(var_names)
                )
            else:
                grad_input_mask = ""
            idx_ranges = ", ".join(f"{n}_ix" for n in var_names)
            copy_ranges: List[str] = []
            for i, n in enumerate(var_names):
                copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
            return False, DERIVATIVE_MULTI.substitute(
                idx_ranges=idx_ranges,
                copy_ranges=copy_ranges,
                derivative=formula,
                grad_input_mask=grad_input_mask,
            )

    body.extend(unpack)
    need_any_grad_defined_var = False
    for derivative in info.derivatives:
        checks_any_grad_defined, derivative_text = emit_derivative(
            derivative, info.args_with_derivatives
        )
        body.append(derivative_text)
        need_any_grad_defined_var |= checks_any_grad_defined
    # Since single-output derivative formulas need to check if grads are
    # defined, only perform the check once, before all the formulas
    if need_any_grad_defined_var:
        body.insert(
            -len(info.derivatives),
            "bool any_grad_defined = any_variable_defined(grads);",
        )

    if info.name in UNTRACEABLE_FUNCTIONS:
        superclass = "Node"
    else:
        superclass = "TraceableFunction"

    all_getsetdef_structs = (
        ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
    )
    all_getter_definitions = "\n".join(getter_definitions)

    return template.substitute(
        op=info.op,
        compute_index_ranges=compute_index_ranges,
        saved_variables=saved_variables,
        release_variables=release_variables,
        saved_list_sizes=saved_list_sizes,
        asserts=asserts,
        thread_lock=thread_lock,
        will_release_variables=will_release_variables,
        body=body,
        superclass=superclass,
        all_getter_definitions=all_getter_definitions,
        all_getsetdef_structs=all_getsetdef_structs,
    )
示例#7
0
    tensorT,
)
from torchgen.code_template import CodeTemplate
from torchgen.model import Argument
from torchgen.utils import FileManager

from .gen_inplace_or_view_type import VIEW_FUNCTIONS

FUNCTION_DECLARATION = CodeTemplate(
    """\
struct TORCH_API ${op} : public ${superclass} {
  using ${superclass}::${superclass};
  variable_list apply(variable_list&& grads) override;
  std::string name() const override { return "${op}"; }
  void release_variables() override {
    ${thread_lock}
    ${release_variables}
  }
  ${will_release_variables}
  ${saved_variables}
  ${saved_list_sizes}
};
"""
)

WILL_RELEASE_VARIABLES = CodeTemplate(
    """\
bool retain_variables = true;
void will_release_variables() override {
  retain_variables = false;
}
"""
示例#8
0
def _read_template(template_fn: str) -> CodeTemplate:
    return CodeTemplate.from_file(template_fn)
示例#9
0
def gen_dispatcher_registrations(
    fm: FileManager,
    output_dir: str,
    class_name: str,
    backend_indices: Dict[DispatchKey, BackendIndex],
    grouped_native_functions: Sequence[Union[NativeFunction,
                                             NativeFunctionsGroup]],
    backend_dispatch_key: DispatchKey,
    dispatch_key: DispatchKey,
    selector: "SelectiveBuilder",
    # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
    build_in_tree: bool = False,
    per_operator_headers: bool = False,
    backend_name: str = "",
    eager_registration: bool = True,
) -> None:
    headers = [
        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
    ]
    if build_in_tree:
        external_backend_headers_str = "\n".join(f"#include <{h}>"
                                                 for h in headers)
    else:
        external_backend_headers_str = "\n".join(f'#include "{h}"'
                                                 for h in headers)

    assert class_name is not None
    backend_index = backend_indices[dispatch_key]

    dispatch_registrations_body = list(
        concatMap(
            dest.RegisterDispatchKey(
                backend_index,
                Target.REGISTRATION,
                selector,
                rocm=False,
                class_method_name=f"{class_name}",
                skip_dispatcher_op_registration=False,
            ),
            grouped_native_functions,
        ))
    newline = "\n"
    ns_helper = NamespaceHelper(namespace_str="at")
    deferred_dispatch_registrations = ""
    static_init_dispatch_registrations = ""
    if eager_registration:
        static_template = CodeTemplate("""\
TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
    $dispatch_registrations_body
};""")
        static_init_dispatch_registrations = static_template.substitute(
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )
    else:
        deferred_template = CodeTemplate("""\
TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
    static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
    $dispatch_registrations_body
}""")
        deferred_dispatch_registrations = deferred_template.substitute(
            backend_name=backend_name,
            dispatch_key=dispatch_key,
            dispatch_registrations_body=dispatch_registrations_body,
        )

    fm.write_with_template(
        f"Register{dispatch_key}.cpp",
        "RegisterDispatchKey.cpp",
        lambda: {
            "extra_cuda_headers":
            "",
            "external_backend_headers":
            external_backend_headers_str,
            "ops_headers":
            "#include <ATen/Functions.h>" if not per_operator_headers else "",
            "DispatchKey":
            dispatch_key,
            "dispatch_namespace":
            dispatch_key.lower(),
            "dispatch_headers":
            dest.gen_registration_headers(backend_index,
                                          per_operator_headers=
                                          per_operator_headers,
                                          rocm=False),
            "dispatch_definitions":
            fm.substitute_with_template(
                "RegisterDispatchDefinitions.ini",
                lambda: {
                    "ns_prologue":
                    ns_helper.prologue,
                    "ns_epilogue":
                    ns_helper.epilogue,
                    "static_init_dispatch_registrations":
                    static_init_dispatch_registrations,
                    "deferred_dispatch_registrations":
                    deferred_dispatch_registrations,
                    "dispatch_helpers":
                    dest.gen_registration_helpers(backend_index),
                    "dispatch_namespace":
                    dispatch_key.lower(),
                    "dispatch_namespaced_definitions":
                    "",
                    "dispatch_anonymous_definitions":
                    list(
                        concatMap(
                            dest.RegisterDispatchKey(
                                backend_index,
                                Target.ANONYMOUS_DEFINITION,
                                selector,
                                rocm=False,
                                class_method_name=f"{class_name}",
                                skip_dispatcher_op_registration=False,
                            ),
                            grouped_native_functions,
                        )),
                },
            ).split(newline),
        },
    )
示例#10
0
    "adjoint",
    "matrix_H",
})

# These are the functions we consider views for the purposes of validating
# StorageImpl and TensorImpl in gen_variable_type.
# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a
# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view
# See NOTE [Unsafe View] for more info.
ALL_VIEW_FUNCTIONS = {
    **VIEW_FUNCTIONS,
    "_unsafe_view": "self",
}

ARRAYREF_TO_VEC = CodeTemplate("""\
auto ${vec} = ${arg}.vec();
""")

OPTIONAL_TO_VAL = CodeTemplate("""\
auto ${val} = ${arg}.value_or(${default});
""")

CALL_DISPATCH = CodeTemplate("""\
at::_ops::${unambiguous_name}::call(${unpacked_args})""")

SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
    """\
std::function<at::Tensor(const at::Tensor&)> func=nullptr;
if (${is_view_with_metadata_change} || !self.unsafeGetTensorImpl()->support_as_strided()) {
  ${replay_view_func}
}
示例#11
0
def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
    print(
        "hFilePath:{} cppFilePath:{} srcDirPath:{} glslcPath:{} tmpDirPath:{}".
        format(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath))

    vexs = glob.glob(os.path.join(srcDirPath, '**', '*.glsl'), recursive=True)
    templateSrcPaths = []
    for f in vexs:
        if len(f) > 1:
            templateSrcPaths.append(f)
            templateSrcPaths.sort()
    print("templateSrcPaths:{}".format(templateSrcPaths))

    spvPaths = []
    for templateSrcPath in templateSrcPaths:
        print("templateSrcPath {}".format(templateSrcPath))
        name = getName(templateSrcPath).replace("_glsl", "")
        print("name {}".format(name))

        codeTemplate = CodeTemplate.from_file(templateSrcPath)
        srcPath = tmpDirPath + "/" + name + ".glsl"
        content = codeTemplate.substitute(env)
        with open(srcPath, 'w') as f:
            f.write(content)

        spvPath = tmpDirPath + "/" + name + ".spv"
        print("spvPath {}".format(spvPath))

        cmd = [
            glslcPath, "-fshader-stage=compute", srcPath, "-o", spvPath,
            "--target-env=vulkan1.0", "-Werror"
        ]

        print("\nglslc cmd:", cmd)

        subprocess.check_call(cmd)
        spvPaths.append(spvPath)

    h = "#pragma once\n"
    h += "#include <stdint.h>\n"
    nsbegin = "\nnamespace at { namespace native { namespace vulkan { \n"
    nsend = "\n} } } //namespace at::native::vulkan\n"

    h += nsbegin

    cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
    cpp += nsbegin

    for spvPath in spvPaths:
        name = getName(spvPath)
        name_len = name + "_len"
        h += "extern const uint32_t {}[];\n".format(name)
        h += "extern const uint32_t {};\n".format(name_len)

        cpp += "const uint32_t " + name + "[] = {\n"
        sizeBytes = 0
        print("spvPath:{}".format(spvPath))
        with open(spvPath, 'rb') as f:
            for word in array.array('I', f.read()):
                cpp += "{},\n".format(word)
                sizeBytes += 4
            cpp += "};\n"
        cpp += "const uint32_t {} = {};\n".format(name_len, sizeBytes)

    cpp += nsend
    h += nsend

    with open(hFilePath, "w") as f:
        f.write(h)
    with open(cppFilePath, "w") as f:
        f.write(cpp)
示例#12
0
parser.add_argument("--output_prefix", default="", help="")
parser.add_argument(
    "--install_dir", default=".", help="where to put generated file")
parser.add_argument("--aten_root", default="", help="root directory of aten")
args, _ = parser.parse_known_args()

if args.aten_root:
    if not os.path.exists(args.aten_root):
        raise ValueError('aten_root ({}) does not exist'.format(
            args.aten_root))
    sys.path.insert(0, os.path.join(args.aten_root, '..'))
    from torchgen.code_template import CodeTemplate as CT
else:
    from torchgen.code_template import CodeTemplate as CT

OP_TEMPLATE = CT.from_file(
    os.path.join(args.template_dir, 'aten_op_template.h'))


try:
    # use faster C loader if available
    from yaml import CSafeLoader as Loader
except ImportError:
    from yaml import SafeLoader as Loader  # type: ignore[misc]


def write(filename, s):
    with open(filename, "w") as f:
        f.write(s)


def read(filename):