Пример #1
0
def main() -> None:
    parser = argparse.ArgumentParser(
        description='Generate type stubs for PyTorch')
    parser.add_argument('--native-functions-path',
                        metavar='NATIVE',
                        default='aten/src/ATen/native/native_functions.yaml',
                        help='path to native_functions.yaml')
    parser.add_argument('--deprecated-functions-path',
                        metavar='DEPRECATED',
                        default='tools/autograd/deprecated.yaml',
                        help='path to deprecated.yaml')
    parser.add_argument('--out',
                        metavar='OUT',
                        default='.',
                        help='path to output directory')
    args = parser.parse_args()
    fm = FileManager(install_dir=args.out, template_dir='.', dry_run=False)
    gen_pyi(args.native_functions_path, args.deprecated_functions_path, fm)
Пример #2
0
def gen_variable_type(
    out: str,
    native_yaml_path: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
    operator_selector: SelectiveBuilder,
) -> None:
    """VariableType.h and VariableType.cpp body

    This is the at::Type subclass for differentiable tensors. The
    implementation of each function dispatches to the base tensor type to
    compute the output. The grad_fn is attached to differentiable functions.
    """
    fns = list(
        sorted(filter(
            operator_selector.is_native_function_selected_for_training,
            parse_native_yaml(native_yaml_path)),
               key=lambda f: cpp.name(f.func)))
    fns_with_infos = match_differentiability_info(fns, differentiability_infos)

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    gen_variable_type_shard(fm, fns_with_infos, 'VariableType.h',
                            'VariableType.h')

    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 5
    shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [
        [] for _ in range(num_shards)
    ]

    # functions are assigned arbitrarily but stably to a file based on hash
    for fn in fns_with_infos:
        x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
        shards[x].append(fn)

    for i, shard in enumerate(shards):
        gen_variable_type_shard(fm, shard, 'VariableType.cpp',
                                f'VariableType_{i}.cpp')

    gen_variable_type_shard(fm, fns_with_infos, 'VariableType.cpp',
                            'VariableTypeEverything.cpp')
