Ejemplo n.º 1
0
def create_python_bindings(
    fm: FileManager,
    pairs: Sequence[PythonSignatureNativeFunctionPair],
    pred: Callable[[NativeFunction], bool],
    module: Optional[str],
    filename: str,
    *,
    method: bool,
) -> None:
    """Generates Python bindings to ATen functions"""
    py_methods: List[str] = []
    py_method_defs: List[str] = []
    py_forwards: List[str] = []

    grouped: Dict[BaseOperatorName,
                  List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
    for pair in pairs:
        if pred(pair.function):
            grouped[pair.function.func.name.name].append(pair)

    for name in sorted(grouped.keys(), key=lambda x: str(x)):
        overloads = grouped[name]
        py_methods.append(method_impl(name, module, overloads, method=method))
        py_method_defs.append(
            method_def(name, module, overloads, method=method))
        py_forwards.extend(forward_decls(name, overloads, method=method))

    fm.write_with_template(
        filename, filename, lambda: {
            'generated_comment': '@' +
            f'generated from {fm.template_dir}/{filename}',
            'py_forwards': py_forwards,
            'py_methods': py_methods,
            'py_method_defs': py_method_defs,
        })
def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path)
    mappings = (
        (is_py_torch_function, 'torch._C._VariableFunctions'),
        (is_py_nn_function, 'torch._C._nn'),
        (is_py_linalg_function, 'torch._C._linalg'),
        (is_py_variable_method, 'torch.Tensor'),
    )
    annotated_args: List[str] = []
    for pred, namespace in mappings:
        groups: Dict[BaseOperatorName,
                     List[NativeFunction]] = defaultdict(list)
        for f in native_functions:
            if not should_generate_py_binding(f) or not pred(f):
                continue
            groups[f.func.name.name].append(f)
        for group in groups.values():
            for f in group:
                annotated_args.append(f'{namespace}.{gen_annotated_args(f)}')

    template_path = os.path.join(autograd_dir, 'templates')
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write_with_template(
        'annotated_fn_args.py', 'annotated_fn_args.py', lambda: {
            'annotated_args': textwrap.indent('\n'.join(annotated_args), '    '
                                              ),
        })
