Ejemplo n.º 1
0
 def _split_tensor_list_constants(g, block):
     for node in block.nodes():
         for subblock in node.blocks():
             _split_tensor_list_constants(g, subblock)
         if node.kind() == "prim::Constant":
             output_type = node.output().type()
             if output_type.isSubtypeOf(ListType.ofTensors()):
                 inputs = [g.create("prim::Constant").t_('value', t)
                            .insertBefore(node).output()
                           for t in node['value']]
                 lc = (g.create("prim::ListConstruct", inputs)
                       .insertBefore(node)
                       .output()
                       .setType(ListType.ofTensors()))
                 node.output().replaceAllUsesWith(lc)
Ejemplo n.º 2
0
def ann_to_type(ann):
    if ann is None:
        return TensorType.get()
    elif ann is torch.Tensor:
        return TensorType.get()
    elif is_tuple(ann):
        return TupleType([ann_to_type(a) for a in ann.__args__])
    elif is_list(ann):
        return ListType(ann_to_type(ann.__args__[0]))
    elif is_dict(ann):
        key = ann_to_type(ann.__args__[0])
        value = ann_to_type(ann.__args__[1])
        return DictType(key, value)
    elif is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(ann_to_type(ann.__args__[0]))
        else:
            return OptionalType(ann_to_type(ann.__args__[1]))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    elif ann is bool:
        return BoolType.get()
    elif hasattr(ann, "__torch_script_class__"):
        return ClassType(_qualified_name(ann))
    raise ValueError("Unknown type annotation: '{}'".format(ann))
Ejemplo n.º 3
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(try_ann_to_type(ann.__args__[0], loc))
        else:
            return OptionalType(try_ann_to_type(ann.__args__[1], loc))
    if is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann):
        if hasattr(ann, "__torch_script_class__"):
            return ClassType(_qualified_name(ann))
        # Why Callable?  forward is declared to be a Callable so that
        # people can define it without mypy complaining.  But we shouldn't
        # try to recursively compile it!
        ignored_builtin_classes = (torch.nn.Module, tuple, list, Callable)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(
                ann, ignored_builtin_classes):
            torch.jit._recursive_compile_class(ann, loc)
            return ClassType(_qualified_name(ann))

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None

    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
Ejemplo n.º 4
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            valid_type = try_ann_to_type(ann.__args__[0], loc)
        else:
            valid_type = try_ann_to_type(ann.__args__[1], loc)
        assert valid_type, "Unsupported annotation {} could not be resolved.".format(
            repr(ann))
        return OptionalType(valid_type)
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann):
        if hasattr(ann, "__torch_script_class__"):
            return ClassType(_qualified_name(ann))
        ignored_builtin_classes = (torch.nn.Module, tuple, list)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(
                ann, ignored_builtin_classes):
            torch.jit._script._recursive_compile_class(ann, loc)
            return ClassType(_qualified_name(ann))

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None

    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
Ejemplo n.º 5
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    elif ann is torch.Tensor:
        return TensorType.get()
    elif is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    elif is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    elif is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    elif is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(try_ann_to_type(ann.__args__[0], loc))
        else:
            return OptionalType(try_ann_to_type(ann.__args__[1], loc))
    elif is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    elif is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    elif ann is bool:
        return BoolType.get()
    elif ann is Any:
        return AnyType.get()
    elif ann is type(None):
        return NoneType.get()
    elif inspect.isclass(ann) and hasattr(ann, "__torch_script_class__"):
        return ClassType(_qualified_name(ann))
    elif inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    elif ann is torch.device:
        return DeviceObjType.get()
    else:
        # Maybe resolve a NamedTuple to a Tuple Type
        def fake_rcb(key):
            return None
        the_type = torch._C._resolve_type_from_object(ann, loc, fake_rcb)
        if the_type is not None:
            return the_type
    return None
