Esempio n. 1
0
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    num_shards = 5
    fm.write('python_functions.h', lambda: {
        'generated_comment': f'@generated from {fm.template_dir}/python_functions.h',
        'shard_forward_declare': [
            f"void initialize_autogenerated_functions_{i}();"
            for i in range(num_shards)
        ],
        'shard_call': [
            f"initialize_autogenerated_functions_{i}();"
            for i in range(num_shards)
        ]
    })

    infos = list(filter(lambda info: info.args_with_derivatives, differentiability_infos))
    fm.write_sharded(
        'python_functions.cpp',
        infos,
        key_fn=lambda info: info.name,
        base_env={
            'generated_comment': f'@generated from {fm.template_dir}/python_functions.cpp',
        },
        env_callable=lambda info: {
            'py_function_initializers': [process_function(info, PY_FUNCTION_DEFINITION)],
            'py_function_props_and_getters': [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)],
        },
        num_shards=num_shards,
        sharded_keys={'py_function_initializers', 'py_function_props_and_getters'}
    )
Esempio n. 2
0
def create_python_bindings_sharded(
        fm: FileManager, pairs: Sequence[PythonSignatureNativeFunctionPair],
        pred: Callable[[NativeFunction], bool], module: Optional[str],
        filename: str, *, method: bool, num_shards: int) -> None:
    """Generates Python bindings to ATen functions"""
    grouped = group_filter_overloads(pairs, pred)

    def key_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> str:
        return str(kv[0])

    def env_func(
        kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
    ) -> Dict[str, List[str]]:
        return {
            'py_forwards': list(forward_decls(kv[0], kv[1], method=method)),
            'py_methods': [method_impl(kv[0], module, kv[1], method=method)],
            'py_method_defs':
            [method_def(kv[0], module, kv[1], method=method)],
        }

    fm.write_sharded(
        filename,
        grouped.items(),
        base_env={
            'generated_comment':
            '@' + f'generated from {fm.template_dir}/{filename}',
        },
        key_fn=key_func,
        env_callable=env_func,
        num_shards=num_shards,
        sharded_keys={'py_forwards', 'py_methods', 'py_method_defs'})
Esempio n. 3
0
def gen_variable_type(
    out: str,
    native_yaml_path: str,
    fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str,
) -> None:

    """VariableType.h and VariableType.cpp body

    This is the at::Type subclass for differentiable tensors. The
    implementation of each function dispatches to the base tensor type to
    compute the output. The grad_fn is attached to differentiable functions.
    """
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write('VariableType.h', lambda: {
        'generated_comment': "@" f'generated from {template_path}/VariableType.h'
    })

    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm.write_sharded(
        'VariableType.cpp',
        [fn for fn in fns_with_diff_infos if use_derived(fn)],
        key_fn=lambda fn: cpp.name(fn.func.func),
        base_env={
            'generated_comment':
            "@" f'generated from {template_path}/VariableType.cpp',
        },
        env_callable=gen_variable_type_func,
        num_shards=5,
        sharded_keys={'type_derived_method_definitions', 'wrapper_registrations'}
    )
Esempio n. 4
0
def gen_trace_type(out: str, native_yaml_path: str,
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    native_functions = parse_native_yaml(native_yaml_path).native_functions
    fm.write_sharded('TraceType.cpp', [
        fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER
    ],
                     key_fn=lambda fn: cpp.name(fn.func),
                     base_env={
                         'generated_comment':
                         f'@generated from {template_path}/TraceType.cpp',
                     },
                     env_callable=gen_trace_type_func,
                     num_shards=5,
                     sharded_keys={
                         'trace_method_definitions',
                         'trace_wrapper_registrations'
                     })