def gen_autograd_functions_lib(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:
    """Functions.h and Functions.cpp body

    These contain the auto-generated subclasses of torch::autograd::Node
    for each every differentiable torch function.
    """

    # only create an autograd function if we are actually going to calculate a derivative
    infos = list(
        filter(lambda info: info.args_with_derivatives,
               differentiability_infos))
    declarations = list(
        map(lambda f: process_function(f, FUNCTION_DECLARATION), infos))
    definitions = list(
        map(lambda f: process_function(f, FUNCTION_DEFINITION), infos))

    file_basename = 'Functions'
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    for suffix in ['.h', '.cpp']:
        fname = file_basename + suffix
        fm.write_with_template(
            fname, fname, lambda: {
                'generated_comment': '@' + f'generated from {fm.template_dir}/'
                + fname,
                'autograd_function_declarations': declarations,
                'autograd_function_definitions': definitions,
            })
Ejemplo n.º 4
0
def main() -> None:
    """
    # Inject file into template dataset.pyi.in
    TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
          interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
    """

    iterDP_file_path: str = "datapipes/iter"
    iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
    iterDP_deprecated_files: Set[str] = set()
    iterDP_method_to_special_output_type: Dict[str, str] = {
        "demux": "List[IterDataPipe]",
        "fork": "List[IterDataPipe]"
    }

    iter_method_definitions = get_method_definitions(
        iterDP_file_path, iterDP_files_to_exclude, iterDP_deprecated_files,
        "IterDataPipe", iterDP_method_to_special_output_type)
    mapDP_file_path: str = "datapipes/map"
    mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
    mapDP_deprecated_files: Set[str] = set()
    mapDP_method_to_special_output_type: Dict[str, str] = {}

    map_method_definitions = get_method_definitions(
        mapDP_file_path, mapDP_files_to_exclude, mapDP_deprecated_files,
        "MapDataPipe", mapDP_method_to_special_output_type)

    fm = FileManager(install_dir='.', template_dir='.', dry_run=False)
    fm.write_with_template(filename="dataset.pyi",
                           template_fn="dataset.pyi.in",
                           env_callable=lambda: {
                               'IterDataPipeMethods': iter_method_definitions,
                               'MapDataPipeMethods': map_method_definitions
                           })
Ejemplo n.º 5
0
def gen_variable_factories(out: str, native_yaml_path: str, template_path: str) -> None:
    native_functions = parse_native_yaml(native_yaml_path)
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    fm.write_with_template('variable_factories.h', 'variable_factories.h', lambda: {
        'generated_comment': '@' + f'generated from {fm.template_dir}/variable_factories.h',
        'function_definitions': list(mapMaybe(process_function, native_functions)),
    })
Ejemplo n.º 6
0
def gen_trace_type_shard(
    fm: FileManager, native_functions: Sequence[NativeFunction], suffix: str
) -> None:
    fm.write_with_template('TraceType%s.cpp' % suffix, 'TraceType.cpp', lambda: {
        'generated_comment': f'@generated from {fm.template_dir}/TraceType.cpp',
        'trace_method_definitions': list(mapMaybe(method_definition, native_functions)),
        'trace_wrapper_registrations': list(mapMaybe(method_registration, native_functions)),
    })
Ejemplo n.º 7
0
def gen_variable_type_shard(
    fm: FileManager,
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_name: str,
    output_name: str,
) -> None:
    type_declarations: List[str] = []
    type_definitions: List[str] = []
    wrapper_registrations: List[str] = []

    for fn in fns_with_infos:
        f = fn.func
        name = cpp.name(f.func)
        formals = gen_formals(f)

        type_declarations.append(
            METHOD_DECLARATION.substitute(
                return_type=cpp.returns_type(f.func.returns),
                type_wrapper_name=type_wrapper_name(f),
                formals=formals,
            ))

        if name not in MANUAL_AUTOGRAD and dispatch_strategy(
                fn) == 'use_derived':
            type_definitions.append(
                METHOD_DEFINITION.substitute(
                    return_type=cpp.returns_type(f.func.returns),
                    type_wrapper_name=type_wrapper_name(f),
                    type_definition_body=emit_body(fn),
                    formals=formals,
                ))
            wrapper_registrations.append(gen_wrapper_registration(f))

        # See Note [Manual Backend kernels]
        assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
        # If you want to register a kernel to Autograd, you must make the op abstract.
        # In other words, this op must have dispatch section in native_functions.yaml.
        if name in MANUAL_AUTOGRAD_AND_TRACER or (fn.info
                                                  and fn.info.has_derivatives):
            msg = (
                f'There\'s a formula for {name}(or its functional variant) in derivatives.yaml. '
                f'It\'s required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA '
                f'or DefaultBackend in native_functions.yaml. Please see '
                f'https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword '
                f'for instructions to choose the right dispatch keyword.')
            assert f.is_abstract, msg

    fm.write_with_template(
        output_name, template_name, lambda: {
            'generated_comment': '@' +
            f'generated from {fm.template_dir}/{template_name}',
            'type_derived_method_declarations': type_declarations,
            'type_derived_method_definitions': type_definitions,
            'wrapper_registrations': wrapper_registrations,
        })
Ejemplo n.º 8
0
def gen_inplace_or_view_type_shard(
    fm: FileManager, fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo], suffix: str
) -> None:

    filtered_fns_with_infos = list(filter(use_derived, fns_with_infos))

    fm.write_with_template('ADInplaceOrViewType%s.cpp' % suffix, 'ADInplaceOrViewType.cpp', lambda: {
        'generated_comment': f'@generated from {fm.template_dir}/ADInplaceOrViewType.cpp',
        'inplace_or_view_method_definitions': list(mapMaybe(inplace_or_view_method_definition, filtered_fns_with_infos)),
        'inplace_or_view_wrapper_registrations': list(mapMaybe(inplace_or_view_method_registration, filtered_fns_with_infos)),
    })