Ejemplo n.º 6
0
def ann_to_type(ann):
    if ann is None:
        return DynamicType.get()
    elif ann is torch.Tensor:
        return DynamicType.get()
    elif is_tuple(ann):
        return TupleType([ann_to_type(a) for a in ann.__args__])
    elif is_list(ann):
        return ListType(ann_to_type(ann.__args__[0]))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    raise ValueError("The only supported annotations kinds are Tensor and Tuple[...]")
Ejemplo n.º 7
0
def _split_tensor_list_constants(g, block):
    for node in block.nodes():
        for subblock in node.blocks():
            _split_tensor_list_constants(g, subblock)
        if _is_constant_tensor_list(node):
            inputs = []
            for val in node.output().toIValue():
                input = g.insertConstant(val)
                input.node().moveBefore(node)
                inputs.append(input)

            lc = (g.create("prim::ListConstruct",
                           inputs).insertBefore(node).output().setType(
                               ListType.ofTensors()))
            node.output().replaceAllUsesWith(lc)
Ejemplo n.º 8
0
def to(g, self, *args):
    # ONNX doesn't have a concept of a device, so we ignore device casts
    if len(args) == 2:
        if args[0].type().isSubtypeOf(ListType.ofInts()):
            # aten::to(Tensor, Device, bool)
            return self
        else:
            # aten::to(Tensor, ScalarType, bool)
            dtype = _get_const(args[0], 'i', 'dtype')
            return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
    elif len(args) == 3:
        # aten::to(Tensor, Device, ScalarType, bool)
        dtype = _get_const(args[1], 'i', 'dtype')
        return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
    else:
        raise NotImplementedError("Unknown aten::to signature")
Ejemplo n.º 9
0
def ann_to_type(ann, resolver=None):
    # resolver should be a Tuple[Callable, SourceRange] where the Callable
    # is a resolutionCallback
    if ann is None:
        return TensorType.get()
    elif ann is torch.Tensor:
        return TensorType.get()
    elif is_tuple(ann):
        return TupleType([ann_to_type(a) for a in ann.__args__])
    elif is_list(ann):
        return ListType(ann_to_type(ann.__args__[0]))
    elif is_dict(ann):
        key = ann_to_type(ann.__args__[0])
        value = ann_to_type(ann.__args__[1])
        return DictType(key, value)
    elif is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(ann_to_type(ann.__args__[0]))
        else:
            return OptionalType(ann_to_type(ann.__args__[1]))
    elif is_rref(ann):
        return RRefType(ann_to_type(ann.__args__[0]))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    elif ann is bool:
        return BoolType.get()
    elif ann is Any:
        return AnyType.get()
    elif ann is type(None):
        return NoneType.get()
    elif hasattr(ann, "__torch_script_class__"):
        return ClassType(_qualified_name(ann))
    elif hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    elif ann is torch.device:
        return DeviceObjType.get()
    elif resolver is not None:
        # Maybe resolve a NamedTuple to a Tuple Type
        rcb, loc = resolver
        the_type = torch._C._resolve_type(ann.__name__, loc, rcb)
        if the_type is not None:
            return the_type
    raise ValueError("Unknown type annotation: '{}'".format(ann))
Ejemplo n.º 10
0
def ann_to_type(ann):
    if ann is None:
        return DynamicType.get()
    elif ann is torch.Tensor:
        return DynamicType.get()
    elif is_tuple(ann):
        return TupleType([ann_to_type(a) for a in ann.__args__])
    elif is_list(ann):
        return ListType(ann_to_type(ann.__args__[0]))
    elif is_dict(ann):
        key = ann_to_type(ann.__args__[0])
        value = ann_to_type(ann.__args__[1])
        return DictType(key, value)
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    raise ValueError("Unknown type annotation: '{}'".format(ann.__name__))
Ejemplo n.º 11
0
def _is_tensor_list(x):
    return x.type().isSubtypeOf(ListType.ofTensors())
