Beispiel #1
0
def _get_torchscript_builtins():
    functions = []
    builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
    builtins_list = list(builtins)
    # Iterate over the specially added builtins
    for fn, _builtin_name in builtins_list:
        mod = inspect.getmodule(fn)
        if not mod:
            raise RuntimeError(f'Module for {fn} not found')
        builtin = _find_builtin(fn)
        if builtin is not None:
            schemas = torch._C._jit_get_schemas_for_operator(builtin)
            for schema in schemas:
                functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
                pass

    return "TorchScript Builtin Functions", functions
Beispiel #2
0
def _get_math_builtins():
    functions = []
    builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
    builtins = list(builtins)
    # Iterate over the specially added builtins
    for fn, _builtin_name in builtins:
        mod = inspect.getmodule(fn)
        builtin = _find_builtin(fn)
        if builtin is not None:
            schemas = torch._C._jit_get_schemas_for_operator(builtin)
            for schema in schemas:
                schema = _emit_schema(mod.__name__, fn.__name__, schema)
                if 'Tensor' in schema:
                    # Skip Tensor ops that have the same name as math functions
                    # (they will show up in the tensor methods section)
                    continue
                functions.append(schema)
                pass

    return "``math`` Module", functions
Beispiel #3
0
def _get_nn_functional_ops():
    functions = []

    # Iterate over torch.nn.functional
    mod = torch.nn.functional
    name = mod.__name__
    for elem in dir(torch.nn.functional):
        attr = getattr(mod, elem)
        if not inspect.isfunction(attr) or _hidden(elem[0]):
            # Ignore non-functions and internal methods
            continue

        attr_module = inspect.getmodule(attr)
        if not attr_module:
            raise RuntimeError(f'Module for {attr} not found')

        if 'torch.nn.functional' not in attr_module.__name__:
            # Ignore functions from outside torch.nn.functional
            continue

        try:
            # compile fn, get schema
            scripted = torch.jit.script(attr)
            schema = scripted.schema
            functions.append(_emit_schema(name, elem, schema))
        except:  # noqa
            # Skip interpolate / boolean dispatched things
            pass

    # Iterate over modules that we know contain a lot of builtins
    for mod in torch.jit._builtins._modules_containing_builtins:
        name = mod.__name__
        for elem in dir(mod):
            builtin = _find_builtin(getattr(mod, elem))
            if builtin is not None:
                schemas = torch._C._jit_get_schemas_for_operator(builtin)
                for schema in schemas:
                    # remove _tan but not __and__
                    if not _hidden(elem):
                        functions.append(_emit_schema(name, elem, schema))
    return "Supported PyTorch Functions", functions