Ejemplo n.º 9
0
def main() -> None:
    """
    .pyi generation for functional DataPipes Process
    # 1. Find files that we want to process (exclude the ones who don't)
    # 2. Parse method name and signature
    # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
    # 4. Inject file into template dataset.pyi.in
    TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
          interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
    """

    files_to_exclude = {"__init__.py", "utils.py"}
    deprecated_files = {
        "httpreader.py", "linereader.py", "tararchivereader.py",
        "ziparchivereader.py"
    }

    os.chdir(str(pathlib.Path(__file__).parent.resolve()))
    iter_datapipes_file_path = "datapipes/iter"
    file_paths = find_file_paths(
        [iter_datapipes_file_path],
        files_to_exclude=files_to_exclude.union(deprecated_files))
    methods_and_signatures, methods_and_class_names = parse_datapipe_files(
        file_paths)

    method_definitions = []
    for method_name, signature in methods_and_signatures.items():
        class_name = methods_and_class_names[method_name]
        method_definitions.append(
            f"# Functional form of '{class_name}'\ndef {method_name}({signature}): ..."
        )
    method_definitions.sort(
        key=lambda s: s.split('\n')[1])  # sorting based on method_name

    fm = FileManager(install_dir='.', template_dir='.', dry_run=False)
    fm.write_with_template(
        filename="dataset.pyi",
        template_fn="dataset.pyi.in",
        env_callable=lambda: {'IterableDataPipeMethods': method_definitions})
