def proxy_call(func_overload, args, kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_out = func(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_out proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) with no_dispatch(): real_out = func_overload(*args, **kwargs) return wrap_output(real_out, proxy_out)
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception): with no_dispatch(): def to_cpu(e): if isinstance(e, FakeTensor): return torch.zeros_like(e, device="cpu") return e try: args = tree_map(to_cpu, args) kwargs = tree_map(to_cpu, kwargs) r = func(*args, **kwargs) except Exception as new_exception: raise orig_not_implemented_exception from new_exception tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): tensor_impls.add(e) storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will # not be set up, bc of conversion to cpu, error on reused impls for e in tree_flatten(r)[0]: if e in tensor_impls or (isinstance(e, torch.Tensor) and e.storage()._cdata in storages): raise orig_not_implemented_exception # we're only converting these to MetaTensors now, not Fake Tensors, # and the cpu inputs should be temporary. just convert outputs to meta # and continue return tree_map(MetaConverter(), r)
def new(self, *args, **kwargs): # torch.Tensor.new does not go through the normal dispatcher pattern # so in order to use the same pattern as normal invocation of # returning meta device within the kernel we need to intercept # the call here out_device = self.fake_device if "device" in kwargs: kwarg_device = kwargs.pop("device") out_device = kwarg_device if kwarg_device else out_device kwargs["device"] = "meta" with in_kernel_invocation_manager(self.fake_mode): with no_dispatch(): meta_out = super().new(*args, **kwargs) with no_dispatch(): return FakeTensor(self.fake_mode, meta_out, out_device)
def wrap_with_proxy(e, proxy, constant): if isinstance(e, torch.Tensor): with no_dispatch(): return ProxyTensor(e, proxy, constant=constant, proxy_mode=proxy_mode) else: return e
def new(self, *args, **kwargs): # torch.Tensor.new does not go through the normal dispatcher pattern # so in order to use the same pattern as normal invocation of # returning meta device within the kernel we need to intercept # the call here # because it doesn't go through the dispatcher, we run into errors # when attempting to compute an output in meta, so # we compute the real tensor then convert to meta out_device = self.fake_device with no_dispatch(): real_out = super().new(*args, **kwargs) assert not isinstance(real_out, FakeTensor), real_out assert real_out.device.type != "meta", real_out.device with no_dispatch(): meta_out = MetaConverter()(real_out) return FakeTensor(self.fake_mode, meta_out, out_device)
def to_copy(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) input_device = new_kwargs.pop("device", None) out_device = input_device if input_device else new_kwargs["input"].device with no_dispatch(): input = new_kwargs.pop("input").to("meta") return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception): # these should all be supported, just to be safe # avoid fallback for operators which inplace modify metadata # because the input fake tensors would be umodified if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] raise orig_not_implemented_exception with no_dispatch(): inp_impls = {} def to_real_tensor(e): if isinstance(e, FakeTensor): out = torch.zeros_like(e, device=e.fake_device) if e.is_sparse: out._coalesced_(e.is_coalesced()) inp_impls[id(out)] = e return out return e args = tree_map(to_real_tensor, args) kwargs = tree_map(to_real_tensor, kwargs) r = func(*args, **kwargs) tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): if not e.is_sparse: storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will # not be set up, bc of conversion to device, unless we can reuse an # input impl for e in tree_flatten(r)[0]: if id(e) not in inp_impls and ( isinstance(e, torch.Tensor) and not e.is_sparse and e.storage()._cdata in storages ): raise orig_not_implemented_exception def map_out(e): if isinstance(e, torch.Tensor): if id(e) in inp_impls: return inp_impls[id(e)] else: return fake_mode.fake_tensor_converter(fake_mode, e) else: return e return tree_map(map_out, r)
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None): func = func_overload.overloadpacket if any(tuple(isinstance(arg, ProxyTensor) for arg in args)): return proxy_call(func_overload, args, kwargs) else: proxy_out = self.tracer.create_proxy('call_function', func, args, kwargs, name=self.tracer.graph._target_to_str(func.__name__)) with no_dispatch(): real_out = func_overload(*args, **kwargs) return wrap_output(real_out, proxy_out)
def wrapped(*args): flat_args, args_spec = pytree.tree_flatten(args) assert (len(flat_args) == len(flat_inps)) for idx, arg in enumerate(flat_args): if isinstance(flat_inps[idx], torch.Tensor): with no_dispatch(): flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf) else: flat_args[idx] = flat_inps[idx] tree_args = pytree.tree_unflatten(flat_args, args_spec) out = f(*tree_args) flat_outs, out_spec = pytree.tree_flatten(out) for idx in range(len(flat_outs)): if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor): flat_outs[idx] = flat_outs[idx].proxy return pytree.tree_unflatten(flat_outs, out_spec)
def from_real_tensor(self, fake_mode, t): maybe_memo = self._get_memo(t) if maybe_memo is not None: return maybe_memo existing_device = t.device # not yet supported in metatensors if t.is_complex(): raise UnsupportedFakeTensorException("complex nyi in meta tensors") if t.is_sparse: raise UnsupportedFakeTensorException("sparse nyi in meta tensors") if t.is_quantized: raise UnsupportedFakeTensorException("quantized nyi in meta tensors") with no_dispatch(): out = FakeTensor(fake_mode, self.meta_converter(t), existing_device) if type(t) is torch.nn.Parameter: out = torch.nn.Parameter(out, requires_grad=out.requires_grad) # type: ignore[assignment] self.set_tensor_memo(t, out) return out
def proxy_call(func_overload, args, kwargs=None): if kwargs is None: kwargs = {} func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e def unwrap_elem(e): if isinstance(e, ProxyTensor): return e.elem return e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_res = func_overload(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_res proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs)) # Needed to sync up metadata for in-place operators that modify metadata if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined] with no_dispatch(): func_overload(*args, **kwargs) # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs)) return wrap_output(inner_res, proxy_res)
def __repr__(self): with no_dispatch(): return f"ProxyTensor({self.as_subclass(torch.Tensor)}, proxy={self.proxy})" # type: ignore[arg-type]
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, CompositeCompliantTensor) else e def wrap(e): return CompositeCompliantTensor(e) if isinstance( e, torch.Tensor) else e if func == torch.ops.aten._local_scalar_dense.default: raise RuntimeError( ".item() is not allowed to be called inside of composite " "functions in the PyTorch library because not all backends " "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them." ) if func.overloadpacket.__name__ in ('set_', 'resize_'): raise RuntimeError( f"{func.__name__} is not allowed to be called inside of " f"Composite operators.") if is_inplace(func): # NB: We are making an assumption that if the function is in-place, # then the first argument is being written to. Introspection please save us! mutated_argument = args[0] if not isinstance(mutated_argument, CompositeCompliantTensor) and \ any([isinstance(a, CompositeCompliantTensor) for a in args[1:]]): raise RuntimeError( 'Not composite compliant: performing in-place operation ' f'{func.__name__} where the Tensor being written to is ' 'regular Tensor but the other tensors are Tensor Subclasses. ' 'Please try to avoid this in-place operation.') with enable_reentrant_dispatch(): with contextlib.nullcontext( ) if enable_recursive_torch_dispatch else no_dispatch(): unwrapped_args = tree_map(unwrap, args) unwrapped_kwargs = tree_map(unwrap, kwargs) unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs) rs = tree_map(wrap, unwrapped_rs) if is_view_fn(func) and autograd_view_consistency: # Note [Alias Result] # Autograd asserts that for B = A.view_fn(...), B and A's storages # are the same. Here we try to make B alias A to avoid those asserts. # See https://github.com/pytorch/pytorch/issues/65339 for more information # about the issue. with enable_reentrant_dispatch(): with no_dispatch(): # Idea: this is a weird way of getting a storage that aliases the input. # This is a workaround for #65339. # 1. under no_dispatch, all of the wrapper tensors look like regular # tensors with special storage (the storage is nullptr and # advertises CPU/CUDA device. # 2. we run func, which ends up running the view operation # 3. All view operations reuse the input's storage and return # result Tensor(s) with new sizes/strides/offset that alias # the input. # 4. we set the storage (and sizes/strides/offset) of the wrapper # tensor results to be that of the tensors that alias the input result = func(*args, **kwargs) if isinstance(result, tuple) or isinstance( result, list): for a, b in zip(rs, result): a.set_(b) else: rs.set_(result) # Some operations are allowed to in-place modify the metadata of the # inputs. The only ones are the "inplace view functions"; when we # run into these, we manually modify the metadata of the input. with no_dispatch(): if is_inplace_view_fn(func): func(*args, **kwargs) # For each CompositeCompliantTensor t, we check that t and t.elem # have consistent metadata. If they don't have consistent metadata, # that means the operator did something fishy. check = partial(check_metadata_consistency, CCT=cls) tree_map(check, args) tree_map(check, kwargs) tree_map(check, rs) return rs
def meta_tensor(self, t): if t not in self.tensor_memo: with torch.inference_mode(t.is_inference()): if t._is_view(): # Construct views in two steps: recursively meta-fy their # base, and then create the view off that. NB: doing it # directly from storage is WRONG because this won't cause # version counters to get shared. assert t._is_view() base = self.meta_tensor(t._base) def is_c_of_r(complex_dtype, real_dtype): return is_complex_dtype(complex_dtype) and \ corresponding_real_dtype(complex_dtype) == real_dtype if base.dtype == t.dtype: pass elif is_c_of_r(base.dtype, t.dtype): base = torch.view_as_real(base) elif is_c_of_r(t.dtype, base.dtype): base = torch.view_as_complex(base) else: # This is not guaranteed to succeed. If it fails, it # means there is another dtype-converting view function # that hasn't been handled here base = base.view(t.dtype) with torch.enable_grad(): r = base.as_strided(t.size(), t.stride(), t.storage_offset()) else: is_leaf = safe_is_leaf(t) # Fake up some autograd history. if t.requires_grad: r = torch.empty((0, ), dtype=t.dtype, device='meta', requires_grad=True) if not is_leaf: with torch.enable_grad(): # The backward function here will be wrong, but # that's OK; our goal is just to get the metadata # looking as close as possible; we're not going to # actually try to backward() on these produced # metas. TODO: would be safer to install some # sort of unsupported grad_fn here r = r.clone() else: r = torch.empty((0, ), dtype=t.dtype, device='meta') # As long as meta storage is not supported, need to prevent # redispatching on set_(Storage, ...) which will choke with # meta storage s = self.meta_storage(t.storage()) with no_dispatch(): with torch.no_grad(): r.set_(s, t.storage_offset(), t.size(), t.stride()) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) self.tensor_memo[t] = r return self.tensor_memo[t]
def from_tensor(self, tensor): with no_dispatch(): return self.fake_tensor_converter(self, tensor)
def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: return torch.device("meta") else: return args[0].fake_device # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them if "prims::" in func._schema.name: with no_dispatch(): return func(*args, **kwargs) with no_dispatch(): # TODO: apply as no_dispatch decorator converter = self.fake_tensor_converter # this is generated from torch.tensor(), which does not use the # dispatcher, to allow wrapper subclasses to wrap the new tensor # we need to handle before error checking if func == torch.ops.aten.lift.default: assert (len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor) with no_dispatch(): return converter(self, args[0]) def wrap(e, device=None): if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): return converter(self, e, device) else: return e # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. TODO: extend more broadly conversion_made = False def check_non_fake_tensor(x): nonlocal conversion_made conversion_made = conversion_made or (isinstance( x, torch.Tensor) and not isinstance(x, FakeTensor)) tree_map(check_non_fake_tensor, args) tree_map(check_non_fake_tensor, kwargs) if conversion_made: raise Exception( "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " f"Please convert all Tensors to FakeTensors first. Found in {func}" ) for run_impl_check, op_impl in op_implementations: if run_impl_check(func): return op_impl(self, func, *args, **kwargs) self.in_kernel_invocation = True try: r = func(*args, **kwargs) except NotImplementedError as not_implemented_error: if not self.allow_cpu_fallback: raise not_implemented_error r = run_cpu_fallback(func, args, kwargs, not_implemented_error) finally: self.in_kernel_invocation = False # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" # if device is specified, use that if kwargs.get("device", None): return tree_map(partial(wrap, device=kwargs["device"]), r) common_device = FakeTensor._find_common_device(func, args, kwargs) return tree_map(partial(wrap, device=common_device), r)
def torch_dispatch_impl(cls_or_mode_instance, func, types, args, kwargs, run_function): kwargs = kwargs if kwargs else {} in_fake_mode = isinstance(cls_or_mode_instance, FakeTensorMode) converter = cls_or_mode_instance.fake_tensor_converter if in_fake_mode else FakeTensorConverter( ) # This classes virtualizes .device() calls, need to short-circuit # it instead of calling device again or we would keep on recurring if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) return args[0].fake_device def wrap(e, device=None): if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): return converter(e, device) else: return e # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. TODO: extend more broadly if isinstance(cls_or_mode_instance, FakeTensorMode): conversion_made = False def check_non_fake_tensor(x): nonlocal conversion_made conversion_made = conversion_made or (isinstance( x, torch.Tensor) and not isinstance(x, FakeTensor)) tree_map(check_non_fake_tensor, args) tree_map(check_non_fake_tensor, kwargs) if conversion_made: raise Exception( "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " f"Please convert all Tensors to FakeTensors first. Found in {func}" ) # _to_copy fails when run with FakeTensors to cuda device # TODO: debug if func == torch.ops.aten._to_copy.default: _, new_kwargs = normalize_function(func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True) out_device = new_kwargs.pop("device", new_kwargs["input"].device) with no_dispatch(): input = new_kwargs.pop("input").to("meta") return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs), out_device) if _is_tensor_constructor(func): assert func not in _non_kwarg_device_constructors _, new_kwargs = normalize_function(func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True) # cpu is default device if none is specified out_device = new_kwargs.pop("device", torch.device("cpu")) new_kwargs["device"] = torch.device("meta") r = run_function(func, types, (), new_kwargs) return FakeTensor(r, out_device) r = run_function(func, types, args, kwargs) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" # if device is specified, use that if kwargs.get("device", None): return tree_map(partial(wrap, device=kwargs["device"]), r) # operators which copy size from another tensor do not # also take device from the size tensor # other size_as operators are not builtin operators if func == aten.resize_as_.default: _, new_kwargs = normalize_function(func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True) # device of the input is returned return tree_map(partial(wrap, device=new_kwargs["input"].device), r) common_device = FakeTensor._find_common_device(func, args, kwargs) return tree_map(partial(wrap, device=common_device), r)
def proxy_call(func_overload, args, kwargs=None): if kwargs is None: kwargs = {} func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: t, = args assert not kwargs if t.constant is not None: with maybe_disable_fake_tensor_mode(): return t.constant.item() raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e def unwrap_elem(e): if isinstance(e, ProxyTensor): return e.elem if isinstance(e, torch._C.SymbolicIntNode): if isinstance(e.get_pyobj(), ProxySymInt): return e.get_pyobj().sym_int else: raise RuntimeError(f"Something has gone wrong, we are trying to put SymInt {e.get_pyobj()} into the graph," f"even though it's not a ProxySymInt. This is a bug.") return e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_res = func_overload(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_res proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs)) # Needed to sync up metadata for in-place operators that modify metadata # TODO: instead forward the metadata to the inner tensor so updating # is not necessary if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined] with no_dispatch(): func_overload(*args, **kwargs) # In some circumstances, we will be tracing in a situation where a tensor # is *statically* known to be a constant (currently, this only happens if # you run torch.tensor; deterministic factory functions like torch.arange # don't get this treatment). When the tensor in question is small, it's # helpful to due constant propagation in case we call item() (in which # case we can return the constant value that is known, rather than give # an error.) The logic here tests if constant propagation is possible # (because all of the inputs are constant). If so, we disable fake tensor # mode (if it is on) and do true compute on the constant. # # It's worth highlighting that we're making a policy decision here. # There is a potential that the tensor is actually quite large, and we # don't actually want to run the compute. The tensor being quite large # is one of the reasons why factory functions don't get this treatment # (since they can be quite large; if a parameter is initialized to a # constant value it will be!) Similarly, there is also a potential # to run an operator that blows up the size of a small tensor; we don't # protect against this case, but we could force, e.g., only single # element constant computation by testing the numel of the result before # propagating const-ness. Similarly, we don't require the constant to # live on CPU, but we could. all_constant = True any_constant = False def check_constant(e): nonlocal all_constant, any_constant if isinstance(e, ProxyTensor): if e.constant is None: all_constant = False else: any_constant = True pytree.tree_map(check_constant, args) pytree.tree_map(check_constant, kwargs) def unwrap_constant(e): if isinstance(e, ProxyTensor): return e.constant return e constant = None # NB: do NOT include factories as constants if all_constant and any_constant: with maybe_disable_fake_tensor_mode(): constant = func_overload( *pytree.tree_map(unwrap_constant, args), **pytree.tree_map(unwrap_constant, kwargs) ) # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs)) return wrap_output(inner_res, proxy_res, constant=constant)
def __repr__(self): with no_dispatch(): return f"ProxyTensor({self.elem}, proxy={self.proxy})"
def wrap_with_proxy(e, proxy): if isinstance(e, torch.Tensor): with no_dispatch(): return ProxyTensor(e, proxy, **kwargs) else: return e
def wrap_with_proxy(e, proxy): if type(e) == torch.Tensor: with no_dispatch(): return ProxyTensor(e, proxy) else: return e
def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: return torch.device("meta") else: return args[0].fake_device flat_arg_tensors = [ i for i in tree_flatten((args, kwargs))[0] if isinstance(i, FakeTensor) ] has_symbolic_sizes = any([i.has_sym_ints for i in flat_arg_tensors]) if has_symbolic_sizes: # TODO: Find better approach for this # Avoid circular import from torch._decomp import decomposition_table from torch._meta_registrations import meta_table # TODO: hack, doesn't actually work. # see https://github.com/pytorch/pytorch/pull/81598#issuecomment-1192030435 with enable_torch_dispatch_mode( self), torch.overrides.enable_reentrant_dispatch(): if func in meta_table: r = meta_table[func](*args, **kwargs) return r if func in decomposition_table: return decomposition_table[func](*args, **kwargs) with no_dispatch(): if symbolic_shapes.is_symbolic_op(func): return symbolic_shapes.handle_symbolic_op( func, args, kwargs) # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them if "prims::" in func._schema.name: with no_dispatch(): return func(*args, **kwargs) if has_symbolic_sizes: constructors = [torch.ops.aten.empty.SymInt] if func not in constructors: raise RuntimeError( f"{func} - couldn't find symbolic meta function/decomposition" ) with no_dispatch(): # TODO: apply as no_dispatch decorator converter = self.fake_tensor_converter def wrap(e, device=None): if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): return converter(self, e, device) else: return e # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors. TODO: extend more broadly conversion_made = False subclass_seen = False def check_non_fake_tensor(x): nonlocal conversion_made, subclass_seen conversion_made = conversion_made or (isinstance( x, torch.Tensor) and not isinstance(x, FakeTensor)) subclass_seen = subclass_seen or (isinstance( x, torch.Tensor) and not isinstance(x, FakeTensor) and type(x) is not torch.Tensor) tree_map(check_non_fake_tensor, args) tree_map(check_non_fake_tensor, kwargs) # Suppose we enable fake tensor mode. This means that fake tensor # mode will run first. But what if we do an operation that # involves a tensor subclass that will desugar into normal tensor # operations? Without this line, fake tensor mode will run first, # decide that a conversion was made (since there was a non fake # tensor argument), and report an error that converting non # fake tensor is not supported. What we actually wanted to happen # was to give the subclass a chance to figure out what it wants to # before erroring out. Returning NotImplemented here allows this. # # NB: If you're seeing a mysterious infinite loop involving fake # tensor, it might be related to this line. Though I'm not sure # how you'll know to read this comment, as this line won't show up # in the stack trace. if subclass_seen: return NotImplemented # this is generated from torch.tensor(), which does not use the # dispatcher, to allow wrapper subclasses to wrap the new tensor # we need to handle before error checking if func in [ torch.ops.aten.lift_fresh.default, torch.ops.aten.lift_fresh_copy.default, ]: assert (len(kwargs) == 0 and len(args) == 1 and type(args[0]) is torch.Tensor), f"{args} {kwargs}" with no_dispatch(): return converter(self, args[0]) if conversion_made: raise Exception( "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" ) for run_impl_check, op_impl in op_implementations: if run_impl_check(func): return op_impl(self, func, *args, **kwargs) with in_kernel_invocation_manager(self): try: r = func(*args, **kwargs) except NotImplementedError as not_implemented_error: if not self.allow_fallback_kernels: raise not_implemented_error r = run_fallback_kernel(func, args, kwargs, not_implemented_error) # TODO: handle non-kwarg devices assert func not in _device_not_kwarg_ops, f"NYI: {func}" # if device is specified, use that if kwargs.get("device", None): return tree_map(partial(wrap, device=kwargs["device"]), r) common_device = FakeTensor._find_common_device(func, args, kwargs) return tree_map(partial(wrap, device=common_device), r)
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): with no_dispatch(): return cls._torch_dispatch(func, types, args, kwargs)
def clone(fake_mode, func, input, memory_format=None): out_device = input.device with no_dispatch(): out = torch.ops.aten._to_copy(input.to("meta"), memory_format=memory_format) return FakeTensor(fake_mode, out, out_device)