def packed_quantized_conv2d_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: """ Mapping from quantzed Conv2d module to acc_op.conv. We unpack all the parameters in this mapper and pass them directly to conv2d node. """ assert isinstance(node.target, str) conv_module = dict(mod.named_modules())[node.target] prefix = node.target.replace(".", "_") weight_name = f"{prefix}_weight" bias_name = f"{prefix}_bias" # Store weight and bias in the main module mod.register_buffer(weight_name, conv_module.weight()) if conv_module.bias() is not None: mod.register_buffer(bias_name, conv_module.bias()) with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) get_weight.meta["tensor_meta"] = _extract_tensor_metadata( conv_module.weight()) get_bias = None if conv_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) get_bias.meta["tensor_meta"] = _extract_tensor_metadata( conv_module.bias()) # Create kwargs for acc_op.conv kwargs = { "input": node.kwargs["input"], "weight": get_weight, "bias": get_bias, "stride": conv_module.stride, "padding": conv_module.padding, "dilation": conv_module.dilation, "groups": conv_module.groups, "padding_mode": conv_module.padding_mode, "acc_out_ty": acc_utils.build_raw_tensor_meta( q_scale=conv_module.scale, q_zero_point=conv_module.zero_point), } new_node = node.graph.call_function(quantized_conv2d, kwargs=kwargs) new_node.meta = node.meta return new_node
def proxy_call(func_overload, args, kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_out = func(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_out proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) with no_dispatch(): real_out = func_overload(*args, **kwargs) return wrap_output(real_out, proxy_out)
def _test_const_fold_tensor_meta(self, requires_grad): """ Verify tensor_meta is handled correctly. """ class ConstFoldTestModule(torch.nn.Module): def __init__(self): super().__init__() self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad) self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad) def forward(self, x, y): a = self.attr_1 + self.attr_1 x = x - a return x * y + self.attr_2 mod = ConstFoldTestModule() gm = torch.fx.symbolic_trace(mod) in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9]) ShapeProp(gm).propagate(in_x, in_y) mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm) self._verify_const_fold_mod(mod_folded) mod_folded.run_folding() for n in mod_folded.graph.nodes: if n.op == "get_attr": attr = self._get_attr(n) self.assertEquals(_extract_tensor_metadata(attr), n.meta["tensor_meta"]) # Now run both folded and non-folded to check results equal. base_result = mod(in_x, in_y) fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result))
def __init__(self, elem, proxy, *, requires_grad=None): if elem.is_sparse: proxy.node.meta['tensor_meta'] = {} else: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self) self.elem = elem self.proxy = proxy
def packed_quantized_linear_mapper(node: torch.fx.Node, mod: nn.Module) -> torch.fx.Node: """ Mapping from quantized_linear module to acc_op.linear. We unpack weight and bias in this mapper and pass them directly to linear node. """ linear_module = dict(mod.named_modules())[node.target] prefix = node.target.replace(".", "_") weight_name = f"{prefix}_weight" bias_name = f"{prefix}_bias" # Store weight and bias in the main module mod.register_buffer(weight_name, linear_module.weight()) if linear_module.bias() is not None: mod.register_buffer(bias_name, linear_module.bias()) with node.graph.inserting_before(node): # Insert get_attr nodes for weight and bias get_weight = node.graph.get_attr(weight_name) get_weight.meta["tensor_meta"] = _extract_tensor_metadata( linear_module.weight()) get_bias = None if linear_module.bias() is not None: get_bias = node.graph.get_attr(bias_name) get_bias.meta["tensor_meta"] = _extract_tensor_metadata( linear_module.bias()) # Create kwargs for acc_op.quantized_linear kwargs = { "input": node.kwargs["input"], "weight": get_weight, "bias": get_bias, "acc_out_ty": acc_utils.build_raw_tensor_meta( q_scale=linear_module.scale, q_zero_point=linear_module.zero_point), } new_node = node.graph.call_function(quantized_linear, kwargs=kwargs) new_node.meta = node.meta return new_node
def __init__(self, elem, proxy, *, requires_grad=None): if elem.is_sparse: proxy.node.meta['tensor_meta'] = {} else: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self) # This detects situations where you accidentally put a ProxyTensor # inside a ProxyTensor for the same trace; this is a layering violation assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer) self.elem = elem self.proxy = proxy
def __init__(self, elem, proxy, *, requires_grad=None, constant=None): # TODO: hack since _extract_tensor_metadata currently tries to access stride if elem.is_sparse or self.has_sym_ints: proxy.node.meta['tensor_meta'] = {} else: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self) # This detects situations where you accidentally put a ProxyTensor # inside a ProxyTensor for the same trace; this is a layering violation assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer) self.elem = elem self.proxy = proxy self.constant = constant
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) # Commenting this out for now since it causes some spurious failures (such as error checking) # if func == aten._local_scalar_dense: # raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " # "It's likely that this is caused by data-dependent control flow or similar.") def unwrap_proxy(e): return e.proxy if isinstance(e, PythonTensor) else e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_out = func(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_out proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata( args[0]) with no_dispatch(): real_out = func_overload(*args, **kwargs) def wrap_with_proxy(e, proxy): # Some ops (like native_batch_norm_backward) return undefined tensors that get # converted into None in python. # As the function signature expects tensors, if we directly return these None # tensors back to C++, we'll error. if e is None: e = torch.empty(()) if type(e) == torch.Tensor: return PythonTensor(e, proxy) else: return e if isinstance(real_out, tuple): return tuple( wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) elif isinstance(real_out, list): return [ wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out) ] elif isinstance(real_out, torch.Tensor): return wrap_with_proxy(real_out, proxy_out) else: return real_out
def __new__(cls, elem, proxy, *, requires_grad=None): # Hack to deal with super().__new__ not working for sparse tensors if elem.is_sparse or requires_grad is not None: r = torch.Tensor._make_subclass(cls, elem, requires_grad) else: r = super().__new__(cls, elem) # type: ignore[call-arg] if elem.is_sparse: proxy.node.meta['tensor_meta'] = {} else: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) r.proxy = proxy # type: ignore[attr-defined] return r
def __new__(cls, elem, proxy): # Wrapping something in PythonTensor implicitly detaches # gradients. If something required grad, we will collect it as if it # were a leaf. A consequence of detaching in this way is you # need to maintain a parameter cache when translating tensors # into PythonTensor, so you don't create multiple copies of # a gradient (they are aliased, but they would count as independent # leaves). An alternate strategy would be to avoid implicitly # detaching and instead "catch" gradients as they exit the # PythonTensor boundary. # assert not elem.requires_grad or not torch.is_grad_enabled() r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) r.proxy = proxy if elem.is_sparse: proxy.node.meta['tensor_meta'] = {} else: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) return r
def proxy_call(func_overload, args, kwargs=None): if kwargs is None: kwargs = {} func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e def unwrap_elem(e): if isinstance(e, ProxyTensor): return e.elem return e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_res = func_overload(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_res proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs)) # Needed to sync up metadata for in-place operators that modify metadata if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined] with no_dispatch(): func_overload(*args, **kwargs) # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs)) return wrap_output(inner_res, proxy_res)
def proxy_call(func_overload, args, kwargs=None): if kwargs is None: kwargs = {} func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: t, = args assert not kwargs if t.constant is not None: with maybe_disable_fake_tensor_mode(): return t.constant.item() raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e def unwrap_elem(e): if isinstance(e, ProxyTensor): return e.elem if isinstance(e, torch._C.SymbolicIntNode): if isinstance(e.get_pyobj(), ProxySymInt): return e.get_pyobj().sym_int else: raise RuntimeError(f"Something has gone wrong, we are trying to put SymInt {e.get_pyobj()} into the graph," f"even though it's not a ProxySymInt. This is a bug.") return e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_res = func_overload(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_res proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs)) # Needed to sync up metadata for in-place operators that modify metadata # TODO: instead forward the metadata to the inner tensor so updating # is not necessary if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined] with no_dispatch(): func_overload(*args, **kwargs) # In some circumstances, we will be tracing in a situation where a tensor # is *statically* known to be a constant (currently, this only happens if # you run torch.tensor; deterministic factory functions like torch.arange # don't get this treatment). When the tensor in question is small, it's # helpful to due constant propagation in case we call item() (in which # case we can return the constant value that is known, rather than give # an error.) The logic here tests if constant propagation is possible # (because all of the inputs are constant). If so, we disable fake tensor # mode (if it is on) and do true compute on the constant. # # It's worth highlighting that we're making a policy decision here. # There is a potential that the tensor is actually quite large, and we # don't actually want to run the compute. The tensor being quite large # is one of the reasons why factory functions don't get this treatment # (since they can be quite large; if a parameter is initialized to a # constant value it will be!) Similarly, there is also a potential # to run an operator that blows up the size of a small tensor; we don't # protect against this case, but we could force, e.g., only single # element constant computation by testing the numel of the result before # propagating const-ness. Similarly, we don't require the constant to # live on CPU, but we could. all_constant = True any_constant = False def check_constant(e): nonlocal all_constant, any_constant if isinstance(e, ProxyTensor): if e.constant is None: all_constant = False else: any_constant = True pytree.tree_map(check_constant, args) pytree.tree_map(check_constant, kwargs) def unwrap_constant(e): if isinstance(e, ProxyTensor): return e.constant return e constant = None # NB: do NOT include factories as constants if all_constant and any_constant: with maybe_disable_fake_tensor_mode(): constant = func_overload( *pytree.tree_map(unwrap_constant, args), **pytree.tree_map(unwrap_constant, kwargs) ) # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs)) return wrap_output(inner_res, proxy_res, constant=constant)
def proxy_call(proxy_mode, func_overload, args, kwargs=None): if kwargs is None: kwargs = {} func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: with proxy_mode.restore(): r = CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if r is not NotImplemented: return r # Some of these are not "real" aten ops and will fail if we # call _dispatch_has_kernel_for_dispatch_key on them. # This list is probably incomplete if func_overload not in [torch.ops.aten.size.default]: with proxy_mode.restore(): r = func_overload.decompose(*args, **kwargs) if r is not NotImplemented: return r tracer = proxy_mode.tracer f_args, f_kwargs = pytree.tree_map_only(torch.Tensor, fetch_tensor_proxy(tracer), (args, kwargs)) # If there are SymInts, we also should not consider this constant. # However, fake tensor handling of SymInts is sufficiently broken that # I couldn't write a test for this case all_constant = ( pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) # TODO: maybe constant SymInts should also be allowed? Not sure if # this can happen and pytree.tree_all_only(SymInt, lambda _: False, (args, kwargs))) if torch.Tag.data_dependent_output in func_overload.tags: # type: ignore[attr-defined] # Check if all of the Tensor inputs are constants if all_constant: const_args, const_kwargs = pytree.tree_map_only( _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)) with maybe_disable_fake_tensor_mode(): return func_overload(*const_args, **const_kwargs) raise RuntimeError( "It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." ) proxy_args, proxy_kwargs = pytree.tree_map_only( SymInt, fetch_symint_proxy(proxy_mode.tracer), pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))) proxy_out = func_overload(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": # This makes DCE marginally less likely to DCE inplace operations. # It is not strictly necessary args[0].proxy = proxy_out proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) out = func_overload(*args, **kwargs) # In some circumstances, we will be tracing in a situation where a tensor # is *statically* known to be a constant (currently, this only happens if # you run torch.tensor; deterministic factory functions like torch.arange # don't get this treatment). When the tensor in question is small, it's # helpful to due constant propagation in case we call item() (in which # case we can return the constant value that is known, rather than give # an error.) The logic here tests if constant propagation is possible # (because all of the inputs are constant). If so, we disable fake tensor # mode (if it is on) and do true compute on the constant. # # It's worth highlighting that we're making a policy decision here. # There is a potential that the tensor is actually quite large, and we # don't actually want to run the compute. The tensor being quite large # is one of the reasons why factory functions don't get this treatment # (since they can be quite large; if a parameter is initialized to a # constant value it will be!) Similarly, there is also a potential # to run an operator that blows up the size of a small tensor; we don't # protect against this case, but we could force, e.g., only single # element constant computation by testing the numel of the result before # propagating const-ness. Similarly, we don't require the constant to # live on CPU, but we could. any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) constant = None # NB: do NOT include factories as constants if (torch.Tag.nondeterministic_seeded not in func_overload.tags # type: ignore[attr-defined] and all_constant and any_constant and pytree.tree_all_only( torch.Tensor, lambda t: t.numel() <= CONSTANT_NUMEL_LIMIT, out)): with maybe_disable_fake_tensor_mode(): const_args, const_kwargs = pytree.tree_map_only( _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)) constant = func_overload(*const_args, **const_kwargs) track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) return out
def wrap_with_proxy(e, proxy, constant): if isinstance(e, torch.Tensor): track_tensor(e, proxy, tracer=tracer, constant=constant) if not e.is_sparse: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(e)