Ejemplo n.º 10
0
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str,
            fm: FileManager) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str,
                                  List[str]] = collections.defaultdict(list)
    unsorted_function_hints.update({
        'set_flush_denormal':
        ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
        'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
        'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
        'numel': ['def numel(self: Tensor) -> _int: ...'],
        'clamp': [
            "def clamp(self, min: _float=-inf, max: _float=inf,"
            " *, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
        'as_tensor': [
            "def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."
        ],
        'get_num_threads': ['def get_num_threads() -> _int: ...'],
        'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
        'init_num_threads': ['def init_num_threads() -> None: ...'],
        'get_num_interop_threads':
        ['def get_num_interop_threads() -> _int: ...'],
        'set_num_interop_threads':
        ['def set_num_interop_threads(num: _int) -> None: ...'],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        'tensor':
        ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        'sparse_coo_tensor': [
            'def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
            ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
            ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'
        ],
        '_sparse_coo_tensor_unsafe': [
            'def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
            ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
            ' requires_grad: bool = False) -> Tensor: ...'
        ],
        'range': [
            'def range(start: Number, end: Number,'
            ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
            .format(FACTORY_PARAMS)
        ],
        'arange': [
            'def arange(start: Number, end: Number, step: Number, *,'
            ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS),
            'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
            .format(FACTORY_PARAMS),
            'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
            .format(FACTORY_PARAMS)
        ],
        'randint': [
            'def randint(low: _int, high: _int, size: _size, *,'
            ' generator: Optional[Generator]=None, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS), 'def randint(high: _int, size: _size, *,'
            ' generator: Optional[Generator]=None, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS)
        ],
        'full': [
            'def full(size: _size, fill_value: Number, *,'
            ' out: Optional[Tensor]=None,'
            ' layout: _layout=strided, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS), 'def full(size: _size, fill_value: Number, *,'
            ' names: List[Union[str, None]],'
            ' layout: _layout=strided, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS)
        ],
        'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
        'nonzero': [
            'def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
            'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'
        ],
    })
    for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
    for binop in ['add', 'sub']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'
            .format(binop))

    function_signatures = load_signatures(native_yaml_path,
                                          deprecated_yaml_path,
                                          method=False,
                                          pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[
        str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        'size':
        ['def size(self) -> Size: ...', 'def size(self, _int) -> _int: ...'],
        'stride': [
            'def stride(self) -> Tuple[_int]: ...',
            'def stride(self, _int) -> _int: ...'
        ],
        'new_ones': [
            'def new_ones(self, size: _size, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS)
        ],
        'new_tensor': [
            "def new_tensor(self, data: Any, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        'new': [
            'def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
            'def new(self, storage: Storage) -> Tensor: ...',
            'def new(self, other: Tensor) -> Tensor: ...',
            'def new(self, size: _size, *, {}) -> Tensor: ...'.format(
                DEVICE_PARAM),
        ],
        '__init__': [
            'def __init__(self, *args: Any, {}) -> None: ...'.format(
                DEVICE_PARAM),
            'def __init__(self, storage: Storage) -> None: ...',
            'def __init__(self, other: Tensor) -> None: ...',
            'def __init__(self, size: _size, *, {}) -> None: ...'.format(
                DEVICE_PARAM),
        ],
        'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        # clamp has no default values in the Declarations
        'clamp': [
            "def clamp(self, min: _float=-inf, max: _float=inf,"
            " *, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
        'clamp_':
        ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
        '__getitem__':
        ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        '__setitem__': [
            "def __setitem__(self, {}, val: Union[Tensor, Number])"
            " -> None: ...".format(INDICES)
        ],
        'tolist': ['def tolist(self) -> List: ...'],
        'requires_grad_':
        ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
        'element_size': ['def element_size(self) -> _int: ...'],
        'data_ptr': ['def data_ptr(self) -> _int: ...'],
        'dim': ['def dim(self) -> _int: ...'],
        'nonzero':
        ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
        'numel': ['def numel(self) -> _int: ...'],
        'ndimension': ['def ndimension(self) -> _int: ...'],
        'nelement': ['def nelement(self) -> _int: ...'],
        'cuda': [
            'def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'
        ],
        'numpy': ['def numpy(self) -> Any: ...'],
        'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
        'map_':
        ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
        'storage': ['def storage(self) -> Storage: ...'],
        'type': [
            'def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
            'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
        ],
        'get_device': ['def get_device(self) -> _int: ...'],
        'contiguous': [
            'def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'
        ],
        'is_contiguous': [
            'def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'
        ],
        'is_cuda': ['is_cuda: _bool'],
        'is_leaf': ['is_leaf: _bool'],
        'is_sparse': ['is_sparse: _bool'],
        'is_quantized': ['is_quantized: _bool'],
        'is_meta': ['is_meta: _bool'],
        'is_mkldnn': ['is_mkldnn: _bool'],
        'is_vulkan': ['is_vulkan: _bool'],
        'storage_offset': ['def storage_offset(self) -> _int: ...'],
        'to': [
            'def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
            'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
            'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
            'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
        ],
        'item': ["def item(self) -> Number: ..."],
        'copy_': [
            "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
        ],
        'set_': [
            'def set_(self, storage: Storage, offset: _int, size: _size, stride: _size) -> Tensor: ...',
            'def set_(self, storage: Storage) -> Tensor: ...'
        ],
        'split': [
            'def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
            'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'
        ],
    })
    for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
        for inplace in [False, True]:
            out_suffix = ', *, out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number]{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    for binop in ['add', 'sub']:
        for inplace in [False, True]:
            out_suffix = ', out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number], '
                '*, alpha: Optional[Number]=1{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    simple_conversions = [
        'byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long',
        'short', 'bool', 'bfloat16'
    ]
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append(
            'def {}(self) -> Tensor: ...'.format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(native_yaml_path,
                                               deprecated_yaml_path,
                                               method=True,
                                               skip_deprecated=True,
                                               pyi=True)
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures,
                                                      method=True)

    for group in sorted(tensor_method_sig_groups,
                        key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = '__{}__'.format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = [
        '{} = {}'.format(name, defn) for name, defn in namedtuples.items()
    ]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: These are deprecated, maybe we shouldn't type hint them
    legacy_storage_base_hints = []
    dt = ('Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Bool',
          'Half', 'BFloat16', 'ComplexDouble', 'ComplexFloat', 'QUInt8',
          'QInt8', 'QInt32', 'QUInt4x2')
    for c in dt:
        legacy_storage_base_hints.append(
            'class {}StorageBase(object): ...'.format(c))
    for c in dt:
        legacy_storage_base_hints.append(
            'class Cuda{}StorageBase(object): ...'.format(c))

    legacy_class_hints = []
    for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
              'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor',
              'BoolTensor'):
        legacy_class_hints.append('class {}(Tensor): ...'.format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = [
        '{}: dtype = ...'.format(n) for n in [
            'float32', 'float', 'float64', 'double', 'float16', 'bfloat16',
            'half', 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64',
            'long', 'complex32', 'complex64', 'cfloat', 'complex128',
            'cdouble', 'quint8', 'qint8', 'qint32', 'bool', 'quint4x2'
        ]
    ]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [
        name for name, hint in unsorted_function_hints.items() if hint
    ]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split('\n')
    all_directive[0] = '__all__ = {}'.format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        'namedtuple_defs': namedtuple_defs,
        'function_hints': function_hints,
        'tensor_method_hints': tensor_method_hints,
        'legacy_class_hints': legacy_class_hints,
        'legacy_storage_base_hints': legacy_storage_base_hints,
        'dtype_class_hints': dtype_class_hints,
        'all_directive': all_directive
    }
    fm.write_with_template(
        'torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
            'generated_comment': '@' +
            'generated from torch/_C/__init__.pyi.in',
            **env,
        })
    fm.write_with_template(
        'torch/_C/_VariableFunctions.pyi',
        'torch/_C/_VariableFunctions.pyi.in', lambda: {
            'generated_comment': '@' +
            'generated from torch/_C/_VariableFunctions.pyi.in',
            **env,
        })
    fm.write_with_template(
        'torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
            'generated_comment': '@' +
            'generated from torch/_C/_VariableFunctions.pyi.in',
            **env,
        })
    gen_nn_functional(fm)
