コード例 #1
0
ファイル: annotations.py プロジェクト: Huskyeder/pytorch
def ann_to_type(ann):
    if ann is None:
        return TensorType.get()
    elif ann is torch.Tensor:
        return TensorType.get()
    elif is_tuple(ann):
        return TupleType([ann_to_type(a) for a in ann.__args__])
    elif is_list(ann):
        return ListType(ann_to_type(ann.__args__[0]))
    elif is_dict(ann):
        key = ann_to_type(ann.__args__[0])
        value = ann_to_type(ann.__args__[1])
        return DictType(key, value)
    elif is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(ann_to_type(ann.__args__[0]))
        else:
            return OptionalType(ann_to_type(ann.__args__[1]))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    elif ann is bool:
        return BoolType.get()
    elif hasattr(ann, "__torch_script_class__"):
        return ClassType(_qualified_name(ann))
    raise ValueError("Unknown type annotation: '{}'".format(ann))
コード例 #2
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(try_ann_to_type(ann.__args__[0], loc))
        else:
            return OptionalType(try_ann_to_type(ann.__args__[1], loc))
    if is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann):
        if hasattr(ann, "__torch_script_class__"):
            return ClassType(_qualified_name(ann))
        # Why Callable?  forward is declared to be a Callable so that
        # people can define it without mypy complaining.  But we shouldn't
        # try to recursively compile it!
        ignored_builtin_classes = (torch.nn.Module, tuple, list, Callable)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(
                ann, ignored_builtin_classes):
            torch.jit._recursive_compile_class(ann, loc)
            return ClassType(_qualified_name(ann))

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None

    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
コード例 #3
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    if ann is torch.Tensor:
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(try_ann_to_type(ann.__args__[0], loc))
        else:
            return OptionalType(try_ann_to_type(ann.__args__[1], loc))
    if is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if inspect.isclass(ann):
        if hasattr(ann, "__torch_script_class__"):
            return ClassType(_qualified_name(ann))
        ignored_builtin_classes = (torch.nn.Module, tuple, list)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes):
            torch.jit._recursive_compile_class(ann, loc)
            return ClassType(_qualified_name(ann))

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None
    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
コード例 #4
0
ファイル: annotations.py プロジェクト: zollhoefer/pytorch
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    elif ann is torch.Tensor:
        return TensorType.get()
    elif is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    elif is_list(ann):
        return ListType(try_ann_to_type(ann.__args__[0], loc))
    elif is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    elif is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(try_ann_to_type(ann.__args__[0], loc))
        else:
            return OptionalType(try_ann_to_type(ann.__args__[1], loc))
    elif is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    elif is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    elif ann is bool:
        return BoolType.get()
    elif ann is Any:
        return AnyType.get()
    elif ann is type(None):
        return NoneType.get()
    elif inspect.isclass(ann) and hasattr(ann, "__torch_script_class__"):
        return ClassType(_qualified_name(ann))
    elif inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    elif ann is torch.device:
        return DeviceObjType.get()
    else:
        # Maybe resolve a NamedTuple to a Tuple Type
        def fake_rcb(key):
            return None

        the_type = torch._C._resolve_type_from_object(ann, loc, fake_rcb)
        if the_type is not None:
            return the_type
    return None
コード例 #5
0
ファイル: annotations.py プロジェクト: yunpengsnail/pytorch
def ann_to_type(ann, resolver=None):
    # resolver should be a Tuple[Callable, SourceRange] where the Callable
    # is a resolutionCallback
    if ann is None:
        return TensorType.get()
    elif ann is torch.Tensor:
        return TensorType.get()
    elif is_tuple(ann):
        return TupleType([ann_to_type(a) for a in ann.__args__])
    elif is_list(ann):
        return ListType(ann_to_type(ann.__args__[0]))
    elif is_dict(ann):
        key = ann_to_type(ann.__args__[0])
        value = ann_to_type(ann.__args__[1])
        return DictType(key, value)
    elif is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            return OptionalType(ann_to_type(ann.__args__[0]))
        else:
            return OptionalType(ann_to_type(ann.__args__[1]))
    elif is_rref(ann):
        return RRefType(ann_to_type(ann.__args__[0]))
    elif ann is float:
        return FloatType.get()
    elif ann is int:
        return IntType.get()
    elif ann is str:
        return StringType.get()
    elif ann is bool:
        return BoolType.get()
    elif ann is Any:
        return AnyType.get()
    elif ann is type(None):
        return NoneType.get()
    elif hasattr(ann, "__torch_script_class__"):
        return ClassType(_qualified_name(ann))
    elif hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    elif ann is torch.device:
        return DeviceObjType.get()
    elif resolver is not None:
        # Maybe resolve a NamedTuple to a Tuple Type
        rcb, loc = resolver
        the_type = torch._C._resolve_type(ann.__name__, loc, rcb)
        if the_type is not None:
            return the_type
    raise ValueError("Unknown type annotation: '{}'".format(ann))