Пример #3
0
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)

    methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True)
    create_python_bindings(
        fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)

    functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False)
    create_python_bindings(
        fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)

    create_python_bindings(
        fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
Пример #4
0
def gen_inplace_or_view_type(
    out: str,
    native_yaml_path: str,
    fns_with_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str
) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 2
    shards: List[List[NativeFunctionWithDifferentiabilityInfo]] = [[] for _ in range(num_shards)]

    # functions are assigned arbitrarily but stably to a file based on hash
    for fn in fns_with_infos:
        x = sum(ord(c) for c in cpp.name(fn.func.func)) % num_shards
        shards[x].append(fn)

    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
    for i, shard in enumerate(shards):
        gen_inplace_or_view_type_shard(fm, shard, f'_{i}')
    gen_inplace_or_view_type_shard(fm, fns_with_infos, 'Everything')
Пример #5
0
def gen_trace_type(out: str, native_yaml_path: str,
                   template_path: str) -> None:
    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    num_shards = 5
    shards: List[List[NativeFunction]] = [[] for _ in range(num_shards)]

    # functions are assigned arbitrarily but stably to a file based on hash
    native_functions = list(
        sorted(parse_native_yaml(native_yaml_path),
               key=lambda f: cpp.name(f.func)))
    for f in native_functions:
        x = sum(ord(c) for c in cpp.name(f.func)) % num_shards
        shards[x].append(f)

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    for i, shard in enumerate(shards):
        gen_trace_type_shard(fm, shard, '_%d' % i)
    gen_trace_type_shard(fm, native_functions, 'Everything')
def gen_autograd_functions_python(
    out: str,
    differentiability_infos: Sequence[DifferentiabilityInfo],
    template_path: str,
) -> None:

    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    num_shards = 5
    fm.write(
        'python_functions.h', lambda: {
            'generated_comment':
            f'@generated from {fm.template_dir}/python_functions.h',
            'shard_forward_declare': [
                f"void initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ],
            'shard_call': [
                f"initialize_autogenerated_functions_{i}();"
                for i in range(num_shards)
            ]
        })

    infos = list(
        filter(lambda info: info.args_with_derivatives,
               differentiability_infos))
    fm.write_sharded(
        'python_functions.cpp',
        infos,
        key_fn=lambda info: info.name,
        base_env={
            'generated_comment':
            f'@generated from {fm.template_dir}/python_functions.cpp',
        },
        env_callable=lambda info: {
            'py_function_initializers':
            [process_function(info, PY_FUNCTION_DEFINITION)],
            'py_function_props_and_getters':
            [process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)],
        },
        num_shards=num_shards,
        sharded_keys={
            'py_function_initializers', 'py_function_props_and_getters'
        })
Пример #7
0
def gen_variable_type(
    out: str,
    native_yaml_path: str,
    fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo],
    template_path: str,
) -> None:
    """VariableType.h and VariableType.cpp body

    This is the at::Type subclass for differentiable tensors. The
    implementation of each function dispatches to the base tensor type to
    compute the output. The grad_fn is attached to differentiable functions.
    """
    fm = FileManager(install_dir=out,
                     template_dir=template_path,
                     dry_run=False)
    fm.write(
        'VariableType.h', lambda: {
            'generated_comment':
            "@"
            f'generated from {template_path}/VariableType.h'
        })

    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
    # template regarding sharding of the generated files.
    fm.write_sharded('VariableType.cpp',
                     [fn for fn in fns_with_diff_infos if use_derived(fn)],
                     key_fn=lambda fn: cpp.name(fn.func.func),
                     base_env={
                         'generated_comment':
                         "@"
                         f'generated from {template_path}/VariableType.cpp',
                     },
                     env_callable=gen_variable_type_func,
                     num_shards=5,
                     sharded_keys={
                         'type_derived_method_definitions',
                         'wrapper_registrations'
                     })
