def ann_to_type(ann): if ann is None: return TensorType.get() elif ann is torch.Tensor: return TensorType.get() elif is_tuple(ann): return TupleType([ann_to_type(a) for a in ann.__args__]) elif is_list(ann): return ListType(ann_to_type(ann.__args__[0])) elif is_dict(ann): key = ann_to_type(ann.__args__[0]) value = ann_to_type(ann.__args__[1]) return DictType(key, value) elif is_optional(ann): if issubclass(ann.__args__[1], type(None)): return OptionalType(ann_to_type(ann.__args__[0])) else: return OptionalType(ann_to_type(ann.__args__[1])) elif ann is float: return FloatType.get() elif ann is int: return IntType.get() elif ann is str: return StringType.get() elif ann is bool: return BoolType.get() elif hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) raise ValueError("Unknown type annotation: '{}'".format(ann))
def try_ann_to_type(ann, loc): if ann is None: return TensorType.get() if inspect.isclass(ann) and issubclass(ann, torch.Tensor): return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): return OptionalType(try_ann_to_type(ann.__args__[0], loc)) else: return OptionalType(try_ann_to_type(ann.__args__[1], loc)) if is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) if ann is torch.device: return DeviceObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann): if hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) # Why Callable? forward is declared to be a Callable so that # people can define it without mypy complaining. But we shouldn't # try to recursively compile it! ignored_builtin_classes = (torch.nn.Module, tuple, list, Callable) if torch._jit_internal.can_compile_class(ann) and not issubclass( ann, ignored_builtin_classes): torch.jit._recursive_compile_class(ann, loc) return ClassType(_qualified_name(ann)) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def try_ann_to_type(ann, loc): if ann is None: return TensorType.get() if ann is torch.Tensor: return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): return OptionalType(try_ann_to_type(ann.__args__[0], loc)) else: return OptionalType(try_ann_to_type(ann.__args__[1], loc)) if is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) if ann is torch.device: return DeviceObjType.get() if inspect.isclass(ann): if hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) ignored_builtin_classes = (torch.nn.Module, tuple, list) if torch._jit_internal.can_compile_class(ann) and not issubclass(ann, ignored_builtin_classes): torch.jit._recursive_compile_class(ann, loc) return ClassType(_qualified_name(ann)) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def try_ann_to_type(ann, loc): if ann is None: return TensorType.get() elif ann is torch.Tensor: return TensorType.get() elif is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) elif is_list(ann): return ListType(try_ann_to_type(ann.__args__[0], loc)) elif is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) return DictType(key, value) elif is_optional(ann): if issubclass(ann.__args__[1], type(None)): return OptionalType(try_ann_to_type(ann.__args__[0], loc)) else: return OptionalType(try_ann_to_type(ann.__args__[1], loc)) elif is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) elif is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) elif ann is float: return FloatType.get() elif ann is int: return IntType.get() elif ann is str: return StringType.get() elif ann is bool: return BoolType.get() elif ann is Any: return AnyType.get() elif ann is type(None): return NoneType.get() elif inspect.isclass(ann) and hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) elif inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) elif ann is torch.device: return DeviceObjType.get() else: # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None the_type = torch._C._resolve_type_from_object(ann, loc, fake_rcb) if the_type is not None: return the_type return None
def ann_to_type(ann, resolver=None): # resolver should be a Tuple[Callable, SourceRange] where the Callable # is a resolutionCallback if ann is None: return TensorType.get() elif ann is torch.Tensor: return TensorType.get() elif is_tuple(ann): return TupleType([ann_to_type(a) for a in ann.__args__]) elif is_list(ann): return ListType(ann_to_type(ann.__args__[0])) elif is_dict(ann): key = ann_to_type(ann.__args__[0]) value = ann_to_type(ann.__args__[1]) return DictType(key, value) elif is_optional(ann): if issubclass(ann.__args__[1], type(None)): return OptionalType(ann_to_type(ann.__args__[0])) else: return OptionalType(ann_to_type(ann.__args__[1])) elif is_rref(ann): return RRefType(ann_to_type(ann.__args__[0])) elif ann is float: return FloatType.get() elif ann is int: return IntType.get() elif ann is str: return StringType.get() elif ann is bool: return BoolType.get() elif ann is Any: return AnyType.get() elif ann is type(None): return NoneType.get() elif hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) elif hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) elif ann is torch.device: return DeviceObjType.get() elif resolver is not None: # Maybe resolve a NamedTuple to a Tuple Type rcb, loc = resolver the_type = torch._C._resolve_type(ann.__name__, loc, rcb) if the_type is not None: return the_type raise ValueError("Unknown type annotation: '{}'".format(ann))
def _is_constant_tensor_list(node): if node.kind() != "prim::Constant": return False output_type = node.output().type() if output_type.isSubtypeOf(ListType.ofTensors()): return True if output_type.isSubtypeOf(ListType(OptionalType.ofTensor())): return True
def try_ann_to_type(ann, loc): if ann is None: return TensorType.get() if inspect.isclass(ann) and issubclass(ann, torch.Tensor): return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): valid_type = try_ann_to_type(ann.__args__[0], loc) else: valid_type = try_ann_to_type(ann.__args__[1], loc) assert valid_type, "Unsupported annotation {} could not be resolved.".format( repr(ann)) return OptionalType(valid_type) if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) if ann is torch.device: return DeviceObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): if not is_enum_support_enabled(): raise NotImplementedError( "Enum support is work in progress, please do not use it now") return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc)) if inspect.isclass(ann): if hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) ignored_builtin_classes = (torch.nn.Module, tuple, list) if torch._jit_internal.can_compile_class(ann) and not issubclass( ann, ignored_builtin_classes): torch.jit._script._recursive_compile_class(ann, loc) return ClassType(_qualified_name(ann)) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def try_ann_to_type(ann, loc): if ann is None: return TensorType.get() if inspect.isclass(ann) and issubclass(ann, torch.Tensor): return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): contained = ann.__args__[0] else: contained = ann.__args__[1] valid_type = try_ann_to_type(contained, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert valid_type, msg.format(repr(ann), repr(contained)) return OptionalType(valid_type) if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(_qualified_name(ann)) if ann is torch.device: return DeviceObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): qualified_name = _qualified_name(ann) if _get_script_class(qualified_name) is None: torch.jit._script._recursive_compile_class(ann, loc) return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): qualified_name = _qualified_name(ann) if _get_script_class(qualified_name) is not None: return ClassType(qualified_name) ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) if torch._jit_internal.can_compile_class(ann) and not issubclass( ann, ignored_builtin_classes): torch.jit._script._recursive_compile_class(ann, loc) return ClassType(qualified_name) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def try_ann_to_type(ann, loc): if ann is inspect.Signature.empty: return TensorType.getInferred() if ann is None: return NoneType.get() if inspect.isclass(ann) and is_tensor(ann): return TensorType.get() if is_tuple(ann): # Special case for the empty Tuple type annotation `Tuple[()]` if len(ann.__args__) == 1 and ann.__args__[0] == (): return TupleType([]) return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) # Raise error if key or value is None if key is None: raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}") if value is None: raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}") return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): contained = ann.__args__[0] else: contained = ann.__args__[1] valid_type = try_ann_to_type(contained, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert valid_type, msg.format(repr(ann), repr(contained)) return OptionalType(valid_type) if is_union(ann): # TODO: this is hack to recognize NumberType if set(ann.__args__) == set([int, float, complex]): return NumberType.get() inner: List = [] # We need these extra checks because both `None` and invalid # values will return `None` # TODO: Determine if the other cases need to be fixed as well for a in ann.__args__: if a is None: inner.append(NoneType.get()) maybe_type = try_ann_to_type(a, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert maybe_type, msg.format(repr(ann), repr(maybe_type)) inner.append(maybe_type) return UnionType(inner) # type: ignore[arg-type] if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is complex: return ComplexType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(ann.__torch_script_interface__) if ann is torch.device: return DeviceObjType.get() if ann is torch.Stream: return StreamObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): if _get_script_class(ann) is None: scripted_class = torch.jit._script._recursive_compile_class(ann, loc) name = scripted_class.qualified_name() else: name = _qualified_name(ann) return EnumType(name, get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): maybe_script_class = _get_script_class(ann) if maybe_script_class is not None: return maybe_script_class if torch._jit_internal.can_compile_class(ann): return torch.jit._script._recursive_compile_class(ann, loc) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def _optional_input_placeholder_tensor(g): n = g.op("prim::Constant") n.setType(OptionalType.ofTensor()) return n
def augment_many_model_functions_with_bundled_inputs( model: torch.jit.ScriptModule, inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]], _receive_inflate_expr: Optional[List[str]] = None, # For debugging. info: Optional[Dict[Callable, List[ str]]] = None, # Optional argument to provide info about the function or its inputs ) -> None: """Add bundled sample inputs to a model for an arbitrary list of public functions. Models with bundled inputs can be invoked in a uniform manner by benchmarking and code coverage tools. Augmented models will support the following methods: `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` Returns a dictionary mapping function names to a metadata dictionary. This nested dictionary maps preset strings like: 'get_inputs_function_name' -> the name of a function attribute in this model that can be run to get back a list of inputs corresponding to that function. 'info' -> the user provided extra information about the bundled inputs If forward has bundled inputs then these following functions are also defined: `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs(): model(*inp)` `get_num_bundled_inputs() -> int` Equivalent to `len(model.get_all_bundled_inputs())`, but slightly easier to call from C++. `run_on_bundled_input(idx: int) -> Any` Run the model on bundled input number `idx` Inputs can be specified in one of two ways: - The model can define `_generate_bundled_inputs_for_<function_name>` get_all_bundled_inputs will simply call this method and cache the value. If the user chooses this method inputs[<function>] should map to None - The `inputs` argument to this function can be a dictionary mapping functions to a list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>. The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a list of inputs, the inner tuple is the list of args that together make up one input. For inputs of functions that take one arg, this will be a tuple of length one. The Any, ... is the actual data that makes up the args, e.g. a tensor. Info is an optional parameter that maps functions to a list of strings providing extra information about that function's bundled inputs. This could be descriptions, expected outputs, etc. - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} This function will attempt to optimize arguments so that (e.g.) arguments like `torch.zeros(1000)` will be represented compactly. Only top-level arguments will be optimized. Tensors in lists or tuples will not. """ if not isinstance(model, torch.jit.ScriptModule): raise Exception("Only ScriptModule is supported.") if not inputs: raise Exception("Please provide inputs for at least 1 function") get_bundled_inputs_functions_and_info_template = "" for function, input_list in inputs.items(): function_name = function.__name__ if input_list is not None and not isinstance(input_list, Sequence): raise TypeError( "Error inputs for function {0} is not a Sequence".format( function_name)) function_arg_types = [ arg.type for arg in function.schema.arguments[1:] ] # type: ignore deflated_inputs_type: ListType = ListType( TupleType(function_arg_types)) inflated_inputs_type: OptionalType[ListType] = OptionalType( deflated_inputs_type) model._c._register_attribute( "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs_type, []) model._c._register_attribute( "_bundled_inputs_inflated_{name}".format(name=function_name), inflated_inputs_type, None) if hasattr(model, "_generate_bundled_inputs_for_" + function_name): if input_list is not None: raise Exception( "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined" .format(name=function_name)) # Model author already defined _generate_bundled_inputs_for_<function_name>. elif input_list is None or len(input_list) == 0: raise Exception( "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined" .format(name=function_name, )) else: # Iterate over the inputs and args in each input. # Accumulate `deflated_inputs` as (possibly) compressed values # and `parts` to be joined into the expression that unpacks them. deflated_inputs = [] parts = [] for inp_idx, args in enumerate(input_list): if not isinstance(args, Tuple) and not isinstance( args, List): # type: ignore raise TypeError( "Error bundled input for function {0} idx: {1} is not a Tuple or a List" .format(function_name, inp_idx)) deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): deflated, inflater = _inflate_expr( arg, f"deflated[{inp_idx}][{arg_idx}]") deflated_args.append(deflated) parts.append(f" {inflater},") deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) setattr( model, "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs) definition = textwrap.dedent(""" def _generate_bundled_inputs_for_{name}(self): deflated = self._bundled_inputs_deflated_{name} return [ {expr} ] """).format(expr=expr, name=function_name) model.define(definition) # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs. model.define( textwrap.dedent(""" def get_all_bundled_inputs_for_{name}(self): if self._bundled_inputs_inflated_{name} is None: self._bundled_inputs_inflated_{name} = self._generate_bundled_inputs_for_{name}() all_inputs = self._bundled_inputs_inflated_{name} assert all_inputs is not None return all_inputs """).format(name=function_name)) # Add to the high level helper methods inputs_info = repr( info[function]) if info and function in info else '[]' get_bundled_inputs_functions_and_info_template += """ temp_dict : Dict[str,List[str]] = {{}} info: List[str] = {info} temp_dict['info'] = info temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}'] all_inputs['{name}'] = temp_dict """.format( name=function_name, info=inputs_info, ) # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided if function_name == 'forward': model.define( textwrap.dedent(""" def get_all_bundled_inputs(self): return self.get_all_bundled_inputs_for_forward() """)) model.define( textwrap.dedent(""" def get_num_bundled_inputs(self): return len(self.get_all_bundled_inputs_for_forward()) """)) model.define( textwrap.dedent(""" def run_on_bundled_input(self, idx: int): return self(*self.get_all_bundled_inputs()[idx]) """)) # Define some high level helper methods that act on all bundled inputs model.define( textwrap.dedent(""" def get_bundled_inputs_functions_and_info(self): all_inputs : Dict[str, Dict[str,List[str]]] = {{}} {template} return all_inputs """.format(template=get_bundled_inputs_functions_and_info_template)))
def try_ann_to_type(ann, loc): if ann is inspect.Signature.empty: return TensorType.getInferred() if ann is None: return NoneType.get() if inspect.isclass(ann) and is_tensor(ann): return TensorType.get() if is_tuple(ann): return TupleType([try_ann_to_type(a, loc) for a in ann.__args__]) if is_list(ann): elem_type = try_ann_to_type(ann.__args__[0], loc) if elem_type: return ListType(elem_type) if is_dict(ann): key = try_ann_to_type(ann.__args__[0], loc) value = try_ann_to_type(ann.__args__[1], loc) # Raise error if key or value is None if key is None: raise ValueError(f"Unknown type annotation: '{ann.__args__[0]}' at {loc.highlight()}") if value is None: raise ValueError(f"Unknown type annotation: '{ann.__args__[1]}' at {loc.highlight()}") return DictType(key, value) if is_optional(ann): if issubclass(ann.__args__[1], type(None)): contained = ann.__args__[0] else: contained = ann.__args__[1] valid_type = try_ann_to_type(contained, loc) msg = "Unsupported annotation {} could not be resolved because {} could not be resolved." assert valid_type, msg.format(repr(ann), repr(contained)) return OptionalType(valid_type) if torch.distributed.rpc.is_available() and is_rref(ann): return RRefType(try_ann_to_type(ann.__args__[0], loc)) if is_future(ann): return FutureType(try_ann_to_type(ann.__args__[0], loc)) if ann is float: return FloatType.get() if ann is complex: return ComplexType.get() if ann is int: return IntType.get() if ann is str: return StringType.get() if ann is bool: return BoolType.get() if ann is Any: return AnyType.get() if ann is type(None): return NoneType.get() if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): return InterfaceType(ann.__torch_script_interface__) if ann is torch.device: return DeviceObjType.get() if ann is torch.Stream: return StreamObjType.get() if ann is torch.dtype: return IntType.get() # dtype not yet bound in as its own type if inspect.isclass(ann) and issubclass(ann, enum.Enum): if _get_script_class(ann) is None: scripted_class = torch.jit._script._recursive_compile_class(ann, loc) name = scripted_class.qualified_name() else: name = _qualified_name(ann) return EnumType(name, get_enum_value_type(ann, loc), list(ann)) if inspect.isclass(ann): maybe_script_class = _get_script_class(ann) if maybe_script_class is not None: return maybe_script_class if torch._jit_internal.can_compile_class(ann): return torch.jit._script._recursive_compile_class(ann, loc) # Maybe resolve a NamedTuple to a Tuple Type def fake_rcb(key): return None return torch._C._resolve_type_from_object(ann, loc, fake_rcb)
def augment_model_with_bundled_inputs( model: torch.jit.ScriptModule, inputs: Optional[List[Tuple[Any, ...]]] = None, _receive_inflate_expr: Optional[List[str]] = None, # For debugging. ) -> None: """Add bundled sample inputs to a model. Models with bundled inputs can be invoked in a uniform manner by benchmarking and code coverage tools. Augmented models will support the following methods: `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` Returns a list of tuples suitable for passing to the model like `for inp in model.get_all_bundled_inputs(): model(*inp)` `get_num_bundled_inputs() -> int` Equivalent to `len(model.get_all_bundled_inputs())`, but slightly easier to call from C++. `run_on_bundled_input(idx: int) -> Any` Run the model on bundled input number `idx` Inputs can be specified in one of two ways: - The model can define `_generate_bundled_inputs` get_all_bundled_inputs will simply call this method and cache the value. - The `inputs` argument to this function can be a list of tuples, of the same form that will be returned by get_all_bundled_inputs. This function will attempt to optimize arguments so that (e.g.) arguments like `torch.zeros(1000)` will be represented compactly. Only top-level arguments will be optimized. Tensors in lists or tuples will not. """ if not isinstance(model, torch.jit.ScriptModule): raise Exception("Only ScriptModule is supported.") forward_arg_types = [ arg.type for arg in model.forward.schema.arguments[1:] ] deflated_inputs_type: ListType = ListType(TupleType(forward_arg_types)) inflated_inputs_type: OptionalType[ListType] = OptionalType( deflated_inputs_type) model._c._register_attribute("_bundled_inputs_deflated", deflated_inputs_type, []) model._c._register_attribute("_bundled_inputs_inflated", inflated_inputs_type, None) if hasattr(model, "_generate_bundled_inputs"): if inputs is not None: raise Exception( "inputs is not None, but _generate_bundled_inputs is already defined" ) # Model author already defined _generate_bundled_inputs. elif inputs is None: raise Exception( "inputs must be specified if _generate_bundled_inputs is not already defined" ) else: # Iterate over the inputs and args in each input. # Accumulate `deflated_inputs` as (possibly) compressed values # and `parts` to be joined into the expression that unpacks them. deflated_inputs = [] parts = [] for inp_idx, args in enumerate(inputs): deflated_args = [] parts.append("(") for arg_idx, arg in enumerate(args): deflated, inflater = _inflate_expr( arg, f"deflated[{inp_idx}][{arg_idx}]") deflated_args.append(deflated) parts.append(f" {inflater},") deflated_inputs.append(tuple(deflated_args)) parts.append("),") parts.append("") expr = "\n".join(parts) # Back-channel return this expr for debugging. if _receive_inflate_expr is not None: _receive_inflate_expr.append(expr) model._bundled_inputs_deflated = deflated_inputs definition = textwrap.dedent(""" def _generate_bundled_inputs(self): deflated = self._bundled_inputs_deflated return [ {} ] """).format(expr) model.define(definition) # Define get_all_bundled_inputs that caches the generated inputs. model.define( textwrap.dedent(""" def get_all_bundled_inputs(self): if self._bundled_inputs_inflated is None: self._bundled_inputs_inflated = self._generate_bundled_inputs() all_inputs = self._bundled_inputs_inflated assert all_inputs is not None return all_inputs """)) # Define some helper methods. model.define( textwrap.dedent(""" def get_num_bundled_inputs(self): return len(self.get_all_bundled_inputs()) """)) model.define( textwrap.dedent(""" def run_on_bundled_input(self, idx: int): return self(*self.get_all_bundled_inputs()[idx]) """))