コード例 #6
0
ファイル: utils.py プロジェクト: zollhoefer/pytorch
def _is_constant_tensor_list(node):
    if node.kind() != "prim::Constant":
        return False
    output_type = node.output().type()
    if output_type.isSubtypeOf(ListType.ofTensors()):
        return True
    if output_type.isSubtypeOf(ListType(OptionalType.ofTensor())):
        return True
コード例 #7
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            valid_type = try_ann_to_type(ann.__args__[0], loc)
        else:
            valid_type = try_ann_to_type(ann.__args__[1], loc)
        assert valid_type, "Unsupported annotation {} could not be resolved.".format(
            repr(ann))
        return OptionalType(valid_type)
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        if not is_enum_support_enabled():
            raise NotImplementedError(
                "Enum support is work in progress, please do not use it now")
        return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc))
    if inspect.isclass(ann):
        if hasattr(ann, "__torch_script_class__"):
            return ClassType(_qualified_name(ann))
        ignored_builtin_classes = (torch.nn.Module, tuple, list)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(
                ann, ignored_builtin_classes):
            torch.jit._script._recursive_compile_class(ann, loc)
            return ClassType(_qualified_name(ann))

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None

    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
コード例 #8
0
def try_ann_to_type(ann, loc):
    if ann is None:
        return TensorType.get()
    if inspect.isclass(ann) and issubclass(ann, torch.Tensor):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            contained = ann.__args__[0]
        else:
            contained = ann.__args__[1]
        valid_type = try_ann_to_type(contained, loc)
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
        return OptionalType(valid_type)
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(_qualified_name(ann))
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        qualified_name = _qualified_name(ann)
        if _get_script_class(qualified_name) is None:
            torch.jit._script._recursive_compile_class(ann, loc)
        return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc),
                        list(ann))
    if inspect.isclass(ann):
        qualified_name = _qualified_name(ann)
        if _get_script_class(qualified_name) is not None:
            return ClassType(qualified_name)
        ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
        if torch._jit_internal.can_compile_class(ann) and not issubclass(
                ann, ignored_builtin_classes):
            torch.jit._script._recursive_compile_class(ann, loc)
            return ClassType(qualified_name)

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None

    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
コード例 #9
0
def try_ann_to_type(ann, loc):
    if ann is inspect.Signature.empty:
        return TensorType.getInferred()
    if ann is None:
        return NoneType.get()
    if inspect.isclass(ann) and is_tensor(ann):
        return TensorType.get()
    if is_tuple(ann):
        # Special case for the empty Tuple type annotation `Tuple[()]`
        if len(ann.__args__) == 1 and ann.__args__[0] == ():
            return TupleType([])
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        # Raise error if key or value is None
        if key is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
        if value is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            contained = ann.__args__[0]
        else:
            contained = ann.__args__[1]
        valid_type = try_ann_to_type(contained, loc)
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
        return OptionalType(valid_type)
    if is_union(ann):
        # TODO: this is hack to recognize NumberType
        if set(ann.__args__) == set([int, float, complex]):
            return NumberType.get()
        inner: List = []
        # We need these extra checks because both `None` and invalid
        # values will return `None`
        # TODO: Determine if the other cases need to be fixed as well
        for a in ann.__args__:
            if a is None:
                inner.append(NoneType.get())
            maybe_type = try_ann_to_type(a, loc)
            msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
            assert maybe_type, msg.format(repr(ann), repr(maybe_type))
            inner.append(maybe_type)
        return UnionType(inner)    # type: ignore[arg-type]
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is complex:
        return ComplexType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(ann.__torch_script_interface__)
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.Stream:
        return StreamObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        if _get_script_class(ann) is None:
            scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
            name = scripted_class.qualified_name()
        else:
            name = _qualified_name(ann)
        return EnumType(name, get_enum_value_type(ann, loc), list(ann))
    if inspect.isclass(ann):
        maybe_script_class = _get_script_class(ann)
        if maybe_script_class is not None:
            return maybe_script_class
        if torch._jit_internal.can_compile_class(ann):
            return torch.jit._script._recursive_compile_class(ann, loc)

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None
    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