Пример #8
0
def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str,
            fm: FileManager) -> None:
    """gen_pyi()

    This function generates a pyi file for torch.
    """

    # Some of this logic overlaps with generate_python_signature in
    # tools/autograd/gen_python_functions.py; however, this
    # function is all about generating mypy type signatures, whereas
    # the other function generates are custom format for argument
    # checking.  If you are update this, consider if your change
    # also needs to update the other file.

    # Dictionary for NamedTuple definitions
    namedtuples: Dict[str, str] = {}

    # Generate type signatures for top-level functions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_function_hints: Dict[str,
                                  List[str]] = collections.defaultdict(list)
    unsorted_function_hints.update({
        'set_flush_denormal':
        ['def set_flush_denormal(mode: _bool) -> _bool: ...'],
        'get_default_dtype': ['def get_default_dtype() -> _dtype: ...'],
        'from_numpy': ['def from_numpy(ndarray) -> Tensor: ...'],
        'numel': ['def numel(self: Tensor) -> _int: ...'],
        'clamp': [
            "def clamp(self, min: _float=-inf, max: _float=inf,"
            " *, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
        'as_tensor': [
            "def as_tensor(data: Any, dtype: _dtype=None, device: Optional[_device]=None) -> Tensor: ..."
        ],
        'get_num_threads': ['def get_num_threads() -> _int: ...'],
        'set_num_threads': ['def set_num_threads(num: _int) -> None: ...'],
        'init_num_threads': ['def init_num_threads() -> None: ...'],
        'get_num_interop_threads':
        ['def get_num_interop_threads() -> _int: ...'],
        'set_num_interop_threads':
        ['def set_num_interop_threads(num: _int) -> None: ...'],
        # These functions are explicitly disabled by
        # SKIP_PYTHON_BINDINGS because they are hand bound.
        # Correspondingly, we must hand-write their signatures.
        'tensor':
        ["def tensor(data: Any, {}) -> Tensor: ...".format(FACTORY_PARAMS)],
        'sparse_coo_tensor': [
            'def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
            ' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
            ' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'
        ],
        '_sparse_coo_tensor_unsafe': [
            'def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
            ' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
            ' requires_grad: bool = False) -> Tensor: ...'
        ],
        'range': [
            'def range(start: Number, end: Number,'
            ' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
            .format(FACTORY_PARAMS)
        ],
        'arange': [
            'def arange(start: Number, end: Number, step: Number, *,'
            ' out: Optional[Tensor]=None, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS),
            'def arange(start: Number, end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
            .format(FACTORY_PARAMS),
            'def arange(end: Number, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
            .format(FACTORY_PARAMS)
        ],
        'randint': [
            'def randint(low: _int, high: _int, size: _size, *,'
            ' generator: Optional[Generator]=None, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS), 'def randint(high: _int, size: _size, *,'
            ' generator: Optional[Generator]=None, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS)
        ],
        'full': [
            'def full(size: _size, fill_value: Number, *,'
            ' out: Optional[Tensor]=None,'
            ' layout: _layout=strided, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS), 'def full(size: _size, fill_value: Number, *,'
            ' names: List[Union[str, None]],'
            ' layout: _layout=strided, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS)
        ],
        'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'],
        'nonzero': [
            'def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...',
            'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'
        ],
    })
    for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, out: Optional[Tensor]=None) -> Tensor: ...'.format(binop))
    for binop in ['add', 'sub']:
        unsorted_function_hints[binop].append(
            'def {}(input: Union[Tensor, Number],'
            ' other: Union[Tensor, Number],'
            ' *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...'
            .format(binop))

    function_signatures = load_signatures(native_yaml_path,
                                          deprecated_yaml_path,
                                          method=False,
                                          pyi=True)
    sig_groups = get_py_torch_functions(function_signatures)
    for group in sorted(sig_groups, key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_function_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    function_hints = []
    for name, hints in sorted(unsorted_function_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        function_hints += hints

    # Generate type signatures for Tensor methods
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    unsorted_tensor_method_hints: Dict[
        str, List[str]] = collections.defaultdict(list)
    unsorted_tensor_method_hints.update({
        'size':
        ['def size(self) -> Size: ...', 'def size(self, _int) -> _int: ...'],
        'stride': [
            'def stride(self) -> Tuple[_int]: ...',
            'def stride(self, _int) -> _int: ...'
        ],
        'new_ones': [
            'def new_ones(self, size: _size, {}) -> Tensor: ...'.format(
                FACTORY_PARAMS)
        ],
        'new_tensor': [
            "def new_tensor(self, data: Any, {}) -> Tensor: ...".format(
                FACTORY_PARAMS)
        ],
        # new and __init__ have the same signatures differ only in return type
        # Adapted from legacy_tensor_ctor and legacy_tensor_new
        'new': [
            'def new(self, *args: Any, {}) ->Tensor: ...'.format(DEVICE_PARAM),
            'def new(self, storage: Storage) -> Tensor: ...',
            'def new(self, other: Tensor) -> Tensor: ...',
            'def new(self, size: _size, *, {}) -> Tensor: ...'.format(
                DEVICE_PARAM),
        ],
        '__init__': [
            'def __init__(self, *args: Any, {}) -> None: ...'.format(
                DEVICE_PARAM),
            'def __init__(self, storage: Storage) -> None: ...',
            'def __init__(self, other: Tensor) -> None: ...',
            'def __init__(self, size: _size, *, {}) -> None: ...'.format(
                DEVICE_PARAM),
        ],
        'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
        # clamp has no default values in the Declarations
        'clamp': [
            "def clamp(self, min: _float=-inf, max: _float=inf,"
            " *, out: Optional[Tensor]=None) -> Tensor: ..."
        ],
        'clamp_':
        ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."],
        '__getitem__':
        ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)],
        '__setitem__': [
            "def __setitem__(self, {}, val: Union[Tensor, Number])"
            " -> None: ...".format(INDICES)
        ],
        'tolist': ['def tolist(self) -> List: ...'],
        'requires_grad_':
        ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
        'element_size': ['def element_size(self) -> _int: ...'],
        'data_ptr': ['def data_ptr(self) -> _int: ...'],
        'dim': ['def dim(self) -> _int: ...'],
        'nonzero':
        ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
        'numel': ['def numel(self) -> _int: ...'],
        'ndimension': ['def ndimension(self) -> _int: ...'],
        'nelement': ['def nelement(self) -> _int: ...'],
        'cuda': [
            'def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...'
        ],
        'numpy': ['def numpy(self) -> Any: ...'],
        'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'],
        'map_':
        ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'],
        'storage': ['def storage(self) -> Storage: ...'],
        'type': [
            'def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...',
            'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...',
        ],
        'get_device': ['def get_device(self) -> _int: ...'],
        'contiguous': [
            'def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'
        ],
        'is_contiguous': [
            'def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'
        ],
        'is_cuda': ['is_cuda: _bool'],
        'is_leaf': ['is_leaf: _bool'],
        'is_sparse': ['is_sparse: _bool'],
        'is_quantized': ['is_quantized: _bool'],
        'is_meta': ['is_meta: _bool'],
        'is_mkldnn': ['is_mkldnn: _bool'],
        'is_vulkan': ['is_vulkan: _bool'],
        'storage_offset': ['def storage_offset(self) -> _int: ...'],
        'to': [
            'def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
            'def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, '
            'non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
            'def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...',
        ],
        'item': ["def item(self) -> Number: ..."],
        'copy_': [
            "def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."
        ],
        'set_': [
            'def set_(self, storage: Storage, offset: _int, size: _size, stride: _size) -> Tensor: ...',
            'def set_(self, storage: Storage) -> Tensor: ...'
        ],
        'split': [
            'def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
            'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'
        ],
    })
    for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
        for inplace in [False, True]:
            out_suffix = ', *, out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number]{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    for binop in ['add', 'sub']:
        for inplace in [False, True]:
            out_suffix = ', out: Optional[Tensor]=None'
            if inplace:
                binop += '_'
                out_suffix = ''
            unsorted_tensor_method_hints[binop].append(
                'def {}(self, other: Union[Tensor, Number], '
                '*, alpha: Optional[Number]=1{})'
                ' -> Tensor: ...'.format(binop, out_suffix))
    simple_conversions = [
        'byte', 'char', 'cpu', 'double', 'float', 'half', 'int', 'long',
        'short', 'bool', 'bfloat16'
    ]
    for name in simple_conversions:
        unsorted_tensor_method_hints[name].append(
            'def {}(self) -> Tensor: ...'.format(name))

    # pyi tensor methods don't currently include deprecated signatures for some reason
    # TODO: we should probably add them in
    tensor_method_signatures = load_signatures(native_yaml_path,
                                               deprecated_yaml_path,
                                               method=True,
                                               skip_deprecated=True,
                                               pyi=True)
    tensor_method_sig_groups = get_py_torch_functions(tensor_method_signatures,
                                                      method=True)

    for group in sorted(tensor_method_sig_groups,
                        key=lambda g: g.signature.name):
        name = group.signature.name
        unsorted_tensor_method_hints[name] += generate_type_hints(group)

        named_tuple = group.signature.returns.named_tuple_pyi()
        if named_tuple is not None and not group.signature.deprecated:
            # deprecated namedtuples are currently not included for torch functions
            tuple_name, tuple_def = named_tuple
            if tuple_name in namedtuples:
                assert namedtuples[tuple_name] == tuple_def
            else:
                namedtuples[tuple_name] = tuple_def

    for op in all_ops:
        name = '__{}__'.format(op)
        unsorted_tensor_method_hints[name] += sig_for_ops(name)

    tensor_method_hints = []
    for name, hints in sorted(unsorted_tensor_method_hints.items()):
        if len(hints) > 1:
            hints = ['@overload\n' + h for h in hints]
        tensor_method_hints += hints

    # TODO: Missing type hints for nn

    # Generate namedtuple definitions
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    namedtuple_defs = [
        '{} = {}'.format(name, defn) for name, defn in namedtuples.items()
    ]

    # Generate type signatures for legacy classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: These are deprecated, maybe we shouldn't type hint them
    legacy_storage_base_hints = []
    dt = ('Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Bool',
          'Half', 'BFloat16', 'ComplexDouble', 'ComplexFloat', 'QUInt8',
          'QInt8', 'QInt32', 'QUInt4x2')
    for c in dt:
        legacy_storage_base_hints.append(
            'class {}StorageBase(object): ...'.format(c))
    for c in dt:
        legacy_storage_base_hints.append(
            'class Cuda{}StorageBase(object): ...'.format(c))

    legacy_class_hints = []
    for c in ('DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
              'ShortTensor', 'HalfTensor', 'CharTensor', 'ByteTensor',
              'BoolTensor'):
        legacy_class_hints.append('class {}(Tensor): ...'.format(c))

    # Generate type signatures for dtype classes
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # TODO: don't explicitly list dtypes here; get it from canonical
    # source
    dtype_class_hints = [
        '{}: dtype = ...'.format(n) for n in [
            'float32', 'float', 'float64', 'double', 'float16', 'bfloat16',
            'half', 'uint8', 'int8', 'int16', 'short', 'int32', 'int', 'int64',
            'long', 'complex32', 'complex64', 'cfloat', 'complex128',
            'cdouble', 'quint8', 'qint8', 'qint32', 'bool', 'quint4x2'
        ]
    ]

    # Generate __all__ directive
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    # Include only the functions that contain hints, to prevent undefined
    # symbols to be included in the `__all__` directive.
    hinted_function_names = [
        name for name, hint in unsorted_function_hints.items() if hint
    ]
    all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names)
    all_directive = pformat(all_symbols, width=100, compact=True).split('\n')
    all_directive[0] = '__all__ = {}'.format(all_directive[0])

    # Write out the stub
    # ~~~~~~~~~~~~~~~~~~

    env = {
        'namedtuple_defs': namedtuple_defs,
        'function_hints': function_hints,
        'tensor_method_hints': tensor_method_hints,
        'legacy_class_hints': legacy_class_hints,
        'legacy_storage_base_hints': legacy_storage_base_hints,
        'dtype_class_hints': dtype_class_hints,
        'all_directive': all_directive
    }
    fm.write_with_template(
        'torch/_C/__init__.pyi', 'torch/_C/__init__.pyi.in', lambda: {
            'generated_comment': '@' +
            'generated from torch/_C/__init__.pyi.in',
            **env,
        })
    fm.write_with_template(
        'torch/_C/_VariableFunctions.pyi',
        'torch/_C/_VariableFunctions.pyi.in', lambda: {
            'generated_comment': '@' +
            'generated from torch/_C/_VariableFunctions.pyi.in',
            **env,
        })
    fm.write_with_template(
        'torch/_VF.pyi', 'torch/_C/_VariableFunctions.pyi.in', lambda: {
            'generated_comment': '@' +
            'generated from torch/_C/_VariableFunctions.pyi.in',
            **env,
        })
    gen_nn_functional(fm)
