Example #1
0
def augment_model_with_bundled_inputs(
        model: torch.jit.ScriptModule,
        inputs: Optional[List[Tuple[Any, ...]]] = None,
        _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
) -> None:
    """Add bundled sample inputs to a model.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

      `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
        Returns a list of tuples suitable for passing to the model like
        `for inp in model.get_all_bundled_inputs(): model(*inp)`

      `get_num_bundled_inputs() -> int`
        Equivalent to `len(model.get_all_bundled_inputs())`,
        but slightly easier to call from C++.

      `run_on_bundled_input(idx: int) -> Any`
        Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs`
        get_all_bundled_inputs will simply call this method
        and cache the value.
      - The `inputs` argument to this function can be a list of tuples,
        of the same form that will be returned by get_all_bundled_inputs.
        This function will attempt to optimize arguments so that (e.g.)
        arguments like `torch.zeros(1000)` will be represented compactly.
        Only top-level arguments will be optimized.
        Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    forward_arg_types = [arg.type for arg in model.forward.schema.arguments[1:]]
    deflated_inputs_type = torch._C.ListType(torch._C.TupleType(forward_arg_types))
    inflated_inputs_type = torch._C.OptionalType(deflated_inputs_type)
    model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, [])
    model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None)

    if hasattr(model, "_generate_bundled_inputs"):
        if inputs is not None:
            raise Exception(
                "inputs is not None, but _generate_bundled_inputs is already defined")
        # Model author already defined _generate_bundled_inputs.
    elif inputs is None:
        raise Exception(
            "inputs must be specified if _generate_bundled_inputs is not already defined")
    else:
        # Iterate over the inputs and args in each input.
        # Accumulate `deflated_inputs` as (possibly) compressed values
        # and `parts` to be joined into the expression that unpacks them.
        deflated_inputs = []
        parts = []
        for inp_idx, args in enumerate(inputs):
            deflated_args = []
            parts.append("(")
            for arg_idx, arg in enumerate(args):
                deflated, inflater = _inflate_expr(arg, f"deflated[{inp_idx}][{arg_idx}]")
                deflated_args.append(deflated)
                parts.append(f"    {inflater},")
            deflated_inputs.append(tuple(deflated_args))
            parts.append("),")
        parts.append("")
        expr = "\n".join(parts)
        # Back-channel return this expr for debugging.
        if _receive_inflate_expr is not None:
            _receive_inflate_expr.append(expr)
        model._bundled_inputs_deflated = deflated_inputs
        definition = textwrap.dedent("""
            def _generate_bundled_inputs(self):
                deflated = self._bundled_inputs_deflated
                return [
            {}
                ]
            """).format(expr)
        model.define(definition)

    # Define get_all_bundled_inputs that caches the generated inputs.
    model.define(textwrap.dedent("""
        def get_all_bundled_inputs(self):
            if self._bundled_inputs_inflated is None:
                self._bundled_inputs_inflated = self._generate_bundled_inputs()
            all_inputs = self._bundled_inputs_inflated
            assert all_inputs is not None
            return all_inputs
        """))

    # Define some helper methods.
    model.define(textwrap.dedent("""
        def get_num_bundled_inputs(self):
            return len(self.get_all_bundled_inputs())
        """))
    model.define(textwrap.dedent("""
        def run_on_bundled_input(self, idx: int):
            return self(*self.get_all_bundled_inputs()[idx])
        """))
Example #2
0
def augment_many_model_functions_with_bundled_inputs(
    model: torch.jit.ScriptModule,
    inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
    _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
    info: Optional[Dict[Callable, List[
        str]]] = None,  # Optional argument to provide info about the function or its inputs
) -> None:
    """Add bundled sample inputs to a model for an arbitrary list of public functions.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`

        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
            Returns a dictionary mapping function names to a metadata dictionary.
            This nested dictionary maps preset strings like:
                'get_inputs_function_name' -> the name of a function attribute in this model that can be
                    run to get back a list of inputs corresponding to that function.
                'info' -> the user provided extra information about the bundled inputs

    If forward has bundled inputs then these following functions are also defined:

        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs(): model(*inp)`

        `get_num_bundled_inputs() -> int`
            Equivalent to `len(model.get_all_bundled_inputs())`,
            but slightly easier to call from C++.

        `run_on_bundled_input(idx: int) -> Any`
            Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs_for_<function_name>`
        get_all_bundled_inputs will simply call this method
        and cache the value. If the user chooses this method inputs[<function>]
        should map to None
      - The `inputs` argument to this function can be a dictionary mapping functions to a
        list of tuples, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.

      It is highly recommended (though not enforced) that if multiple functions have the same input style, that
      you create separate bundled inputs for each function. Reusing the same input and bundling it to multiple
      functions can cause issues with other torch.jit functionality like freeze

    Info is an optional parameter that maps functions to a list of strings providing extra information about that
    function's bundled inputs. This could be descriptions, expected outputs, etc.
        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}

    This function will attempt to optimize arguments so that (e.g.)
    arguments like `torch.zeros(1000)` will be represented compactly.
    Only top-level arguments will be optimized.
    Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    get_bundled_inputs_functions_and_info_template = ""

    for function, input_list in inputs.items():
        function_name = function.__name__

        function_arg_types = [
            arg.type for arg in function.schema.arguments[1:]
        ]  # type: ignore
        deflated_inputs_type: ListType = ListType(
            TupleType(function_arg_types))
        inflated_inputs_type: OptionalType[ListType] = OptionalType(
            deflated_inputs_type)
        model._c._register_attribute(
            "_bundled_inputs_deflated_{name}".format(name=function_name),
            deflated_inputs_type, [])
        model._c._register_attribute(
            "_bundled_inputs_inflated_{name}".format(name=function_name),
            inflated_inputs_type, None)

        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
            if input_list is not None:
                raise Exception(
                    "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined"
                    .format(name=function_name))
            # Model author already defined _generate_bundled_inputs_for_<function_name>.
        elif input_list is None or len(input_list) == 0:
            raise Exception(
                "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined"
                .format(name=function_name, ))
        else:
            # Iterate over the inputs and args in each input.
            # Accumulate `deflated_inputs` as (possibly) compressed values
            # and `parts` to be joined into the expression that unpacks them.
            deflated_inputs = []
            parts = []
            for inp_idx, args in enumerate(input_list):
                deflated_args = []
                parts.append("(")
                for arg_idx, arg in enumerate(args):
                    deflated, inflater = _inflate_expr(
                        arg, f"deflated[{inp_idx}][{arg_idx}]")
                    deflated_args.append(deflated)
                    parts.append(f"    {inflater},")
                deflated_inputs.append(tuple(deflated_args))
                parts.append("),")
            parts.append("")
            expr = "\n".join(parts)
            # Back-channel return this expr for debugging.
            if _receive_inflate_expr is not None:
                _receive_inflate_expr.append(expr)
            model._bundled_inputs_deflated = deflated_inputs
            setattr(
                model,
                "_bundled_inputs_deflated_{name}".format(name=function_name),
                deflated_inputs)
            definition = textwrap.dedent("""
                def _generate_bundled_inputs_for_{name}(self):
                    deflated = self._bundled_inputs_deflated_{name}
                    return [
                {expr}
                    ]
                """).format(expr=expr, name=function_name)
            model.define(definition)

        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
        model.define(
            textwrap.dedent("""
            def get_all_bundled_inputs_for_{name}(self):
                if self._bundled_inputs_inflated_{name} is None:
                    self._bundled_inputs_inflated_{name} = self._generate_bundled_inputs_for_{name}()
                all_inputs = self._bundled_inputs_inflated_{name}
                assert all_inputs is not None
                return all_inputs
            """).format(name=function_name))

        # Add to the high level helper methods
        inputs_info = repr(
            info[function]) if info and function in info else '[]'
        get_bundled_inputs_functions_and_info_template += """
            temp_dict : Dict[str,List[str]] = {{}}
            info: List[str] = {info}

            temp_dict['info'] = info
            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
            all_inputs['{name}'] = temp_dict
            """.format(
            name=function_name,
            info=inputs_info,
        )

        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
        if function_name == 'forward':
            model.define(
                textwrap.dedent("""
                def get_all_bundled_inputs(self):
                    return self.get_all_bundled_inputs_for_forward()
                """))
            model.define(
                textwrap.dedent("""
                def get_num_bundled_inputs(self):
                    return len(self.get_all_bundled_inputs_for_forward())
                """))
            model.define(
                textwrap.dedent("""
                def run_on_bundled_input(self, idx: int):
                    return self(*self.get_all_bundled_inputs()[idx])
                """))

    # Define some high level helper methods that act on all bundled inputs
    model.define(
        textwrap.dedent("""
        def get_bundled_inputs_functions_and_info(self):
            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
            {template}
            return all_inputs
        """.format(template=get_bundled_inputs_functions_and_info_template)))