コード例 #10
0
def _optional_input_placeholder_tensor(g):
    n = g.op("prim::Constant")
    n.setType(OptionalType.ofTensor())
    return n
コード例 #11
0
ファイル: bundled_inputs.py プロジェクト: shankar1729/pytorch
def augment_many_model_functions_with_bundled_inputs(
    model: torch.jit.ScriptModule,
    inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
    _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
    info: Optional[Dict[Callable, List[
        str]]] = None,  # Optional argument to provide info about the function or its inputs
) -> None:
    """Add bundled sample inputs to a model for an arbitrary list of public functions.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`

        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
            Returns a dictionary mapping function names to a metadata dictionary.
            This nested dictionary maps preset strings like:
                'get_inputs_function_name' -> the name of a function attribute in this model that can be
                    run to get back a list of inputs corresponding to that function.
                'info' -> the user provided extra information about the bundled inputs

    If forward has bundled inputs then these following functions are also defined:

        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
            Returns a list of tuples suitable for passing to the model like
            `for inp in model.get_all_bundled_inputs(): model(*inp)`

        `get_num_bundled_inputs() -> int`
            Equivalent to `len(model.get_all_bundled_inputs())`,
            but slightly easier to call from C++.

        `run_on_bundled_input(idx: int) -> Any`
            Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs_for_<function_name>`
        get_all_bundled_inputs will simply call this method
        and cache the value. If the user chooses this method inputs[<function>]
        should map to None
      - The `inputs` argument to this function can be a dictionary mapping functions to a
        list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
        The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
        list of inputs, the inner tuple is the list of args that together make up one input.
        For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
        is the actual data that makes up the args, e.g. a tensor.

    Info is an optional parameter that maps functions to a list of strings providing extra information about that
    function's bundled inputs. This could be descriptions, expected outputs, etc.
        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}

    This function will attempt to optimize arguments so that (e.g.)
    arguments like `torch.zeros(1000)` will be represented compactly.
    Only top-level arguments will be optimized.
    Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    if not inputs:
        raise Exception("Please provide inputs for at least 1 function")

    get_bundled_inputs_functions_and_info_template = ""

    for function, input_list in inputs.items():
        function_name = function.__name__

        if input_list is not None and not isinstance(input_list, Sequence):
            raise TypeError(
                "Error inputs for function {0} is not a Sequence".format(
                    function_name))

        function_arg_types = [
            arg.type for arg in function.schema.arguments[1:]
        ]  # type: ignore
        deflated_inputs_type: ListType = ListType(
            TupleType(function_arg_types))
        inflated_inputs_type: OptionalType[ListType] = OptionalType(
            deflated_inputs_type)
        model._c._register_attribute(
            "_bundled_inputs_deflated_{name}".format(name=function_name),
            deflated_inputs_type, [])
        model._c._register_attribute(
            "_bundled_inputs_inflated_{name}".format(name=function_name),
            inflated_inputs_type, None)

        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
            if input_list is not None:
                raise Exception(
                    "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined"
                    .format(name=function_name))
            # Model author already defined _generate_bundled_inputs_for_<function_name>.
        elif input_list is None or len(input_list) == 0:
            raise Exception(
                "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined"
                .format(name=function_name, ))
        else:
            # Iterate over the inputs and args in each input.
            # Accumulate `deflated_inputs` as (possibly) compressed values
            # and `parts` to be joined into the expression that unpacks them.
            deflated_inputs = []
            parts = []
            for inp_idx, args in enumerate(input_list):
                if not isinstance(args, Tuple) and not isinstance(
                        args, List):  # type: ignore
                    raise TypeError(
                        "Error bundled input for function {0} idx: {1} is not a Tuple or a List"
                        .format(function_name, inp_idx))
                deflated_args = []
                parts.append("(")
                for arg_idx, arg in enumerate(args):
                    deflated, inflater = _inflate_expr(
                        arg, f"deflated[{inp_idx}][{arg_idx}]")
                    deflated_args.append(deflated)
                    parts.append(f"    {inflater},")
                deflated_inputs.append(tuple(deflated_args))
                parts.append("),")
            parts.append("")
            expr = "\n".join(parts)
            # Back-channel return this expr for debugging.
            if _receive_inflate_expr is not None:
                _receive_inflate_expr.append(expr)
            setattr(
                model,
                "_bundled_inputs_deflated_{name}".format(name=function_name),
                deflated_inputs)
            definition = textwrap.dedent("""
                def _generate_bundled_inputs_for_{name}(self):
                    deflated = self._bundled_inputs_deflated_{name}
                    return [
                {expr}
                    ]
                """).format(expr=expr, name=function_name)
            model.define(definition)

        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
        model.define(
            textwrap.dedent("""
            def get_all_bundled_inputs_for_{name}(self):
                if self._bundled_inputs_inflated_{name} is None:
                    self._bundled_inputs_inflated_{name} = self._generate_bundled_inputs_for_{name}()
                all_inputs = self._bundled_inputs_inflated_{name}
                assert all_inputs is not None
                return all_inputs
            """).format(name=function_name))

        # Add to the high level helper methods
        inputs_info = repr(
            info[function]) if info and function in info else '[]'
        get_bundled_inputs_functions_and_info_template += """
            temp_dict : Dict[str,List[str]] = {{}}
            info: List[str] = {info}

            temp_dict['info'] = info
            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
            all_inputs['{name}'] = temp_dict
            """.format(
            name=function_name,
            info=inputs_info,
        )

        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
        if function_name == 'forward':
            model.define(
                textwrap.dedent("""
                def get_all_bundled_inputs(self):
                    return self.get_all_bundled_inputs_for_forward()
                """))
            model.define(
                textwrap.dedent("""
                def get_num_bundled_inputs(self):
                    return len(self.get_all_bundled_inputs_for_forward())
                """))
            model.define(
                textwrap.dedent("""
                def run_on_bundled_input(self, idx: int):
                    return self(*self.get_all_bundled_inputs()[idx])
                """))

    # Define some high level helper methods that act on all bundled inputs
    model.define(
        textwrap.dedent("""
        def get_bundled_inputs_functions_and_info(self):
            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
            {template}
            return all_inputs
        """.format(template=get_bundled_inputs_functions_and_info_template)))
コード例 #12
0
ファイル: annotations.py プロジェクト: thomascong121/NCRF
def try_ann_to_type(ann, loc):
    if ann is inspect.Signature.empty:
        return TensorType.getInferred()
    if ann is None:
        return NoneType.get()
    if inspect.isclass(ann) and is_tensor(ann):
        return TensorType.get()
    if is_tuple(ann):
        return TupleType([try_ann_to_type(a, loc) for a in ann.__args__])
    if is_list(ann):
        elem_type = try_ann_to_type(ann.__args__[0], loc)
        if elem_type:
            return ListType(elem_type)
    if is_dict(ann):
        key = try_ann_to_type(ann.__args__[0], loc)
        value = try_ann_to_type(ann.__args__[1], loc)
        # Raise error if key or value is None
        if key is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}")
        if value is None:
            raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}")
        return DictType(key, value)
    if is_optional(ann):
        if issubclass(ann.__args__[1], type(None)):
            contained = ann.__args__[0]
        else:
            contained = ann.__args__[1]
        valid_type = try_ann_to_type(contained, loc)
        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved."
        assert valid_type, msg.format(repr(ann), repr(contained))
        return OptionalType(valid_type)
    if torch.distributed.rpc.is_available() and is_rref(ann):
        return RRefType(try_ann_to_type(ann.__args__[0], loc))
    if is_future(ann):
        return FutureType(try_ann_to_type(ann.__args__[0], loc))
    if ann is float:
        return FloatType.get()
    if ann is complex:
        return ComplexType.get()
    if ann is int:
        return IntType.get()
    if ann is str:
        return StringType.get()
    if ann is bool:
        return BoolType.get()
    if ann is Any:
        return AnyType.get()
    if ann is type(None):
        return NoneType.get()
    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
        return InterfaceType(ann.__torch_script_interface__)
    if ann is torch.device:
        return DeviceObjType.get()
    if ann is torch.Stream:
        return StreamObjType.get()
    if ann is torch.dtype:
        return IntType.get()  # dtype not yet bound in as its own type
    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
        if _get_script_class(ann) is None:
            scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
            name = scripted_class.qualified_name()
        else:
            name = _qualified_name(ann)
        return EnumType(name, get_enum_value_type(ann, loc), list(ann))
    if inspect.isclass(ann):
        maybe_script_class = _get_script_class(ann)
        if maybe_script_class is not None:
            return maybe_script_class
        if torch._jit_internal.can_compile_class(ann):
            return torch.jit._script._recursive_compile_class(ann, loc)

    # Maybe resolve a NamedTuple to a Tuple Type
    def fake_rcb(key):
        return None
    return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
コード例 #13
0
def augment_model_with_bundled_inputs(
        model: torch.jit.ScriptModule,
        inputs: Optional[List[Tuple[Any, ...]]] = None,
        _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
) -> None:
    """Add bundled sample inputs to a model.

    Models with bundled inputs can be invoked in a uniform manner by
    benchmarking and code coverage tools.

    Augmented models will support the following methods:

      `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
        Returns a list of tuples suitable for passing to the model like
        `for inp in model.get_all_bundled_inputs(): model(*inp)`

      `get_num_bundled_inputs() -> int`
        Equivalent to `len(model.get_all_bundled_inputs())`,
        but slightly easier to call from C++.

      `run_on_bundled_input(idx: int) -> Any`
        Run the model on bundled input number `idx`

    Inputs can be specified in one of two ways:

      - The model can define `_generate_bundled_inputs`
        get_all_bundled_inputs will simply call this method
        and cache the value.
      - The `inputs` argument to this function can be a list of tuples,
        of the same form that will be returned by get_all_bundled_inputs.
        This function will attempt to optimize arguments so that (e.g.)
        arguments like `torch.zeros(1000)` will be represented compactly.
        Only top-level arguments will be optimized.
        Tensors in lists or tuples will not.
    """
    if not isinstance(model, torch.jit.ScriptModule):
        raise Exception("Only ScriptModule is supported.")

    forward_arg_types = [
        arg.type for arg in model.forward.schema.arguments[1:]
    ]
    deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types))
    inflated_inputs_type: OptionalType[ListType] = OptionalType(
        deflated_inputs_type)
    model._c._register_attribute("_bundled_inputs_deflated",
                                 deflated_inputs_type, [])
    model._c._register_attribute("_bundled_inputs_inflated",
                                 inflated_inputs_type, None)

    if hasattr(model, "_generate_bundled_inputs"):
        if inputs is not None:
            raise Exception(
                "inputs is not None, but _generate_bundled_inputs is already defined"
            )
        # Model author already defined _generate_bundled_inputs.
    elif inputs is None:
        raise Exception(
            "inputs must be specified if _generate_bundled_inputs is not already defined"
        )
    else:
        # Iterate over the inputs and args in each input.
        # Accumulate `deflated_inputs` as (possibly) compressed values
        # and `parts` to be joined into the expression that unpacks them.
        deflated_inputs = []
        parts = []
        for inp_idx, args in enumerate(inputs):
            deflated_args = []
            parts.append("(")
            for arg_idx, arg in enumerate(args):
                deflated, inflater = _inflate_expr(
                    arg, f"deflated[{inp_idx}][{arg_idx}]")
                deflated_args.append(deflated)
                parts.append(f"    {inflater},")
            deflated_inputs.append(tuple(deflated_args))
            parts.append("),")
        parts.append("")
        expr = "\n".join(parts)
        # Back-channel return this expr for debugging.
        if _receive_inflate_expr is not None:
            _receive_inflate_expr.append(expr)
        model._bundled_inputs_deflated = deflated_inputs
        definition = textwrap.dedent("""
            def _generate_bundled_inputs(self):
                deflated = self._bundled_inputs_deflated
                return [
            {}
                ]
            """).format(expr)
        model.define(definition)

    # Define get_all_bundled_inputs that caches the generated inputs.
    model.define(
        textwrap.dedent("""
        def get_all_bundled_inputs(self):
            if self._bundled_inputs_inflated is None:
                self._bundled_inputs_inflated = self._generate_bundled_inputs()
            all_inputs = self._bundled_inputs_inflated
            assert all_inputs is not None
            return all_inputs
        """))

    # Define some helper methods.
    model.define(
        textwrap.dedent("""
        def get_num_bundled_inputs(self):
            return len(self.get_all_bundled_inputs())
        """))
    model.define(
        textwrap.dedent("""
        def run_on_bundled_input(self, idx: int):
            return self(*self.get_all_bundled_inputs()[idx])
        """))