Пример #9
0
def gen_nn_functional(fm: FileManager) -> None:
    # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
    # through an `_add_docstr` call
    imports = [
        'conv1d',
        'conv2d',
        'conv3d',
        'conv_transpose1d',
        'conv_transpose2d',
        'conv_transpose3d',
        'conv_tbc',
        'avg_pool1d',
        'relu_',
        'selu_',
        'celu_',
        'rrelu_',
        'pixel_shuffle',
        'pixel_unshuffle',
        'channel_shuffle',
        'pdist',
        'cosine_similarity',
    ]
    # Functions generated by `torch._jit_internal.boolean_dispatch`
    dispatches = [
        'fractional_max_pool2d',
        'fractional_max_pool3d',
        'max_pool1d',
        'max_pool2d',
        'max_pool3d',
        'adaptive_max_pool1d',
        'adaptive_max_pool2d',
        'adaptive_max_pool3d',
    ]
    # Functions directly imported from `torch._C`
    from_c = [
        'avg_pool2d',
        'avg_pool3d',
        'hardtanh_',
        'elu_',
        'leaky_relu_',
        'logsigmoid',
        'softplus',
        'softshrink',
        'one_hot',
    ]
    import_code = ["from .. import {0} as {0}".format(_) for _ in imports]
    # TODO make these types more precise
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template(
        'torch/nn/functional.pyi', 'torch/nn/functional.pyi.in', lambda: {
            'imported_hints': import_code,
            'dispatched_hints': dispatch_code,
        })

    # functional.pyi already contains the definitions for those functions
    # so, we don't export then to it
    from_c.extend(['hardtanh', 'leaky_relu', 'hardsigmoid'])
    dispatch_code = ["{}: Callable".format(_) for _ in (dispatches + from_c)]
    fm.write_with_template(
        'torch/_C/_nn.pyi', 'torch/_C/_nn.pyi.in', lambda: {
            'imported_hints': import_code,
            'dispatched_hints': dispatch_code,
        })