Beispiel #4
0
def infer_concrete_type_builder(nn_module, share_types=True):
    """
    Build a ConcreteModuleTypeBuilder from an nn.Module. This
    ConcreteModuleType doesn't have a JIT type associated with it yet, it
    must be filled in by the caller.
    """
    concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
    if isinstance(nn_module, (torch.nn.ModuleDict)):
        concrete_type_builder.set_module_dict()
    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
        concrete_type_builder.set_module_list()

    class_annotations = getattr(nn_module, '__annotations__', {})
    if isinstance(nn_module, (torch.quantization.QuantWrapper)):
        class_annotations = {}

    # Get user-annotated ignored attributes.
    user_annotated_ignored_attributes = getattr(nn_module,
                                                "__jit_ignored_attributes__",
                                                list())
    concrete_type_builder.add_ignored_attributes(
        user_annotated_ignored_attributes)
    ignored_properties = jit_ignored_properties(nn_module)

    # try to infer the type from type annotation or from the object itself
    def infer_type(name, item):
        # The forward function from Module is special; never use this annotations; we
        # need to infer type directly using JIT.  I originally wanted to write
        # this test as isinstance(class_annotations[name], Callable) but
        # isinstance on typing things doesn't seem to work: isinstance(list, Callable)
        # is also true!
        inferred = False
        try:
            if name in class_annotations and class_annotations[
                    name] != torch.nn.Module.__annotations__["forward"]:
                ann_to_type = torch.jit.annotations.ann_to_type(
                    class_annotations[name], _jit_internal.fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            elif isinstance(item, torch.jit.Attribute):
                ann_to_type = torch.jit.annotations.ann_to_type(
                    item.type, _jit_internal.fake_range())
                attr_type = torch._C.InferredType(ann_to_type)
            else:
                attr_type = torch._C._jit_try_infer_type(item)
                inferred = True
        except RuntimeError as re:
            raise RuntimeError(
                "Error inferring type for {name}: {item}: {re}".format(
                    name=name, item=item, re=re))

        return attr_type, inferred

    added_names = set()

    for name, item in nn_module._parameters.items():
        if name in user_annotated_ignored_attributes:
            continue

        assert item is None or isinstance(item, torch.Tensor)
        attr_type, _ = infer_type(name, item)
        # We currently have the invariant in various places in our code
        # that parameters must be Tensors. However, the nn.Module API also
        # allows NoneType parameters. These parameters are not returned as
        # part of `parameters()` and its variants, but are available
        # through direct attribute access.
        concrete_type_builder.add_attribute(name, attr_type.type(), True,
                                            False)
        added_names.add(name)

    for name, item in nn_module._buffers.items():
        if name in user_annotated_ignored_attributes:
            continue

        assert item is None or isinstance(item, torch.Tensor)
        attr_type, _ = infer_type(name, item)
        concrete_type_builder.add_attribute(name, attr_type.type(), False,
                                            True)
        added_names.add(name)

    for name, item in nn_module._modules.items():
        if name in user_annotated_ignored_attributes:
            continue

        attr_type, _ = infer_type(name, item)
        if item is None:
            # Modules can be None. We don't have direct support for optional
            # Modules, so the register it as an NoneType attribute instead.
            concrete_type_builder.add_attribute(name, attr_type.type(), False,
                                                False)
            continue
        if attr_type.success():
            assert attr_type.type().is_interface_type()
            # if the type can be inferred, it should be a module interface type
            sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(
                attr_type.type())
        else:
            # otherwise we get the concrete module type for item and add it to concrete_type
            sub_concrete_type = get_module_concrete_type(item, share_types)
        concrete_type_builder.add_module(name, sub_concrete_type)

        added_names.add(name)

    # populate constants_set
    constants_set = getattr(nn_module, "__constants__", set())

    # Constants annotated via `Final[T]` rather than being added to `__constants__`
    for name, ann in class_annotations.items():
        if torch._jit_internal.is_final(ann):
            constants_set.add(name)

    for name in constants_set:
        if name in added_names:
            # TODO: We should really error in this case, but its bc-breaking so
            # we need to warn for at least one release
            if name in nn_module._modules:
                hint = "submodule"
            elif name in nn_module._buffers:
                hint = "buffer"
            elif name in nn_module._parameters:
                hint = "parameter"
            else:
                raise AssertionError(
                    "added_names must be submodule, parameter, or buffer")

            warnings.warn(
                "'{}' was found in ScriptModule constants, "
                " but it is a non-constant {}. Consider removing it.".format(
                    name, hint))
            continue
        if not hasattr(nn_module, name):
            # TODO: We should really error in this case, but its bc-breaking so
            # we need to warn for at least one release
            warnings.warn("'{}' was found in ScriptModule constants, "
                          "but was not actually set in __init__. "
                          "Consider removing it.".format(name))
            continue
        value = getattr(nn_module, name)
        concrete_type_builder.add_constant(
            name, _get_valid_constant(name, value,
                                      type(nn_module).__name__))
        added_names.add(name)

    # populate overloads
    overloads = getattr(nn_module, "__overloads__", {})
    # update with any annotated overloads
    overloads.update(
        get_overload_name_mapping(
            get_overload_annotations(nn_module, ignored_properties)))
    for name, overloaded_names in overloads.items():
        concrete_type_builder.add_overload(name, overloaded_names)

    for name, value in nn_module.__dict__.items():
        if name in ignored_attributes or name.startswith("__"):
            # Python objects have lots of random attributes attached to them;
            # PyTorch adds a few more. Prevent these from getting compiled.
            continue

        if name in user_annotated_ignored_attributes:
            continue

        if name in added_names:
            # Don't re-add anything we already added
            continue

        # Handle Python function attributes
        if inspect.isfunction(value):
            try:
                scripted_fn = torch.jit.script(value)
                concrete_type_builder.add_function_attribute(
                    name,
                    torch._C._jit_try_infer_type(scripted_fn).type(), value)
            except Exception as e:
                # If we fail to script the function, it isn't a hard error.
                # Instead, we will add it to the list of attributes we failed
                # to convert, with the compilation error.
                hint = (
                    "(This function exists as an attribute on the Python module, "
                    "but we failed to compile it to a TorchScript function. "
                    "\nThe error stack is reproduced here:\n{}").format(e)
                concrete_type_builder.add_failed_attribute(name, hint)
                pass

            continue

        # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
        # a call to an aten function like torch.add)
        builtin_symbol_name = _find_builtin(value)
        if builtin_symbol_name:
            concrete_type_builder.add_builtin_function(name,
                                                       builtin_symbol_name)
            continue

        # Handle Script function attributes
        if isinstance(value, torch.jit.ScriptFunction):
            concrete_type_builder.add_function_attribute(
                name,
                torch._C._jit_try_infer_type(value).type(), value)
            continue

        # If we got here, this is a regular "data" attribute, Add it to the concrete type
        attr_type, inferred = infer_type(name, value)
        if attr_type.success():
            concrete_type_builder.add_attribute(name, attr_type.type(), False,
                                                False)
        else:
            # TODO: could add more detail here. For example, what the user should do
            # when the pytype is `list` or `NoneType`
            inferred_msg = "Its type was inferred; try adding a type annotation for the attribute." if inferred else ""
            additional_info = f"{attr_type.reason()}. {inferred_msg}"
            hint = "(This attribute exists on the Python module, " \
                f"but we failed to convert Python type: '{torch.typename(type(value))}' " \
                f"to a TorchScript type. {additional_info})"
            concrete_type_builder.add_failed_attribute(name, hint)

    # add hooks to concrete type
    for hook in nn_module._forward_hooks.values():
        concrete_type_builder.add_forward_hook(hook)
    for pre_hook in nn_module._forward_pre_hooks.values():
        concrete_type_builder.add_forward_pre_hook(pre_hook)

    return concrete_type_builder
Beispiel #5
0
def infer_concrete_type_builder(nn_module):
    """
    Build a ConcreteModuleTypeBuilder from an nn.Module. This
    ConcreteModuleType doesn't have a JIT type associated with it yet, it
    must be filled in by the caller.
    """
    concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
    if isinstance(nn_module, (torch.nn.ModuleDict)):
        concrete_type_builder.set_module_dict()
    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
        concrete_type_builder.set_module_list()

    class_annotations = getattr(nn_module, '__annotations__', {})

    # try to infer the type from type annotation or from the object itself
    def infer_type(name, item):
        if name in class_annotations:
            attr_type = torch.jit.annotations.ann_to_type(
                class_annotations[name], _jit_internal.fake_range())
        elif isinstance(item, torch.jit.Attribute):
            attr_type = torch.jit.annotations.ann_to_type(
                item.type, _jit_internal.fake_range())
        else:
            attr_type = torch._C._jit_try_infer_type(item)
        return attr_type

    added_names = set()

    for name, item in nn_module._parameters.items():
        assert item is None or isinstance(item, torch.Tensor)
        attr_type = infer_type(name, item)
        # We currently have the invariant in various places in our code
        # that parameters must be Tensors. However, the nn.Module API also
        # allows NoneType parameters. These parameters are not returned as
        # part of `parameters()` and its variants, but are available
        # through direct attribute access.
        concrete_type_builder.add_attribute(name, attr_type, True)
        added_names.add(name)

    for name, item in nn_module._buffers.items():
        assert item is None or isinstance(item, torch.Tensor)
        attr_type = infer_type(name, item)
        concrete_type_builder.add_attribute(name, attr_type, False)
        added_names.add(name)

    for name, item in nn_module._modules.items():
        attr_type = infer_type(name, item)
        if item is None:
            # Modules can be None. We don't have direct support for optional
            # Modules, so the register it as an NoneType attribute instead.
            concrete_type_builder.add_attribute(name, attr_type, False)
            continue
        if attr_type is not None:
            assert attr_type.is_interface_type()
            # if the type can be inferred, it should be a module interface type
            sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type(
                attr_type)
        else:
            # otherwise we get the concrete module type for item and add it to concrete_type
            sub_concrete_type = concrete_type_store.get_or_create_concrete_type(
                item)
        concrete_type_builder.add_module(name, sub_concrete_type)

        added_names.add(name)

    # populate constants_set
    constants_set = getattr(nn_module, "__constants__", set())

    # Constants annotated via `Final[T]` rather than being added to `__constants__`
    for name, ann in class_annotations.items():
        if torch._jit_internal.is_final(ann):
            constants_set.add(name)

    for name in constants_set:
        if name in added_names:
            # TODO: We should really error in this case, but its bc-breaking so
            # we need to warn for at least one release
            if name in nn_module._modules:
                hint = "submodule"
            elif name in nn_module._buffers:
                hint = "buffer"
            elif name in nn_module._parameters:
                hint = "parameter"
            else:
                raise AssertionError(
                    "added_names must be submodule, parameter, or buffer")

            warnings.warn(
                "'{}' was found in ScriptModule constants, "
                " but it is a non-constant {}. Consider removing it.".format(
                    name, hint))
            continue
        if not hasattr(nn_module, name):
            # TODO: We should really error in this case, but its bc-breaking so
            # we need to warn for at least one release
            warnings.warn("'{}' was found in ScriptModule constants, "
                          "but was not actually set in __init__. "
                          "Consider removing it.".format(name))
            continue
        value = getattr(nn_module, name)
        concrete_type_builder.add_constant(name,
                                           _get_valid_constant(name, value))
        added_names.add(name)

    # populate overloads
    overloads = getattr(nn_module, "__overloads__", {})
    # update with any annotated overloads
    overloads.update(
        get_overload_name_mapping(get_overload_annotations(nn_module)))
    for name, overloaded_names in overloads.items():
        concrete_type_builder.add_overload(name, overloaded_names)

    for name, value in nn_module.__dict__.items():
        if name in blacklist or name.startswith("__"):
            # Python objects have lots of random attributes attached to them;
            # PyTorch adds a few more. Prevent these from getting compiled.
            continue

        if name in added_names:
            # Don't re-add anything we already added
            continue

        # Handle Python function attributes
        if inspect.isfunction(value):
            try:
                scripted_fn = torch.jit.script(value)
                concrete_type_builder.add_function_attribute(
                    name, torch._C._jit_try_infer_type(scripted_fn), value)
            except Exception as e:
                # If we fail to script the function, it isn't a hard error.
                # Instead, we will add it to the list of attributes we failed
                # to convert, with the compilation error.
                hint = (
                    "(This function exists as an attribute on the Python module, "
                    "but we failed to compile it to a TorchScript function. "
                    "\nThe error stack is reproduced here:\n{}").format(e)
                concrete_type_builder.add_failed_attribute(name, hint)
                pass

            continue

        # Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or
        # a call to an aten function like torch.add)
        builtin_symbol_name = _find_builtin(value)
        if builtin_symbol_name:
            concrete_type_builder.add_builtin_function(name,
                                                       builtin_symbol_name)
            continue

        # Handle Script function attributes
        if isinstance(value, torch.jit.ScriptFunction):
            concrete_type_builder.add_function_attribute(
                name, torch._C._jit_try_infer_type(value), value)
            continue

        # If we got here, this is a regular "data" attribute, Add it to the concrete type
        attr_type = infer_type(name, value)
        if attr_type is not None:
            concrete_type_builder.add_attribute(name, attr_type, False)
        else:
            # TODO: could add more detail here. For example, what the user should do
            # when the pytype is `list` or `NoneType`
            hint = ("(This attribute exists on the Python module, "
                    "but we failed to convert Python type: '{}' "
                    "to a TorchScript type.)").format(type(value).__name__)
            concrete_type_builder.add_failed_attribute(name, hint)

    # Add @property methods as failed attributes, to give a better error message.
    for name, value in type(nn_module).__dict__.items():
        if isinstance(value, property):
            hint = (
                "\n(This attribute exists on the Python module, but it's an @property "
                "method. @property methods are not yet supported in TorchScript. "
                "Please file a feature request on Github)")
            concrete_type_builder.add_failed_attribute(name, hint)

    return concrete_type_builder