Ejemplo n.º 12
0
def augment_many_model_functions_with_bundled_inputs(
    model: torch.jit.ScriptModule,
    inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
    _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
    info: Optional[Dict[Callable, List[
        str]]] = None,  # Optional argument to provide info about the function or its inputs
) -> None:
    """Add bundled sample inputs to a model for an arbitrary list of public functions.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`

        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
            Returns a dictionary mapping function names to a metadata dictionary.
            This nested dictionary maps preset strings like:
                'get_inputs_function_name' -> the name of a function attribute in this model that can be
                    run to get back a list of inputs corresponding to that function.
                'info' -> the user provided extra information about the bundled inputs

    If forward has bundled inputs then these following functions are also defined:

        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs(): model(*inp)`

        `get_num_bundled_inputs() -> int`
            Equivalent to `len(model.get_all_bundled_inputs())`,
            but slightly easier to call from C++.

        `run_on_bundled_input(idx: int) -> Any`
            Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs_for_<function_name>`
        get_all_bundled_inputs will simply call this method
        and cache the value. If the user chooses this method inputs[<function>]
        should map to None
      - The `inputs` argument to this function can be a dictionary mapping functions to a
        list of tuples, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.

      It is highly recommended (though not enforced) that if multiple functions have the same input style, that
      you create separate bundled inputs for each function. Reusing the same input and bundling it to multiple
      functions can cause issues with other torch.jit functionality like freeze

    Info is an optional parameter that maps functions to a list of strings providing extra information about that
    function's bundled inputs. This could be descriptions, expected outputs, etc.
        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}

    This function will attempt to optimize arguments so that (e.g.)
    arguments like `torch.zeros(1000)` will be represented compactly.
    Only top-level arguments will be optimized.
    Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    get_bundled_inputs_functions_and_info_template = ""

    for function, input_list in inputs.items():
        function_name = function.__name__

        function_arg_types = [
            arg.type for arg in function.schema.arguments[1:]
        ]  # type: ignore
        deflated_inputs_type: ListType = ListType(
            TupleType(function_arg_types))
        inflated_inputs_type: OptionalType[ListType] = OptionalType(
            deflated_inputs_type)
        model._c._register_attribute(
            "_bundled_inputs_deflated_{name}".format(name=function_name),
            deflated_inputs_type, [])
        model._c._register_attribute(
            "_bundled_inputs_inflated_{name}".format(name=function_name),
            inflated_inputs_type, None)

        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
            if input_list is not None:
                raise Exception(
                    "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined"
                    .format(name=function_name))
            # Model author already defined _generate_bundled_inputs_for_<function_name>.
        elif input_list is None or len(input_list) == 0:
            raise Exception(
                "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined"
                .format(name=function_name, ))
        else:
            # Iterate over the inputs and args in each input.
            # Accumulate `deflated_inputs` as (possibly) compressed values
            # and `parts` to be joined into the expression that unpacks them.
            deflated_inputs = []
            parts = []
            for inp_idx, args in enumerate(input_list):
                deflated_args = []
                parts.append("(")
                for arg_idx, arg in enumerate(args):
                    deflated, inflater = _inflate_expr(
                        arg, f"deflated[{inp_idx}][{arg_idx}]")
                    deflated_args.append(deflated)
                    parts.append(f"    {inflater},")
                deflated_inputs.append(tuple(deflated_args))
                parts.append("),")
            parts.append("")
            expr = "\n".join(parts)
            # Back-channel return this expr for debugging.
            if _receive_inflate_expr is not None:
                _receive_inflate_expr.append(expr)
            model._bundled_inputs_deflated = deflated_inputs
            setattr(
                model,
                "_bundled_inputs_deflated_{name}".format(name=function_name),
                deflated_inputs)
            definition = textwrap.dedent("""
                def _generate_bundled_inputs_for_{name}(self):
                    deflated = self._bundled_inputs_deflated_{name}
                    return [
                {expr}
                    ]
                """).format(expr=expr, name=function_name)
            model.define(definition)

        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
        model.define(
            textwrap.dedent("""
            def get_all_bundled_inputs_for_{name}(self):
                if self._bundled_inputs_inflated_{name} is None:
                    self._bundled_inputs_inflated_{name} = self._generate_bundled_inputs_for_{name}()
                all_inputs = self._bundled_inputs_inflated_{name}
                assert all_inputs is not None
                return all_inputs
            """).format(name=function_name))

        # Add to the high level helper methods
        inputs_info = repr(
            info[function]) if info and function in info else '[]'
        get_bundled_inputs_functions_and_info_template += """
            temp_dict : Dict[str,List[str]] = {{}}
            info: List[str] = {info}

            temp_dict['info'] = info
            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
            all_inputs['{name}'] = temp_dict
            """.format(
            name=function_name,
            info=inputs_info,
        )

        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
        if function_name == 'forward':
            model.define(
                textwrap.dedent("""
                def get_all_bundled_inputs(self):
                    return self.get_all_bundled_inputs_for_forward()
                """))
            model.define(
                textwrap.dedent("""
                def get_num_bundled_inputs(self):
                    return len(self.get_all_bundled_inputs_for_forward())
                """))
            model.define(
                textwrap.dedent("""
                def run_on_bundled_input(self, idx: int):
                    return self(*self.get_all_bundled_inputs()[idx])
                """))

    # Define some high level helper methods that act on all bundled inputs
    model.define(
        textwrap.dedent("""
        def get_bundled_inputs_functions_and_info(self):
            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
            {template}
            return all_inputs
        """.format(template=get_bundled_inputs_functions_and_info_template)))
Ejemplo n.º 13
0
def augment_many_model_functions_with_bundled_inputs(
    model: torch.jit.ScriptModule,
    inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
    _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
    info: Optional[Dict[Callable, List[
        str]]] = None,  # Optional argument to provide info about the function or its inputs
    skip_size_check=False,
) -> None:
    """Add bundled sample inputs to a model for an arbitrary list of public functions.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`

        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
            Returns a dictionary mapping function names to a metadata dictionary.
            This nested dictionary maps preset strings like:
                'get_inputs_function_name' -> the name of a function attribute in this model that can be
                    run to get back a list of inputs corresponding to that function.
                'info' -> the user provided extra information about the bundled inputs

    If forward has bundled inputs then these following functions are also defined:

        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs(): model(*inp)`

        `get_num_bundled_inputs() -> int`
            Equivalent to `len(model.get_all_bundled_inputs())`,
            but slightly easier to call from C++.

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs_for_<function_name>`.
        If the user chooses this method inputs[<function>] should map to None

      - The `inputs` argument to this function can be a dictionary mapping functions to a
        list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
        The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
        list of inputs, the inner tuple is the list of args that together make up one input.
        For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
        is the actual data that makes up the args, e.g. a tensor.

    Info is an optional parameter that maps functions to a list of strings providing extra information about that
    function's bundled inputs. This could be descriptions, expected outputs, etc.
        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}

    This function will attempt to optimize arguments so that (e.g.)
    arguments like `torch.zeros(1000)` will be represented compactly.
    Only top-level arguments will be optimized.
    Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    if not inputs:
        raise Exception("Please provide inputs for at least 1 function")

    if hasattr(model, "get_all_bundled_inputs") or hasattr(
            model, "get_bundled_inputs_functions_and_info"):
        raise Exception(
            "Models can only be augmented with bundled inputs once. "
            "This Model seems to have already been augmented with "
            "bundled inputs. Please start afresh with one that "
            "doesn't have bundled inputs.", )

    get_bundled_inputs_functions_and_info_template = ""

    for function, input_list in inputs.items():
        if hasattr(function, "__name__"):
            function_name = function.__name__
        else:
            if hasattr(function, "name"):
                function_name = function.name  # type: ignore[attr-defined]
            else:
                raise Exception(
                    'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"'
                )

        if input_list is not None and not isinstance(input_list, Sequence):
            raise TypeError(
                "Error inputs for function {0} is not a Sequence".format(
                    function_name))

        function_arg_types = [
            arg.type for arg in function.schema.arguments[1:]
        ]  # type: ignore[attr-defined]
        deflated_inputs_type: ListType = ListType(
            TupleType(function_arg_types))
        model._c._register_attribute(
            "_bundled_inputs_deflated_{name}".format(name=function_name),
            deflated_inputs_type, [])

        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
            if input_list is not None:
                raise Exception(
                    "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined"
                    .format(name=function_name))
            # Model author already defined _generate_bundled_inputs_for_<function_name>.
        elif input_list is None or len(input_list) == 0:
            raise Exception(
                "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined"
                .format(name=function_name, ))
        else:
            # Iterate over the inputs and args in each input.
            # Accumulate `deflated_inputs` as (possibly) compressed values
            # and `parts` to be joined into the expression that unpacks them.
            deflated_inputs = []
            parts = []
            for inp_idx, args in enumerate(input_list):
                if not isinstance(args, Tuple) and not isinstance(
                        args, List):  # type: ignore[arg-type]
                    raise TypeError(
                        "Error bundled input for function {0} idx: {1} is not a Tuple or a List"
                        .format(function_name, inp_idx))
                deflated_args = []
                parts.append("(")
                for arg_idx, arg in enumerate(args):
                    inflate_helper_fn_name = _get_inflate_helper_fn_name(
                        arg_idx, inp_idx, function_name)
                    deflated, inflater, helper_definition = _inflate_expr(
                        arg,
                        f"deflated[{inp_idx}][{arg_idx}]",
                        inflate_helper_fn_name,
                        skip_size_check=skip_size_check,
                    )
                    deflated_args.append(deflated)
                    parts.append(f"    {inflater},")
                    if helper_definition:
                        model.define(textwrap.dedent(helper_definition))
                deflated_inputs.append(tuple(deflated_args))
                parts.append("),")
            parts.append("")
            expr = "\n".join(parts)

            # Back-channel return this expr for debugging.
            if _receive_inflate_expr is not None:
                _receive_inflate_expr.append(expr)
            setattr(
                model,
                "_bundled_inputs_deflated_{name}".format(name=function_name),
                deflated_inputs)
            definition = textwrap.dedent("""
                def _generate_bundled_inputs_for_{name}(self):
                    deflated = self._bundled_inputs_deflated_{name}
                    return [
                {expr}
                    ]
                """).format(expr=expr, name=function_name)
            model.define(definition)

        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
        model.define(
            textwrap.dedent("""
            def get_all_bundled_inputs_for_{name}(self):
                all_inputs = self._generate_bundled_inputs_for_{name}()
                assert all_inputs is not None
                return all_inputs
            """).format(name=function_name))

        # Add to the high level helper methods
        inputs_info = repr(
            info[function]) if info and function in info else '[]'
        get_bundled_inputs_functions_and_info_template += """
            temp_dict : Dict[str,List[str]] = {{}}
            info: List[str] = {info}

            temp_dict['info'] = info
            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
            all_inputs['{name}'] = temp_dict
            """.format(
            name=function_name,
            info=inputs_info,
        )

        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
        if function_name == 'forward':
            model.define(
                textwrap.dedent("""
                def get_all_bundled_inputs(self):
                    return self.get_all_bundled_inputs_for_forward()
                """))
            model.define(
                textwrap.dedent("""
                def get_num_bundled_inputs(self):
                    return len(self.get_all_bundled_inputs_for_forward())
                """))

    # Define some high level helper methods that act on all bundled inputs
    model.define(
        textwrap.dedent("""
        def get_bundled_inputs_functions_and_info(self):
            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
            {template}
            return all_inputs
        """.format(template=get_bundled_inputs_functions_and_info_template)))
Ejemplo n.º 14
0
def try_ann_to_type(ann, loc):
    if ann is inspect.Signature.empty:
        return TensorType.getInferred()
    if ann is None:
        return NoneType.get()
    if inspect.isclass(ann) and is_tensor(ann):
        return TensorType.get()
    if is_tuple(ann):
        # Special case for the empty Tuple type annotation `Tuple[()]`
        if len(ann.__args__) == 1 and ann.__args__[0] == ():
            return TupleType([])
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        # Raise error if key or value is None
        if key is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
        if value is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            contained = ann.__args__[0]
        else:
            contained = ann.__args__[1]
        valid_type = try_ann_to_type(contained, loc)
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
        return OptionalType(valid_type)
    if is_union(ann):
        # TODO: this is hack to recognize NumberType
        if set(ann.__args__) == set([int, float, complex]):
            return NumberType.get()
        inner: List = []
        # We need these extra checks because both `None` and invalid
        # values will return `None`
        # TODO: Determine if the other cases need to be fixed as well
        for a in ann.__args__:
            if a is None:
                inner.append(NoneType.get())
            maybe_type = try_ann_to_type(a, loc)
            msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
            assert maybe_type, msg.format(repr(ann), repr(maybe_type))
            inner.append(maybe_type)
        return UnionType(inner)    # type: ignore[arg-type]
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is complex:
        return ComplexType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(ann.__torch_script_interface__)
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.Stream:
        return StreamObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        if _get_script_class(ann) is None:
            scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
            name = scripted_class.qualified_name()
        else:
            name = _qualified_name(ann)
        return EnumType(name, get_enum_value_type(ann, loc), list(ann))
    if inspect.isclass(ann):
        maybe_script_class = _get_script_class(ann)
        if maybe_script_class is not None:
            return maybe_script_class
        if torch._jit_internal.can_compile_class(ann):
            return torch.jit._script._recursive_compile_class(ann, loc)

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None
    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
Ejemplo n.º 15
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.getInferred()
    if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        # Raise error if key or value is None
        if key is None:
            raise ValueError(
                f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}"
            )
        if value is None:
            raise ValueError(
                f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}"
            )
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            contained = ann.__args__[0]
        else:
            contained = ann.__args__[1]
        valid_type = try_ann_to_type(contained, loc)
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
        return OptionalType(valid_type)
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is complex:
        return ComplexType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.Stream:
        return StreamObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        qualified_name = _qualified_name(ann)
        if _get_script_class(qualified_name) is None:
            torch.jit._script._recursive_compile_class(ann, loc)
        return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc),
                        list(ann))
    if inspect.isclass(ann):
        qualified_name = _qualified_name(ann)
        if _get_script_class(qualified_name) is not None:
            return ClassType(qualified_name)
        ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(
                ann, ignored_builtin_classes):
            torch.jit._script._recursive_compile_class(ann, loc)
            return ClassType(qualified_name)

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None

    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
Ejemplo n.º 16
0
def _run_symbolic_function(g,
                           n,
                           inputs,
                           env,
                           operator_export_type=OperatorExportTypes.ONNX):
    # NB: Returning None means the node gets cloned as is into
    # the new graph
    try:
        import torch
        from torch.onnx.symbolic_helper import _export_onnx_opset_version as opset_version
        import torch.onnx.symbolic_registry as sym_registry

        sym_registry.register_version('', opset_version)
        if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
            import torch.onnx.symbolic_caffe2
            torch.onnx.symbolic_caffe2.register_quantized_ops(
                'caffe2', opset_version)

        # See Note [Export inplace]
        # TODO: I think this is not necessary anymore
        if n.kind().endswith('_'):
            ns_op_name = n.kind()[:-1]
        else:
            ns_op_name = n.kind()
        ns, op_name = ns_op_name.split("::")
        if ns == "onnx":
            # Use the original node directly
            return None

        elif ns == "aten":
            is_exportable_aten_op = sym_registry.is_registered_op(
                op_name, '', opset_version)
            is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
            is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
            if is_onnx_aten_export or (not is_exportable_aten_op
                                       and is_aten_fallback_export):
                # Direct ATen export requested
                attrs = {
                    k + "_" + n.kindOf(k)[0]: n[k]
                    for k in n.attributeNames()
                }
                outputs = n.outputsSize()
                attrs["outputs"] = outputs
                return _graph_at(g, op_name, *inputs, aten=True, **attrs)

            else:
                # Export it regularly
                attrs = {k: n[k] for k in n.attributeNames()}
                if not is_exportable_aten_op:
                    warnings.warn(
                        "ONNX export failed on ATen operator {} because "
                        "torch.onnx.symbolic_opset{}.{} does not exist".format(
                            op_name, opset_version, op_name))
                op_fn = sym_registry.get_registered_op(op_name, '',
                                                       opset_version)
                return op_fn(g, *inputs, **attrs)

        elif ns == "prim":
            if op_name == "Constant" and not n.mustBeNone():
                if n.kindOf("value") == "t":
                    return g.op("Constant", value_t=n["value"])
                if n.kindOf("value") == "s":
                    return g.op("Constant", value_s=n["value"])
                elif n.output().type().isSubtypeOf(
                        ListType.ofInts()) or n.output().type().isSubtypeOf(
                            ListType.ofFloats()):
                    vals = n.output().toIValue()
                    value = torch.stack([torch.tensor(v)
                                         for v in vals]) if len(vals) else []
                    return g.op("Constant", value_t=value)
                elif n.output().type().kind() == "DeviceObjType":
                    return None
                else:
                    raise RuntimeError(
                        "Unsupported prim::Constant kind: `{}`. Send a bug report."
                        .format(n.kindOf("value")))
            elif n.mustBeNone(
            ) or op_name == "ListConstruct" or op_name == "ListUnpack":
                # None is not an ONNX operator; keep it as None
                # let the exporter handle finally eliminating these

                # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
                return None
            elif op_name == 'Loop' or op_name == 'If':
                new_op_outputs = g.op(op_name,
                                      *inputs,
                                      outputs=n.outputsSize())
                new_node = new_op_outputs[0].node(
                ) if n.outputsSize() > 1 else new_op_outputs.node()
                for b in n.blocks():
                    new_block = new_node.addBlock()
                    torch._C._jit_pass_onnx_block(b, new_block,
                                                  operator_export_type, env)
                return new_op_outputs
            else:
                # TODO: we sould lift prim's symbolic out
                symbolic_name = 'prim_' + op_name
                is_exportable = sym_registry.is_registered_op(
                    symbolic_name, '', opset_version)
                if not is_exportable:
                    warnings.warn(
                        "ONNX export failed on primitive operator {}; please report a bug"
                        .format(op_name))
                symbolic_fn = sym_registry.get_registered_op(
                    symbolic_name, '', opset_version)
                attrs = {k: n[k] for k in n.attributeNames()}
                return symbolic_fn(g, *inputs, **attrs)

        elif ns == "quantized":
            domain = ''
            if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK:
                domain = 'caffe2'
            attrs = {k: n[k] for k in n.attributeNames()}

            if not sym_registry.is_registered_op(op_name, domain,
                                                 opset_version):
                warnings.warn(
                    "ONNX export failed on quantized operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. ".format(
                        ns, op_name, opset_version, op_name))
            op_fn = sym_registry.get_registered_op(op_name, domain,
                                                   opset_version)
            return op_fn(g, *inputs, **attrs)

        # custom ops
        elif sym_registry.is_registered_version(ns, opset_version):
            if not sym_registry.is_registered_op(op_name, ns, opset_version):
                warnings.warn(
                    "ONNX export failed on custom operator {}::{} because "
                    "torch.onnx.symbolic_opset{}.{} does not exist. "
                    "Have you registered your symbolic function with "
                    "torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn)?"
                    .format(ns, op_name, opset_version, op_name))
            symbolic_fn = sym_registry.get_registered_op(
                op_name, ns, opset_version)
            attrs = {k: n[k] for k in n.attributeNames()}
            return symbolic_fn(g, *inputs, **attrs)

        else:
            warnings.warn(
                "ONNX export failed on an operator with unrecognized namespace {}::{}; "
                "If you are trying to export a custom operator, make sure you registered "
                "it with the right domain and version."
                "Otherwise please report a bug".format(ns, op_name))
            return None

    except TypeError as e:
        # Handle the specific case where we didn't successfully dispatch.
        # Otherwise, the backtrace will have the clues you need.
        e.args = ("{} (occurred when translating {})".format(
            e.args[0], op_name), )
        raise
Ejemplo n.º 17
0
def augment_model_with_bundled_inputs(
        model: torch.jit.ScriptModule,
        inputs: Optional[List[Tuple[Any, ...]]] = None,
        _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
) -> None:
    """Add bundled sample inputs to a model.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

      `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
        Returns a list of tuples suitable for passing to the model like
        `for inp in model.get_all_bundled_inputs(): model(*inp)`

      `get_num_bundled_inputs() -> int`
        Equivalent to `len(model.get_all_bundled_inputs())`,
        but slightly easier to call from C++.

      `run_on_bundled_input(idx: int) -> Any`
        Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs`
        get_all_bundled_inputs will simply call this method
        and cache the value.
      - The `inputs` argument to this function can be a list of tuples,
        of the same form that will be returned by get_all_bundled_inputs.
        This function will attempt to optimize arguments so that (e.g.)
        arguments like `torch.zeros(1000)` will be represented compactly.
        Only top-level arguments will be optimized.
        Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    forward_arg_types = [
        arg.type for arg in model.forward.schema.arguments[1:]
    ]
    deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types))
    inflated_inputs_type: OptionalType[ListType] = OptionalType(
        deflated_inputs_type)
    model._c._register_attribute("_bundled_inputs_deflated",
                                 deflated_inputs_type, [])
    model._c._register_attribute("_bundled_inputs_inflated",
                                 inflated_inputs_type, None)

    if hasattr(model, "_generate_bundled_inputs"):
        if inputs is not None:
            raise Exception(
                "inputs is not None, but _generate_bundled_inputs is already defined"
            )
        # Model author already defined _generate_bundled_inputs.
    elif inputs is None:
        raise Exception(
            "inputs must be specified if _generate_bundled_inputs is not already defined"
        )
    else:
        # Iterate over the inputs and args in each input.
        # Accumulate `deflated_inputs` as (possibly) compressed values
        # and `parts` to be joined into the expression that unpacks them.
        deflated_inputs = []
        parts = []
        for inp_idx, args in enumerate(inputs):
            deflated_args = []
            parts.append("(")
            for arg_idx, arg in enumerate(args):
                deflated, inflater = _inflate_expr(
                    arg, f"deflated[{inp_idx}][{arg_idx}]")
                deflated_args.append(deflated)
                parts.append(f"    {inflater},")
            deflated_inputs.append(tuple(deflated_args))
            parts.append("),")
        parts.append("")
        expr = "\n".join(parts)
        # Back-channel return this expr for debugging.
        if _receive_inflate_expr is not None:
            _receive_inflate_expr.append(expr)
        model._bundled_inputs_deflated = deflated_inputs
        definition = textwrap.dedent("""
            def _generate_bundled_inputs(self):
                deflated = self._bundled_inputs_deflated
                return [
            {}
                ]
            """).format(expr)
        model.define(definition)

    # Define get_all_bundled_inputs that caches the generated inputs.
    model.define(
        textwrap.dedent("""
        def get_all_bundled_inputs(self):
            if self._bundled_inputs_inflated is None:
                self._bundled_inputs_inflated = self._generate_bundled_inputs()
            all_inputs = self._bundled_inputs_inflated
            assert all_inputs is not None
            return all_inputs
        """))

    # Define some helper methods.
    model.define(
        textwrap.dedent("""
        def get_num_bundled_inputs(self):
            return len(self.get_all_bundled_inputs())
        """))
    model.define(
        textwrap.dedent("""
        def run_on_bundled_input(self, idx: int):
            return self(*self.get_all_bundled_inputs()[idx])
        """))