Ejemplo n.º 11
0
def gen_nn_functional(fm: FileManager) -> None:
    # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
    # through an `_add_docstr` call
    imports = [
        'conv1d',
        'conv2d',
        'conv3d',
        'conv_transpose1d',
        'conv_transpose2d',
        'conv_transpose3d',
        'conv_tbc',
        'avg_pool1d',
        'relu_',
        'selu_',
        'celu_',
        'rrelu_',
        'pixel_shuffle',
        'pixel_unshuffle',
        'channel_shuffle',
        'pdist',
        'cosine_similarity',
    ]
    # Functions generated by `torch._jit_internal.boolean_dispatch`
    dispatches = [
        'fractional_max_pool2d',
        'fractional_max_pool3d',
        'max_pool1d',
        'max_pool2d',
        'max_pool3d',
        'adaptive_max_pool1d',
        'adaptive_max_pool2d',
        'adaptive_max_pool3d',
    ]
    # Functions directly imported from `torch._C`
    from_c = [
        'avg_pool2d',
        'avg_pool3d',
        'hardtanh_',
        'elu_',
        'leaky_relu_',
        'logsigmoid',
        'softplus',
        'softshrink',
        'one_hot',
    ]
    import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
    # TODO make these types more precise
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template(
        'torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: {
            'imported_hints': import_code,
            'dispatched_hints': dispatch_code,
        })

    # functional.pyi already contains the definitions for those functions
    # so, we don't export then to it
    from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid'])
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template(
        'torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: {
            'imported_hints': import_code,
            'dispatched_hints': dispatch_code,
        })
Ejemplo n.º 12
0
def gen_external(native_functions_path, external_path):
    native_functions = parse_native_yaml(native_functions_path)
    func_decls = []
    func_registrations = []
    for func in native_functions:
        schema = func.func
        name = schema.name.name.base
        args = schema.arguments
        # Only supports extern calls for functions with out variants
        if not schema.is_out_fn():
            continue

        # Doesn't currently support functions with more than one out parameter
        if len(args.out) > 1:
            continue

        # Doesn't currently support kwarg arguments
        if len(args.pre_tensor_options_kwarg_only) > 0 or len(
                args.post_tensor_options_kwarg_only) > 0:
            continue
        self_arg = [args.self_arg.argument
                    ] if args.self_arg is not None else []
        args = list(args.pre_self_positional) + self_arg + list(
            args.post_self_positional)
        tensor_args = [
            arg for arg in args if isinstance(arg.type, model.BaseType)
            and arg.type.name == model.BaseTy.Tensor
        ]
        if len(tensor_args) != len(args):
            continue

        arg_names = [None] * len(args)

        tensor_decls = []
        for idx, arg in enumerate(tensor_args):
            s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];"
            tensor_decls.append(s)
            arg_names[idx] = arg.name
        nl = '\n'

        # print(tensor_decls, name, arg_names)
        func_decl = f"""\
void nnc_aten_{name}(
    int64_t bufs_num,
    void** buf_data,
    int64_t* buf_ranks,
    int64_t* buf_dims,
    int8_t* buf_dtypes,
    int64_t args_num,
    int64_t* extra_args) {{
  std::vector<at::Tensor> tensors =
      constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
  at::Tensor& r = tensors[0];
  {nl.join(tensor_decls)}
  try {{
    at::{name}_out({', '.join(['r'] + arg_names)});
  }} catch (...) {{
  }}
}}"""
        func_registration = f"""\
const static RegisterNNCExternalFunction nnc_{name}(
    "nnc_aten_{name}",
    nnc_aten_{name});"""
        func_decls.append(func_decl)
        func_registrations.append(func_registration)
    fm = FileManager(install_dir='.', template_dir='.', dry_run=False)
    fm.write_with_template(
        'external_functions_codegen.cpp', external_path, lambda: {
            'external_registrations': func_registrations,
            'external_functions': func_decls
        })