Пример #10
0
 def make_file_manager(install_dir: str) -> FileManager:
     return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=dry_run)
Пример #11
0
def gen_external(native_functions_path, external_path):
    native_functions = parse_native_yaml(native_functions_path)
    func_decls = []
    func_registrations = []
    for func in native_functions:
        schema = func.func
        name = schema.name.name.base
        args = schema.arguments
        # Only supports extern calls for functions with out variants
        if not schema.is_out_fn():
            continue

        # Doesn't currently support functions with more than one out parameter
        if len(args.out) > 1:
            continue

        # Doesn't currently support kwarg arguments
        if len(args.pre_tensor_options_kwarg_only) > 0 or len(
                args.post_tensor_options_kwarg_only) > 0:
            continue
        self_arg = [args.self_arg.argument
                    ] if args.self_arg is not None else []
        args = list(args.pre_self_positional) + self_arg + list(
            args.post_self_positional)
        tensor_args = [
            arg for arg in args if isinstance(arg.type, model.BaseType)
            and arg.type.name == model.BaseTy.Tensor
        ]
        if len(tensor_args) != len(args):
            continue

        arg_names = [None] * len(args)

        tensor_decls = []
        for idx, arg in enumerate(tensor_args):
            s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];"
            tensor_decls.append(s)
            arg_names[idx] = arg.name
        nl = '\n'

        # print(tensor_decls, name, arg_names)
        func_decl = f"""\
void nnc_aten_{name}(
    int64_t bufs_num,
    void** buf_data,
    int64_t* buf_ranks,
    int64_t* buf_dims,
    int8_t* buf_dtypes,
    int64_t args_num,
    int64_t* extra_args) {{
  std::vector<at::Tensor> tensors =
      constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_dtypes);
  at::Tensor& r = tensors[0];
  {nl.join(tensor_decls)}
  try {{
    at::{name}_out({', '.join(['r'] + arg_names)});
  }} catch (...) {{
  }}
}}"""
        func_registration = f"""\
const static RegisterNNCExternalFunction nnc_{name}(
    "nnc_aten_{name}",
    nnc_aten_{name});"""
        func_decls.append(func_decl)
        func_registrations.append(func_registration)
    fm = FileManager(install_dir='.', template_dir='.', dry_run=False)
    fm.write_with_template(
        'external_functions_codegen.cpp', external_path, lambda: {
            'external_registrations': func_registrations,
            'external_functions': func_decls
        })