def __torch_dispatch__(self, func, types, args=(), kwargs=None): named_arg_list = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True ).kwargs schema_info_value_test = torch._C._SchemaInfo(func._schema) schema_info_values_test = torch._C._SchemaInfo(func._schema) self.test_self.assertFalse(schema_info_value_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) self.test_self.assertFalse(schema_info_values_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) for i in named_arg_list: schema_info_value_test.add_argument_value(i, named_arg_list[i]) schema_info_values_test.add_argument_values(named_arg_list) self.test_self.assertTrue(schema_info_value_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) self.test_self.assertTrue(schema_info_values_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) return func(*args, **kwargs)
def normalized_arguments( self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, kwarg_types : Optional[Dict[str, Any]] = None, normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to Python targets. This means that `args/kwargs` will be matched up to the module/functional's signature and return exclusively kwargs in positional order if `normalize_to_only_use_kwargs` is true. Also populates default values. Does not support positional-only parameters or varargs parameters. Supports module calls. May require `arg_types` and `kwarg_types` in order to disambiguate overloads. Args: root (torch.nn.Module): Module upon which to resolve module targets. arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. Returns: Returns NamedTuple ArgsKwargsPair, or `None` if not successful. """ if self.op == 'call_function': assert callable(self.target) return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] elif self.op == 'call_module': assert isinstance(self.target, str) return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] return None
def index_put_(fake_mode, func, *args, **kwargs): with in_kernel_invocation_manager(fake_mode): out = func(*args, **kwargs) _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) return new_kwargs["input"]
def non_kwarg_to(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True ) input_device = new_kwargs["device"] out_device = input_device if input_device else new_kwargs["input"].device new_kwargs["device"] = torch.device("meta") r = func(*args, **new_kwargs) return fake_mode.fake_tensor_converter(fake_mode, r, out_device)
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) out_device = new_kwargs["input"].device with in_kernel_invocation_manager(fake_mode): out = func(*args, **kwargs) return FakeTensor(fake_mode, out, out_device)
def to_copy(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) input_device = new_kwargs.pop("device", None) out_device = input_device if input_device else new_kwargs["input"].device with no_dispatch(): input = new_kwargs.pop("input").to("meta") return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
def call_function( self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any], arg_types: Optional[Tuple[Any, ...]] = None, kwarg_types : Optional[Dict[str, Any]] = None): assert callable(target) new_args_and_kwargs = normalize_function(target, args, kwargs, arg_types, kwarg_types, # type: ignore[arg-type] self.normalize_to_only_use_kwargs) if new_args_and_kwargs: new_args, new_kwargs = new_args_and_kwargs return self.tracer.create_proxy('call_function', target, new_args, new_kwargs) else: return super().call_function(target, args, kwargs)
def index_tensor(fake_mode, func, *args, **kwargs): # dynamic shape op if indices are bool/uint8 check_no_bool_index_tensors(func, *args, **kwargs) _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) out_device = new_kwargs["input"].device with in_kernel_invocation_manager(fake_mode): out = func(*args, **kwargs) return FakeTensor(fake_mode, out, out_device)
def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any], arg_types: Optional[Tuple[Any, ...]] = None, kwarg_types: Optional[Dict[str, Any]] = None): assert callable(target) new_kwargs = normalize_function(target, args, kwargs, arg_types, kwarg_types) # type: ignore if new_kwargs: return self.tracer.create_proxy('call_function', target, (), new_kwargs) else: return super().call_function(target, args, kwargs)
def contructors(fake_mode, func, *args, **kwargs): assert func not in _non_kwarg_device_constructors _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) if func in _like_tensor_constructors: default_device = new_kwargs["input"].device # TODO: file issue args = (new_kwargs.pop("input"),) else: # cpu is default device if none is specified default_device = torch.device("cpu") args = () out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device)
def __torch_dispatch__(self, func, types, args=(), kwargs=None): def has_mutated(before, after): return not torch.equal(before, after) if isinstance( before, torch.Tensor) and isinstance(after, torch.Tensor) else False def has_aliased(lhs, rhs): return torch._C._is_alias_of(lhs, rhs) if isinstance( lhs, torch.Tensor) and isinstance(rhs, torch.Tensor) else False def is_mutable(arg_alias_pairs): for arg in arg_alias_pairs: if arg.alias_info is not None and arg.alias_info.is_write: return True return False def is_aliasing(output_alias_info, arg_alias_pairs): if output_alias_info is None: return False for arg in arg_alias_pairs: if arg.alias_info is not None and bool( len(output_alias_info.after_set & arg.alias_info.after_set)): return True return False def are_args_aliasing(lhs, rhs): for lhs_value in lhs: for rhs_value in rhs: if (has_aliased(lhs_value, rhs_value)): return True return False def standardize_name(name): return name if name != "self" else "input" def unwrap(e): if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: try: return e.elem except AttributeError as t: return e else: return e self.ops.append(func._schema.name) arguments = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True).kwargs cloned_arguments = dict( zip(arguments.keys(), clone_inputs(arguments.values()))) out = func(*args, **kwargs) # Construct an aliasing map between op arguments for verifying aliasing pairs # between op arguments and op outputs. This is used to allow cases where two aliasing arguments # cause a non-mutable/non-aliasing argument to mutate or alias. arg_alias_pairs_map = { arg.name: [arg] for arg in func._schema.arguments } for i_arg, j_arg in combinations(func._schema.arguments, 2): i_values = tree_map( unwrap, tree_flatten(arguments.get(standardize_name(i_arg.name)))[0]) j_values = tree_map( unwrap, tree_flatten(arguments.get(standardize_name(j_arg.name)))[0]) if are_args_aliasing(i_values, j_values): arg_alias_pairs_map[i_arg.name].append(j_arg) arg_alias_pairs_map[j_arg.name].append(i_arg) for arg in func._schema.arguments: name = standardize_name(arg.name) if arguments.get(name) is not None: arg_alias_pairs = arg_alias_pairs_map[arg.name] before = tree_flatten(cloned_arguments.get(name))[0] after = tree_flatten(arguments.get(name))[0] u_values = tree_map(unwrap, after) u_out = tree_map(unwrap, out) u_out = u_out if isinstance(u_out, tuple) else (u_out, ) if any([has_mutated(i, j) for i, j in zip(before, after) ]) and not is_mutable(arg_alias_pairs): raise RuntimeError( f"Argument {name} is not defined as mutable but was mutated" ) for v in u_values: for j in range(len(u_out)): if has_aliased(v, u_out[j]) and not is_aliasing( func._schema.returns[j].alias_info, arg_alias_pairs): raise RuntimeError( f'Argument {name} is not defined to alias output but was aliasing' ) return out
def __torch_dispatch__(self, func, types, args=(), kwargs=None): def has_mutated(before, after, md): if type(before) == torch.Tensor and type(after) == torch.Tensor: return not (torch.equal(before, after) and md[0] == after.stride() and md[1] == after.storage()._cdata) return False def has_aliased(lhs, rhs): try: return torch._C._overlaps(lhs, rhs) except Exception as exception: if str(exception).startswith("Cannot inspect value of type "): return False else: raise exception def standardize_name(name): return name if name != "self" else "input" def unwrap(e): if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: try: return e.elem except AttributeError as t: return e return e def parse_metadata(e): if isinstance(e, torch.Tensor): if not type(e) == torch.Tensor: try: current = e.elem return (deepcopy(current.stride()), current.storage()._cdata) except AttributeError as t: return None else: return (deepcopy(e.stride()), e.storage()._cdata) return None self.ops.append(func._schema.name) # Clone and process arguments and outputs pre_arguments = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True).kwargs c_p_args = dict( zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) cloned_arguments = { name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args } cloned_metadata = { name: tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments } out = func(*args, **kwargs) arguments = { name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments } tuple_out = out if isinstance(out, tuple) else (out, ) tuple_out = tree_map(unwrap, tuple_out) schema_info = SchemaInfo(func._schema) schema_info.add_argument_values(pre_arguments) # Process arguments with outputs for i in range(len(func._schema.arguments)): arg = func._schema.arguments[i] name = standardize_name(arg.name) if arguments.get(name) is not None: before = cloned_arguments.get(name) md = cloned_metadata.get(name) after = arguments.get(name) for j in range(len(tuple_out)): if has_aliased(tuple_out[j], after): if not schema_info.may_contain_alias( SchemaArgument(SchemaArgType.output, j), SchemaArgument(SchemaArgType.input, i)): raise RuntimeError( f'Argument {name} is not defined to alias output but was aliasing' ) else: self.aliasing.append( Aliasing(func._schema.name, name, f"output_{j}")) if any( has_mutated(a, b, c) for a, b, c in zip( tree_flatten(before)[0], tree_flatten(after)[0], md)): if not schema_info.is_mutable( SchemaArgument(SchemaArgType.input, i)): raise RuntimeError( f"Argument {name} is not defined as mutable but was mutated" ) else: self.mutated.append(Mutation(func._schema.name, name)) # Aliasing between outputs for i, j in combinations(range(len(func._schema.returns)), 2): if has_aliased(tuple_out[i], tuple_out[j]): if not schema_info.may_contain_alias( SchemaArgument(SchemaArgType.output, i), SchemaArgument(SchemaArgType.output, j)): raise RuntimeError( f'Outputs {i} and {j} alias unexpectedly') return out
def torch_dispatch_impl(cls_or_mode_instance, func, types, args, kwargs, run_function): kwargs = kwargs if kwargs else {} in_fake_mode = isinstance(cls_or_mode_instance, FakeTensorMode) converter = cls_or_mode_instance.fake_tensor_converter if in_fake_mode else FakeTensorConverter( ) # This classes virtualizes .device() calls, need to short-circuit # it instead of calling device again or we would keep on recurring if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) return args[0].fake_device def wrap(e, device=None): if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): return converter(e, device) else: return e # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. TODO: extend more broadly if isinstance(cls_or_mode_instance, FakeTensorMode): conversion_made = False def check_non_fake_tensor(x): nonlocal conversion_made conversion_made = conversion_made or (isinstance( x, torch.Tensor) and not isinstance(x, FakeTensor)) tree_map(check_non_fake_tensor, args) tree_map(check_non_fake_tensor, kwargs) if conversion_made: raise Exception( "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " f"Please convert all Tensors to FakeTensors first. Found in {func}" ) # _to_copy fails when run with FakeTensors to cuda device # TODO: debug if func == torch.ops.aten._to_copy.default: _, new_kwargs = normalize_function(func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True) out_device = new_kwargs.pop("device", new_kwargs["input"].device) with no_dispatch(): input = new_kwargs.pop("input").to("meta") return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device) if _is_tensor_constructor(func): assert func not in _non_kwarg_device_constructors _, new_kwargs = normalize_function(func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True) # cpu is default device if none is specified out_device = new_kwargs.pop("device", torch.device("cpu")) new_kwargs["device"] = torch.device("meta") r = run_function(func, types, (), new_kwargs) return FakeTensor(r, out_device) r = run_function(func, types, args, kwargs) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" # if device is specified, use that if kwargs.get("device", None): return tree_map(partial(wrap, device=kwargs["device"]), r) # operators which copy size from another tensor do not # also take device from the size tensor # other size_as operators are not builtin operators if func == aten.resize_as_.default: _, new_kwargs = normalize_function(func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True) # device of the input is returned return tree_map(partial(wrap, device=new_kwargs["input"].device), r) common_device = FakeTensor._find_common_device(func, args, kwargs) return tree_map(partial(wrap, device=common_device), r)
def __torch_dispatch__(self, func, types, args=(), kwargs=None): def has_mutated(before, after, md): if type(before) == torch.Tensor and type(after) == torch.Tensor: return not (torch.equal(before, after) and md[0] == after.stride() and md[1] == after.storage()._cdata) return False def has_aliased(lhs, rhs): if type(lhs) == torch.Tensor and type(rhs) == torch.Tensor: return torch._C._is_alias_of(lhs, rhs) return False def is_mutable(arg_alias_pairs): for arg in arg_alias_pairs: if arg.alias_info is not None and arg.alias_info.is_write: return True return False def is_aliasing(output, arg_alias_pairs): for arg in arg_alias_pairs: if arg.alias_info is not None: if '*' in arg.alias_info.after_set: same_types = output.type == arg.type elems_same_types = ( isinstance(output.type, torch._C.ListType) and output.type.getElementType() == arg.type) if same_types or elems_same_types: return True elif output.alias_info is not None: share_aliasing_sets = bool( len(output.alias_info.after_set & arg.alias_info.after_set)) if share_aliasing_sets: return True return False def are_args_aliasing(lhs, rhs): for lhs_value in lhs: for rhs_value in rhs: if (has_aliased(lhs_value, rhs_value)): return True return False def standardize_name(name): return name if name != "self" else "input" def unwrap(e): if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: try: return e.elem except AttributeError as t: return e return e def parse_metadata(e): if isinstance(e, torch.Tensor): if not type(e) == torch.Tensor: try: current = e.elem return (deepcopy(current.stride()), current.storage()._cdata) except AttributeError as t: return None else: return (deepcopy(e.stride()), e.storage()._cdata) return None self.ops.append(func._schema.name) # Clone and process arguments and outputs pre_arguments = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True).kwargs c_p_args = dict( zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) cloned_arguments = { name: tree_map(unwrap, tree_flatten(c_p_args.get(name))[0]) for name in c_p_args } cloned_metadata = { name: tree_map(parse_metadata, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments } out = func(*args, **kwargs) arguments = { name: tree_map(unwrap, tree_flatten(pre_arguments.get(name))[0]) for name in pre_arguments } tuple_out = out if isinstance(out, tuple) else (out, ) u_out = [tree_map(unwrap, tree_flatten(i)[0]) for i in tuple_out] # Construct an aliasing map between op arguments for verifying aliasing pairs # between op arguments and op outputs. This is used to allow cases where two aliasing arguments # cause a non-mutable/non-aliasing argument to mutate or alias. arg_alias_pairs_map = { standardize_name(arg.name): [arg] for arg in func._schema.arguments } # Construct an aliasing set for each output for verifying aliasing pairs # between op outputs out_alias_pairs_map = [set() for arg in func._schema.returns] # Aliasing between arguments for i_arg, j_arg in combinations(func._schema.arguments, 2): i_values = arguments.get(standardize_name(i_arg.name)) j_values = arguments.get(standardize_name(j_arg.name)) if are_args_aliasing(i_values, j_values): arg_alias_pairs_map[standardize_name(i_arg.name)].append(j_arg) arg_alias_pairs_map[standardize_name(j_arg.name)].append(i_arg) # Process arguments with outputs for arg in func._schema.arguments: name = standardize_name(arg.name) if arguments.get(name) is not None: arg_alias_pairs = arg_alias_pairs_map[name] before = cloned_arguments.get(name) md = cloned_metadata.get(name) after = arguments.get(name) for v in after: for i in range(len(u_out)): for j in range(len(u_out[i])): if has_aliased(v, u_out[i][j]): if not is_aliasing(func._schema.returns[i], arg_alias_pairs): raise RuntimeError( f'Argument {name} is not defined to alias output but was aliasing' ) else: self.aliasing.append( Aliasing(func._schema.name, name, f"output_{i}")) out_alias_pairs_map[i].add(name) break if any( has_mutated(i, j, k) for i, j, k in zip(before, after, md)): if not is_mutable(arg_alias_pairs): raise RuntimeError( f"Argument {name} is not defined as mutable but was mutated" ) else: self.mutated.append(Mutation(func._schema.name, name)) # Aliasing between outputs for i, j in combinations(range(len(func._schema.returns)), 2): if are_args_aliasing(u_out[i], u_out[j]): share_aliasing_inputs = bool( len(out_alias_pairs_map[i] & out_alias_pairs_map[j])) if (not share_aliasing_inputs): raise RuntimeError( f'Outputs {i} and {j} alias unexpectedly') return out