def proxy_call(func_overload, args, kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) if func_overload == aten._local_scalar_dense.default: raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " "It's likely that this is caused by data-dependent control flow or similar." "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check") def unwrap_proxy(e): return e.proxy if isinstance(e, ProxyTensor) else e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_out = func(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_out proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) with no_dispatch(): real_out = func_overload(*args, **kwargs) return wrap_output(real_out, proxy_out)
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented # For everything else, call the handler: fn = cls.handled_ops.get(func.__name__, None) if fn: return fn(*args, **kwargs or {}) else: # Note that here, because we don't need to provide the autograd formulas # we can have a default "fallback" that creates a plain Tensor based # on the diag elements and calls the func again. def unwrap(e): return e.diag.diag() if isinstance(e, DiagTensorBelow) else e def wrap(e): if isinstance(e, torch.Tensor) and e.ndim == 1: return DiagTensorBelow(e) if isinstance( e, torch.Tensor) and e.ndim == 2 and e.count_nonzero( ) == e.diag().count_nonzero(): return DiagTensorBelow(e.diag()) return e rs = tree_map( wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {}))) return rs
def decomposition_decorator(f): nonlocal registry if registry is None: registry = decomposition_table def add_op_to_table(aten_op): overloads = [] if isinstance(aten_op, torch._ops.OpOverload): overloads.append(aten_op) else: assert isinstance(aten_op, torch._ops.OpOverloadPacket) for ol in aten_op.overloads(): overloads.append(getattr(aten_op, ol)) for op_overload in overloads: if op_overload in registry: raise RuntimeError(f"duplicate registrations for {op_overload}") registry[op_overload] = f # TODO: factor this logic into OpOverload or Library API name = op_overload._schema.name if op_overload._schema.overload_name: name += "." + op_overload._schema.overload_name if ( not disable_meta # TorchScript dumps a bunch of extra nonsense overloads # which don't have corresponding dispatcher entries, we need # to filter those out and torch._C._dispatch_has_kernel(name) and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta') ): meta_lib.impl(op_overload, f) # To handle allowing multiple aten_ops at once tree_map(add_op_to_table, aten_op) return f
def __torch_function__(cls, func, types, args=(), kwargs=None): # Find process_group process_group = None def find_process_group(e): nonlocal process_group if process_group is None and isinstance(e, _PartialTensor): process_group = e._process_group tree_map(find_process_group, args) tree_map(find_process_group, kwargs) if func in _PARTIAL_TENSOR_OPS: return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group) # Need to disable all dispatch to print args and kwargs appropriately. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] try: with torch._C.DisableTorchFunction(): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for PartialTensor!") finally: del guard
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): if isinstance(e, InplaceLoggingTensor): return e.elem else: return e def wrap(e): if isinstance(e, torch.Tensor): return InplaceLoggingTensor(e) else: return e f = func # this subclass converts all `add()` ops into `add_()` ops if f is torch.ops.aten.add.Tensor: f = torch.ops.aten.add_.Tensor with cls.context(): rs = tree_map( wrap, f(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) # after running the (potentially transformed) op, # log the original op that we saw. logging.getLogger("LoggingTensor").info( f"{func.__module__}.{func.__name__}", args, kwargs, rs) return rs
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: if node.op not in CALLABLE_NODE_OPS: return False if node.target in [torch.ops.aten.embedding_dense_backward.default]: return False if node.target in [operator.getitem]: return True found_not_cuda = False def find_not_cuda(t): nonlocal found_not_cuda if isinstance(t, torch.Tensor) and t.device.type != 'cuda': found_not_cuda = True for n in node.all_input_nodes: tree_map(find_not_cuda, n.meta['fake_result']) tree_map(find_not_cuda, node.meta['fake_result']) # NB: factory function is accounted for because the result would be # cpu or cuda return not found_not_cuda
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore """Intercepts all operations performed on this handle object. Before any operation, the tensor attribute is unwrapped from the handle and used in the operation. We maintain a refernce to the tensor and its current versions to track if modifications have been made. If we detect changes to the tensor, we write it to the file maintained by the Handle. """ ssd_tensor_handles = [] def unwrap(e: Any) -> torch.Tensor: if isinstance(e, SsdTensorHandle): t = e.to_tensor() ssd_tensor_handles.append((e, t._version)) # type: ignore return t else: return e r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) for e, saved_version in ssd_tensor_handles: inplace_is_this_tensor = func.__name__[-1] == "_" and e is args[0] out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"] if inplace_is_this_tensor or out_is_this_tensor: e.to_file() return r
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception): with no_dispatch(): def to_cpu(e): if isinstance(e, FakeTensor): return torch.zeros_like(e, device="cpu") return e try: args = tree_map(to_cpu, args) kwargs = tree_map(to_cpu, kwargs) r = func(*args, **kwargs) except Exception as new_exception: raise orig_not_implemented_exception from new_exception tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): tensor_impls.add(e) storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will # not be set up, bc of conversion to cpu, error on reused impls for e in tree_flatten(r)[0]: if e in tensor_impls or (isinstance(e, torch.Tensor) and e.storage()._cdata in storages): raise orig_not_implemented_exception # we're only converting these to MetaTensors now, not Fake Tensors, # and the cpu inputs should be temporary. just convert outputs to meta # and continue return tree_map(MetaConverter(), r)
def _torch_dispatch(cls, func, types, args=(), kwargs=None): self.precision = saved_precision self.rel_tol = saved_rel_tol called.add(func) all_called[func] += 1 # Stuff we shouldn't bother testing # (TODO: remove detach from the decomp table?) if func not in decomposition_table or func in [ torch.ops.aten.detach.default ] or any_unsupported(args, kwargs): return func(*args, **kwargs) decomposed.add(func) all_decomposed.add(func) # We take 2 main strategies for verifying correctness/numerical stability of decompositions # The first one is simply tolerance checking between decomp_out and pytorch_out # However, for fp16/bf16 and reductions, this becomes very # finicky, as there are not many guarantees we can make. # So, for fp16/bf16, we instead compare the difference of # {decomp_out, pytorch_out_64} and {pytorch_out, # pytorch_out_64}. In other words, we compare how far the # decomposition and pytorch are from the "ground truth" (i.e. # fp64). If the decomposition results in more error, we error decomposition = decomposition_table[func] do_relative_check = test_dtype in [ torch.float16, torch.bfloat16 ] real_out_unflat = func(*args, **kwargs) real_out, _ = tree_flatten(real_out_unflat) decomp_out, _ = tree_flatten(decomposition(*args, **kwargs)) assert len(real_out) == len(decomp_out) if do_relative_check: upcast = partial(upcast_tensor, dtype=torch.float64) real_out_double, _ = tree_flatten( func(*tree_map(upcast, args), **tree_map(upcast, kwargs))) for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double): if orig is None: assert decomp is None continue op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs) else: for orig, decomp in zip(real_out, decomp_out): if orig is None: assert decomp is None continue op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs) return real_out_unflat
def wrapped(*args): phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined] fx_tracer = PythonKeyTracer() fake_tensor_mode: Any = nullcontext() if tracing_mode == "real": fake_tensor_mode = nullcontext() elif tracing_mode == "fake": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) elif tracing_mode == "symbolic": fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False) else: raise AssertionError(f"Unexpected tracing type: {tracing_mode}") proxy_mode = ProxyTorchDispatchMode( fx_tracer, trace_factory_functions=trace_factory_functions) def wrap_fake_concrete(x): if isinstance(x, torch.Tensor): return fake_tensor_mode.from_tensor( x) # type: ignore[attr-defined] return x shape_env = ShapeEnv() # todo: Figure out a more informative name for symints def wrap_fake_symbolic(x, sym_shape): if isinstance(x, torch.Tensor): val = FakeTensor(fake_tensor_mode, torch.empty(sym_shape, device="meta"), x.device) return val return x wrap_fn_map = { "real": lambda x: x, "fake": wrap_fake_concrete, } if tracing_mode == "symbolic": flat_shapes = shape_env.create_shapes_for_args(args) flat_args, spec = pytree.tree_flatten(args) args = pytree.tree_unflatten( list( map(lambda a: wrap_fake_symbolic(a[0], a[1]), zip(flat_args, flat_shapes))), spec) else: args = pytree.tree_map(wrap_fn_map[tracing_mode], args) with decompose( decomposition_table ), fake_tensor_mode, proxy_mode: # type: ignore[attr-defined] t = dispatch_trace(wrap_key(f, args, proxy_mode), tracer=fx_tracer, concrete_args=tuple(phs)) # TODO: kind of a bad way to do it, should maybe figure out a better way t.shape_env = shape_env # type: ignore[assignment] return t
def forward(ctx, *flat_tensor_args): nonlocal compiled_fw, compiled_bw, num_outs # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) if compiled_fw is None: flat_tensor_args = pytree.tree_map( lambda x: x.detach().requires_grad_(x.requires_grad) if isinstance(x, Tensor) else x, flat_tensor_args) fake_mode = FakeTensorMode.push( ) if config.use_fake_tensor else nullcontext() with preserve_rng_state(), fake_mode as mode: # Set input tensors that require grad to leaves fake_flat_tensor_args = pytree.tree_map( lambda x: mode.from_tensor(x) if mode else x if isinstance(x, Tensor) else x, flat_tensor_args) with torch.set_grad_enabled(grad_state): out = flat_fn(*fake_flat_tensor_args) out = pytree.tree_map( lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out) if isinstance(out, (list, tuple)): num_outs = len(out) else: num_outs = 1 joint_inputs = (fake_flat_tensor_args, out) aot_decompositions = { **aot_autograd_decompositions, **decompositions } with torch.set_grad_enabled(grad_state): fx_g = make_fx(joint_forward_backward, aot_decompositions)(*joint_inputs) if config.use_functionalize: # Functionalize the foward backward graph. First create a # fake fn to make functionalize happy def fake_fn(primals, tangents): return fx_g(primals, tangents) fx_g = make_fx( functionalize(fake_fn))(*joint_inputs) fw_module, bw_module = partition_fn(fx_g, joint_inputs) # print(fw_module.code, bw_module.code) compiled_fw = fw_compiler(fw_module, flat_tensor_args) fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs] compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) torch._C._jit_set_autocast_mode(old_jit_autocast_flag) ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs])
def wrapper(f): def add_func(op): meta_table[op] = f if register_dispatcher: name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__) meta_lib.impl(name, f) tree_map(add_func, op) return f
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs): """ Prototype ATen executor. Just executes the context's graph. """ if executor == "aten": return gm.forward(*args, **kwargs) elif executor == "nvfuser": if not torch.cuda.is_available(): raise RuntimeError( "Attempting to use nvFuser trace executor but CUDA is not available!" ) # PROTOTYPE nvfuser executor # Everything in the graph must support nvfuser fusion = Fusion() with FusionDefinition(fusion) as fd: class FusionInterpreter(torch.fx.Interpreter): def call_function(self, target, args, kwargs): target = target.impl_nvfuser args = (fd, ) + args return target(*args, **kwargs) def to_nv(arg): if isinstance(arg, torch.Tensor): x = fd.define_tensor(arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)) fd.add_input(x) return x else: return arg # Transforms graph to call nvfuser lowerings nv_args = tree_map(to_nv, args) nv_kwargs = tree_map(to_nv, kwargs) out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs) flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out) for o in flat_out: fd.add_output(o) return torch.utils._pytree.tree_unflatten( fusion.execute( tuple(arg for arg in args if isinstance(arg, torch.Tensor))), unflatten_spec, ) msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format( executor) raise ValueError(msg)
def wrapped(a): input_functional = torch._to_functional_tensor(a) torch._enable_functionalization(reapply_views=reapply_views) try: out = f(input_functional) finally: torch._disable_functionalization() torch._sync(input_functional) tree_map(torch._sync, out) out_unwrapped = tree_map(torch._from_functional_tensor, out) return out_unwrapped
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(x): return x.elem if isinstance(x, TorchDispatchTensor) else x def wrap(x): return TorchDispatchTensor(x) if isinstance(x, torch.Tensor) else x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs or {}) return tree_map(wrap, func(*args, **kwargs))
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)): out_dim = 0 batch_size = 4 generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims) batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm" ) # instance norm calls batch norm if opinfo is not None and opinfo.name in batch_norm_fns: generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size, bdims=bdims) for batched_args, in_dims, kwarg_values in generator: if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) else: loop_out = None # Used for debugging the resulting operations # from functorch import make_fx # def f(a): # return op(a) # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values) yield (loop_out, batched_out) # Tests case where we dispatch to a batching rule with no bdims # This should be handled by autogenerated plumbing. For vmap support # added via a manual plumbing you may need to handle this specially. def add_bdim_if_tensor(x): if isinstance(x, torch.Tensor): return x.unsqueeze(1) return x def f(dummy, *args, **kwargs): return op(*args, **kwargs) dummy = torch.ones(batch_size, 1) expected = pytree.tree_map(add_bdim_if_tensor, batched_out) inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0, ) + in_dims output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) yield (expected, output)
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception): # these should all be supported, just to be safe # avoid fallback for operators which inplace modify metadata # because the input fake tensors would be umodified if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] raise orig_not_implemented_exception with no_dispatch(): inp_impls = {} def to_real_tensor(e): if isinstance(e, FakeTensor): out = torch.zeros_like(e, device=e.fake_device) if e.is_sparse: out._coalesced_(e.is_coalesced()) inp_impls[id(out)] = e return out return e args = tree_map(to_real_tensor, args) kwargs = tree_map(to_real_tensor, kwargs) r = func(*args, **kwargs) tensor_impls = set() storages = set() for e in tree_flatten((args, kwargs))[0]: if isinstance(e, torch.Tensor): if not e.is_sparse: storages.add(e.storage()._cdata) # TODO: also check metadata change on inputs # proper aliasing/metadata relationship between outputs and inputs will # not be set up, bc of conversion to device, unless we can reuse an # input impl for e in tree_flatten(r)[0]: if id(e) not in inp_impls and ( isinstance(e, torch.Tensor) and not e.is_sparse and e.storage()._cdata in storages ): raise orig_not_implemented_exception def map_out(e): if isinstance(e, torch.Tensor): if id(e) in inp_impls: return inp_impls[id(e)] else: return fake_mode.fake_tensor_converter(fake_mode, e) else: return e return tree_map(map_out, r)
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e) -> torch.Tensor: if isinstance(e, NonRewrappingTensor): t = e.tensor return t else: return e r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) # Return an unwrapped tensor no longer of original subclass type. return r
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, cls) else e def wrap(e): return cls(e) if isinstance(e, torch.Tensor) else e global schema_check_recorded_ops schema_check_recorded_ops.append(func.__name__) out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) return tree_map(wrap, out)
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, cls) else e def wrap(e): return cls(e) if isinstance(e, torch.Tensor) else e with cls.context(): rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) return rs
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, LoggingTensor) else e def wrap(e): return LoggingTensor(e) if isinstance(e, torch.Tensor) else e # TODO: handle kwargs assert not kwargs rs = tree_map(wrap, func(*tree_map(unwrap, args))) logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, rs) return rs
def check_requires_grad(*args, **kwargs): requires_grad = False def check_grad(e): nonlocal requires_grad if isinstance(e, TorchMLIRTensor): requires_grad |= e.requires_grad tree_map(check_grad, args) tree_map(check_grad, kwargs) return requires_grad
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, CompositeCompliantTensor) else e def wrap(e): return CompositeCompliantTensor(e) if isinstance( e, torch.Tensor) else e if func.overloadpacket.__name__ in ('set_', 'resize_'): raise RuntimeError( f"{func.__name__} is not allowed to be called inside of " f"CompositeImplicitAutograd operators.") with no_dispatch(): unwrapped_args = tree_map(unwrap, args) unwrapped_kwargs = tree_map(unwrap, kwargs) unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs) rs = tree_map(wrap, unwrapped_rs) if is_view_fn(func): # Autograd asserts that for B = A.view_fn(...), B and A's storages # are the same. Here we try to make B alias A to avoid those asserts. # See https://github.com/pytorch/pytorch/issues/65339 for more information # about the issue. with no_dispatch(): # Idea: this is a weird way of getting a storage that aliases the input. # This is a workaround for #65339. # 1. under no_dispatch, all of the wrapper tensors look like regular # tensors with special storage (the storage is nullptr and # advertises CPU/CUDA device. # 2. we run func, which ends up running the view operation # 3. All view operations reuse the input's storage and return # result Tensor(s) with new sizes/strides/offset that alias # the input. # 4. we set the storage (and sizes/strides/offset) of the wrapper # tensor results to be that of the tensors that alias the input result = func(*args, **kwargs) if isinstance(result, tuple) or isinstance(result, list): for a, b in zip(rs, result): a.set_(b) else: rs.set_(result) # Some operations are allowed to in-place modify the metadata of the # inputs. The only ones are the "inplace view functions"; when we # run into these, we manually modify the metadata of the input. with no_dispatch(): if is_inplace_view_fn(func): func(*args, **kwargs) # For each CompositeCompliantTensor t, we check that t and t.elem # have consistent metadata. If they don't have consistent metadata, # that means the operator did something fishy. check = partial(check_metadata_consistency) tree_map(check, args) tree_map(check, kwargs) tree_map(check, rs) return rs
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, cls) else e def wrap(e): return cls(e) if isinstance(e, torch.Tensor) else e unwrapped_args = tree_map(unwrap, args) out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) if func._schema.name in IncorrectAliasTensor.INCORRECT_OPS: args[0].elem = out return tree_map(wrap, out)
def compute_quantities_for_vmap_test(op, orig_batched_args, orig_kwarg_values, in_dims, out_dim=0, batch_size=2, compute_loop_out=True, clone_inputs=False): def maybe_clone_inputs(): if clone_inputs: batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args) kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values) return batched_args, kwarg_values return orig_batched_args, orig_kwarg_values batched_args, kwarg_values = maybe_clone_inputs() if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) else: loop_out = None # Used for debugging the resulting operations # from functorch import make_fx # def f(a): # return op(a) # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) batched_args, kwarg_values = maybe_clone_inputs() batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values) yield (loop_out, batched_out) # Tests case where we dispatch to a batching rule with no bdims # This should be handled by autogenerated plumbing. For vmap support # added via a manual plumbing you may need to handle this specially. def add_bdim_if_tensor(x): if isinstance(x, torch.Tensor): return x.unsqueeze(1) return x def f(dummy, *args, **kwargs): return op(*args, **kwargs) dummy = torch.ones(batch_size, 1) expected = pytree.tree_map(add_bdim_if_tensor, batched_out) inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0, ) + in_dims batched_args, kwarg_values = maybe_clone_inputs() output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) yield (expected, output)
def check_with_mode(op, args, kwargs): def wrap(e): return CompositeCompliantTensor(e) if isinstance(e, torch.Tensor) else e args = tree_map(wrap, args) kwargs = tree_map(wrap, kwargs) try: with enable_python_mode(CompositeCompliantTensor): op(*args, **kwargs) # see NOTE: [What errors are Composite Compiance trying to catch?] except RuntimeError as err: raise_composite_compliance_error(err)
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, NonWrapperSublass) else e def wrap(e): return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. with no_dispatch(): rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) return rs
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): func = func_overload.overloadpacket if func_overload in CURRENT_DECOMPOSITION_TABLE: return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) # Commenting this out for now since it causes some spurious failures (such as error checking) # if func == aten._local_scalar_dense: # raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " # "It's likely that this is caused by data-dependent control flow or similar.") def unwrap_proxy(e): return e.proxy if isinstance(e, PythonTensor) else e proxy_args = pytree.tree_map(unwrap_proxy, args) proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) proxy_out = func(*proxy_args, **proxy_kwargs) # Kind of a hacky way to test if an op is in-place or not if func.__name__[-1] == "_" and func.__name__[0] != "_": args[0].proxy = proxy_out proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata( args[0]) with no_dispatch(): real_out = func_overload(*args, **kwargs) def wrap_with_proxy(e, proxy): # Some ops (like native_batch_norm_backward) return undefined tensors that get # converted into None in python. # As the function signature expects tensors, if we directly return these None # tensors back to C++, we'll error. if e is None: e = torch.empty(()) if type(e) == torch.Tensor: return PythonTensor(e, proxy) else: return e if isinstance(real_out, tuple): return tuple( wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) elif isinstance(real_out, list): return [ wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out) ] elif isinstance(real_out, torch.Tensor): return wrap_with_proxy(real_out, proxy_out) else: return real_out
def _find_common_device(func, args, kwargs): # cpu - zero-dim tensors can be called in cuda kernels, # so overwrite the common_device if it the only existing # device comes from a cpu zero-dim tensor common_device = None is_cpu_zero_dim = None def cpu_zero_dim(t): return t.device.type == "cpu" and t.dim() == 0 def merge_devices(t): nonlocal common_device nonlocal is_cpu_zero_dim if not isinstance(t, FakeTensor): return if common_device is None: common_device = t.device is_cpu_zero_dim = cpu_zero_dim(t) return t_is_cpu_zero_dim = cpu_zero_dim(t) if t.device == common_device: if is_cpu_zero_dim: is_cpu_zero_dim = t_is_cpu_zero_dim return # mismatching devices ! # if current tensor is cpu 0 dim, defer to existing device if t_is_cpu_zero_dim: return # current device is from cpu 0 dim tensor, overwrite if is_cpu_zero_dim: common_device = t.device is_cpu_zero_dim = t_is_cpu_zero_dim return # mismatching devices of non-zero dim tensors, throw # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as raise Exception( f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" ) tree_map(merge_devices, args) tree_map(merge_devices, kwargs) assert common_device is not None, f"Could not find common device for {func}" return common_device
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def unwrap(e): return e.elem if isinstance(e, NonWrapperSubclass) else e def wrap(e): return NonWrapperSubclass(e) if isinstance( e, torch.Tensor) else e rs = tree_map( wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) logging.getLogger("NonWrapperSubclass").info( f"{func.__module__}.{func.__name__}", args, kwargs, rs) return rs