Esempio n. 1
0
def proxy_call(func_overload, args, kwargs=None):
    func = func_overload.overloadpacket
    if func_overload in CURRENT_DECOMPOSITION_TABLE:
        return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
    if func_overload == aten._local_scalar_dense.default:
        raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
                           "It's likely that this is caused by data-dependent control flow or similar."
                           "Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")

    def unwrap_proxy(e):
        return e.proxy if isinstance(e, ProxyTensor) else e

    proxy_args = pytree.tree_map(unwrap_proxy, args)
    proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

    proxy_out = func(*proxy_args, **proxy_kwargs)

    # Kind of a hacky way to test if an op is in-place or not
    if func.__name__[-1] == "_" and func.__name__[0] != "_":
        args[0].proxy = proxy_out
        proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])

    with no_dispatch():
        real_out = func_overload(*args, **kwargs)

    return wrap_output(real_out, proxy_out)
Esempio n. 2
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        if not all(issubclass(cls, t) for t in types):
            return NotImplemented

        # For everything else, call the handler:
        fn = cls.handled_ops.get(func.__name__, None)
        if fn:
            return fn(*args, **kwargs or {})
        else:
            # Note that here, because we don't need to provide the autograd formulas
            # we can have a default "fallback" that creates a plain Tensor based
            # on the diag elements and calls the func again.

            def unwrap(e):
                return e.diag.diag() if isinstance(e, DiagTensorBelow) else e

            def wrap(e):
                if isinstance(e, torch.Tensor) and e.ndim == 1:
                    return DiagTensorBelow(e)
                if isinstance(
                        e, torch.Tensor) and e.ndim == 2 and e.count_nonzero(
                        ) == e.diag().count_nonzero():
                    return DiagTensorBelow(e.diag())
                return e

            rs = tree_map(
                wrap,
                func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs
                                                         or {})))
            return rs
Esempio n. 3
0
    def decomposition_decorator(f):
        nonlocal registry
        if registry is None:
            registry = decomposition_table

        def add_op_to_table(aten_op):
            overloads = []
            if isinstance(aten_op, torch._ops.OpOverload):
                overloads.append(aten_op)
            else:
                assert isinstance(aten_op, torch._ops.OpOverloadPacket)
                for ol in aten_op.overloads():
                    overloads.append(getattr(aten_op, ol))
            for op_overload in overloads:
                if op_overload in registry:
                    raise RuntimeError(f"duplicate registrations for {op_overload}")
                registry[op_overload] = f
                # TODO: factor this logic into OpOverload or Library API
                name = op_overload._schema.name
                if op_overload._schema.overload_name:
                    name += "." + op_overload._schema.overload_name
                if (
                    not disable_meta
                    # TorchScript dumps a bunch of extra nonsense overloads
                    # which don't have corresponding dispatcher entries, we need
                    # to filter those out
                    and torch._C._dispatch_has_kernel(name)
                    and not torch._C._dispatch_has_kernel_for_dispatch_key(name, 'Meta')
                ):
                    meta_lib.impl(op_overload, f)

        # To handle allowing multiple aten_ops at once
        tree_map(add_op_to_table, aten_op)
        return f
Esempio n. 4
0
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # Find process_group
        process_group = None

        def find_process_group(e):
            nonlocal process_group
            if process_group is None and isinstance(e, _PartialTensor):
                process_group = e._process_group

        tree_map(find_process_group, args)
        tree_map(find_process_group, kwargs)

        if func in _PARTIAL_TENSOR_OPS:
            return _PARTIAL_TENSOR_OPS[func](types, args, kwargs,
                                             process_group)

        # Need to disable all dispatch to print args and kwargs appropriately.
        guard = torch._C._DisableTorchDispatch()  # type: ignore[attr-defined]
        try:
            with torch._C.DisableTorchFunction():
                raise RuntimeError(
                    f"torch function '{func.__name__}', with args: {args} and "
                    f"kwargs: {kwargs} not supported for PartialTensor!")
        finally:
            del guard
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            if isinstance(e, InplaceLoggingTensor):
                return e.elem
            else:
                return e

        def wrap(e):
            if isinstance(e, torch.Tensor):
                return InplaceLoggingTensor(e)
            else:
                return e

        f = func
        # this subclass converts all `add()` ops into `add_()` ops
        if f is torch.ops.aten.add.Tensor:
            f = torch.ops.aten.add_.Tensor

        with cls.context():
            rs = tree_map(
                wrap, f(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
        # after running the (potentially transformed) op,
        # log the original op that we saw.
        logging.getLogger("LoggingTensor").info(
            f"{func.__module__}.{func.__name__}", args, kwargs, rs)
        return rs
Esempio n. 6
0
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        if node.op not in CALLABLE_NODE_OPS:
            return False

        if node.target in [torch.ops.aten.embedding_dense_backward.default]:
            return False

        if node.target in [operator.getitem]:
            return True

        found_not_cuda = False

        def find_not_cuda(t):
            nonlocal found_not_cuda
            if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
                found_not_cuda = True

        for n in node.all_input_nodes:
            tree_map(find_not_cuda, n.meta['fake_result'])

        tree_map(find_not_cuda, node.meta['fake_result'])

        # NB: factory function is accounted for because the result would be
        # cpu or cuda

        return not found_not_cuda
Esempio n. 7
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):  # type: ignore
        """Intercepts all operations performed on this handle object.

        Before any operation, the tensor attribute is unwrapped from the handle
        and used in the operation. We maintain a refernce to the tensor and its current
        versions to track if modifications have been made. If we detect changes to the
        tensor, we write it to the file maintained by the Handle.
        """
        ssd_tensor_handles = []

        def unwrap(e: Any) -> torch.Tensor:
            if isinstance(e, SsdTensorHandle):
                t = e.to_tensor()
                ssd_tensor_handles.append((e, t._version))  # type: ignore
                return t
            else:
                return e

        r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

        for e, saved_version in ssd_tensor_handles:
            inplace_is_this_tensor = func.__name__[-1] == "_" and e is args[0]
            out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
            if inplace_is_this_tensor or out_is_this_tensor:
                e.to_file()
        return r
Esempio n. 8
0
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception):
    with no_dispatch():

        def to_cpu(e):
            if isinstance(e, FakeTensor):
                return torch.zeros_like(e, device="cpu")
            return e

        try:
            args = tree_map(to_cpu, args)
            kwargs = tree_map(to_cpu, kwargs)
            r = func(*args, **kwargs)
        except Exception as new_exception:
            raise orig_not_implemented_exception from new_exception

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                tensor_impls.add(e)
                storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to cpu, error on reused impls
        for e in tree_flatten(r)[0]:
            if e in tensor_impls or (isinstance(e, torch.Tensor)
                                     and e.storage()._cdata in storages):
                raise orig_not_implemented_exception

    # we're only converting these to MetaTensors now, not Fake Tensors,
    # and the cpu inputs should be temporary. just convert outputs to meta
    # and continue
    return tree_map(MetaConverter(), r)
Esempio n. 9
0
            def _torch_dispatch(cls, func, types, args=(), kwargs=None):
                self.precision = saved_precision
                self.rel_tol = saved_rel_tol

                called.add(func)
                all_called[func] += 1

                # Stuff we shouldn't bother testing
                # (TODO: remove detach from the decomp table?)
                if func not in decomposition_table or func in [
                        torch.ops.aten.detach.default
                ] or any_unsupported(args, kwargs):
                    return func(*args, **kwargs)

                decomposed.add(func)
                all_decomposed.add(func)

                # We take 2 main strategies for verifying correctness/numerical stability of decompositions
                # The first one is simply tolerance checking between decomp_out and pytorch_out
                # However, for fp16/bf16 and reductions, this becomes very
                # finicky, as there are not many guarantees we can make.
                # So, for fp16/bf16, we instead compare the difference of
                # {decomp_out, pytorch_out_64} and {pytorch_out,
                # pytorch_out_64}. In other words, we compare how far the
                # decomposition and pytorch are from the "ground truth" (i.e.
                # fp64). If the decomposition results in more error, we error

                decomposition = decomposition_table[func]

                do_relative_check = test_dtype in [
                    torch.float16, torch.bfloat16
                ]
                real_out_unflat = func(*args, **kwargs)
                real_out, _ = tree_flatten(real_out_unflat)
                decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
                assert len(real_out) == len(decomp_out)

                if do_relative_check:
                    upcast = partial(upcast_tensor, dtype=torch.float64)
                    real_out_double, _ = tree_flatten(
                        func(*tree_map(upcast, args),
                             **tree_map(upcast, kwargs)))
                    for i, orig, decomp, ref in zip(range(len(real_out)),
                                                    real_out, decomp_out,
                                                    real_out_double):
                        if orig is None:
                            assert decomp is None
                            continue
                        op_assert_ref(self, func, test_dtype, i, orig, decomp,
                                      ref, args, kwargs)
                else:
                    for orig, decomp in zip(real_out, decomp_out):
                        if orig is None:
                            assert decomp is None
                            continue
                        op_assert_equal(self, func, test_dtype, orig, decomp,
                                        args, kwargs)

                return real_out_unflat
Esempio n. 10
0
    def wrapped(*args):
        phs = pytree.tree_map(lambda _: fx.PH,
                              args)  # type: ignore[attr-defined]
        fx_tracer = PythonKeyTracer()
        fake_tensor_mode: Any = nullcontext()
        if tracing_mode == "real":
            fake_tensor_mode = nullcontext()
        elif tracing_mode == "fake":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
        elif tracing_mode == "symbolic":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
        else:
            raise AssertionError(f"Unexpected tracing type: {tracing_mode}")

        proxy_mode = ProxyTorchDispatchMode(
            fx_tracer, trace_factory_functions=trace_factory_functions)

        def wrap_fake_concrete(x):
            if isinstance(x, torch.Tensor):
                return fake_tensor_mode.from_tensor(
                    x)  # type: ignore[attr-defined]

            return x

        shape_env = ShapeEnv()

        # todo: Figure out a more informative name for symints
        def wrap_fake_symbolic(x, sym_shape):
            if isinstance(x, torch.Tensor):
                val = FakeTensor(fake_tensor_mode,
                                 torch.empty(sym_shape, device="meta"),
                                 x.device)
                return val
            return x

        wrap_fn_map = {
            "real": lambda x: x,
            "fake": wrap_fake_concrete,
        }
        if tracing_mode == "symbolic":
            flat_shapes = shape_env.create_shapes_for_args(args)
            flat_args, spec = pytree.tree_flatten(args)
            args = pytree.tree_unflatten(
                list(
                    map(lambda a: wrap_fake_symbolic(a[0], a[1]),
                        zip(flat_args, flat_shapes))), spec)
        else:
            args = pytree.tree_map(wrap_fn_map[tracing_mode], args)

        with decompose(
                decomposition_table
        ), fake_tensor_mode, proxy_mode:  # type: ignore[attr-defined]
            t = dispatch_trace(wrap_key(f, args, proxy_mode),
                               tracer=fx_tracer,
                               concrete_args=tuple(phs))

        # TODO: kind of a bad way to do it, should maybe figure out a better way
        t.shape_env = shape_env  # type: ignore[assignment]
        return t
Esempio n. 11
0
        def forward(ctx, *flat_tensor_args):
            nonlocal compiled_fw, compiled_bw, num_outs
            # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
            # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
            old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
            if compiled_fw is None:
                flat_tensor_args = pytree.tree_map(
                    lambda x: x.detach().requires_grad_(x.requires_grad)
                    if isinstance(x, Tensor) else x, flat_tensor_args)
                fake_mode = FakeTensorMode.push(
                ) if config.use_fake_tensor else nullcontext()
                with preserve_rng_state(), fake_mode as mode:
                    # Set input tensors that require grad to leaves
                    fake_flat_tensor_args = pytree.tree_map(
                        lambda x: mode.from_tensor(x) if mode else x
                        if isinstance(x, Tensor) else x, flat_tensor_args)
                    with torch.set_grad_enabled(grad_state):
                        out = flat_fn(*fake_flat_tensor_args)
                    out = pytree.tree_map(
                        lambda x: x.detach().contiguous()
                        if isinstance(x, Tensor) else x, out)

                    if isinstance(out, (list, tuple)):
                        num_outs = len(out)
                    else:
                        num_outs = 1

                    joint_inputs = (fake_flat_tensor_args, out)
                    aot_decompositions = {
                        **aot_autograd_decompositions,
                        **decompositions
                    }
                    with torch.set_grad_enabled(grad_state):
                        fx_g = make_fx(joint_forward_backward,
                                       aot_decompositions)(*joint_inputs)

                        if config.use_functionalize:
                            # Functionalize the foward backward graph. First create a
                            # fake fn to make functionalize happy
                            def fake_fn(primals, tangents):
                                return fx_g(primals, tangents)

                            fx_g = make_fx(
                                functionalize(fake_fn))(*joint_inputs)
                fw_module, bw_module = partition_fn(fx_g, joint_inputs)
                # print(fw_module.code, bw_module.code)

                compiled_fw = fw_compiler(fw_module, flat_tensor_args)
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

                bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
                compiled_bw = bw_compiler(bw_module, bw_args)
            else:
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
            torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
            ctx.save_for_backward(*fw_outs[num_outs:])
            return tuple(fw_outs[0:num_outs])
Esempio n. 12
0
    def wrapper(f):
        def add_func(op):
            meta_table[op] = f
            if register_dispatcher:
                name = (op.__name__ if op._overloadname != "default" else
                        op.overloadpacket.__name__)
                meta_lib.impl(name, f)

        tree_map(add_func, op)
        return f
Esempio n. 13
0
def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs):
    """
    Prototype ATen executor.

    Just executes the context's graph.
    """

    if executor == "aten":
        return gm.forward(*args, **kwargs)
    elif executor == "nvfuser":
        if not torch.cuda.is_available():
            raise RuntimeError(
                "Attempting to use nvFuser trace executor but CUDA is not available!"
            )

        # PROTOTYPE nvfuser executor
        # Everything in the graph must support nvfuser

        fusion = Fusion()
        with FusionDefinition(fusion) as fd:

            class FusionInterpreter(torch.fx.Interpreter):
                def call_function(self, target, args, kwargs):
                    target = target.impl_nvfuser
                    args = (fd, ) + args
                    return target(*args, **kwargs)

            def to_nv(arg):
                if isinstance(arg, torch.Tensor):
                    x = fd.define_tensor(arg.size(), arg.stride(),
                                         getnvFuserDtype(arg.dtype))
                    fd.add_input(x)
                    return x
                else:
                    return arg

            # Transforms graph to call nvfuser lowerings
            nv_args = tree_map(to_nv, args)
            nv_kwargs = tree_map(to_nv, kwargs)

            out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs)
            flat_out, unflatten_spec = torch.utils._pytree.tree_flatten(out)
            for o in flat_out:
                fd.add_output(o)

            return torch.utils._pytree.tree_unflatten(
                fusion.execute(
                    tuple(arg for arg in args
                          if isinstance(arg, torch.Tensor))),
                unflatten_spec,
            )

    msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
        executor)
    raise ValueError(msg)
Esempio n. 14
0
 def wrapped(a):
     input_functional = torch._to_functional_tensor(a)
     torch._enable_functionalization(reapply_views=reapply_views)
     try:
         out = f(input_functional)
     finally:
         torch._disable_functionalization()
     torch._sync(input_functional)
     tree_map(torch._sync, out)
     out_unwrapped = tree_map(torch._from_functional_tensor, out)
     return out_unwrapped
Esempio n. 15
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(x):
            return x.elem if isinstance(x, TorchDispatchTensor) else x

        def wrap(x):
            return TorchDispatchTensor(x) if isinstance(x, torch.Tensor) else x

        args = tree_map(unwrap, args)
        kwargs = tree_map(unwrap, kwargs or {})

        return tree_map(wrap, func(*args, **kwargs))
Esempio n. 16
0
def get_fallback_and_vmap_exhaustive(op,
                                     arg_values,
                                     kwarg_values,
                                     opinfo=None,
                                     compute_loop_out=True,
                                     bdims=(0, -1)):
    out_dim = 0
    batch_size = 4
    generator = get_exhaustive_batched_inputs(arg_values,
                                              kwarg_values,
                                              batch_size,
                                              bdims=bdims)
    batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm"
                      )  # instance norm calls batch norm
    if opinfo is not None and opinfo.name in batch_norm_fns:
        generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values,
                                                                 kwarg_values,
                                                                 batch_size,
                                                                 bdims=bdims)
    for batched_args, in_dims, kwarg_values in generator:
        if compute_loop_out:
            loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args,
                            **kwarg_values)
        else:
            loop_out = None
        # Used for debugging the resulting operations
        # from functorch import make_fx
        # def f(a):
        #     return op(a)
        # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
        # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
        batched_out = vmap(op, in_dims=in_dims,
                           out_dims=out_dim)(*batched_args, **kwarg_values)
        yield (loop_out, batched_out)

        # Tests case where we dispatch to a batching rule with no bdims
        # This should be handled by autogenerated plumbing. For vmap support
        # added via a manual plumbing you may need to handle this specially.
        def add_bdim_if_tensor(x):
            if isinstance(x, torch.Tensor):
                return x.unsqueeze(1)
            return x

        def f(dummy, *args, **kwargs):
            return op(*args, **kwargs)

        dummy = torch.ones(batch_size, 1)
        expected = pytree.tree_map(add_bdim_if_tensor, batched_out)

        inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims)
        outer_in_dims = (0, ) + in_dims
        output = vmap(vmap(f, inner_in_dims),
                      outer_in_dims)(dummy, *batched_args, **kwarg_values)
        yield (expected, output)
Esempio n. 17
0
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
    # these should all be supported, just to be safe
    # avoid fallback for operators which inplace modify metadata
    # because the input fake tensors would be umodified
    if torch.Tag.inplace_view in func.tags:  # type: ignore[attr-defined]
        raise orig_not_implemented_exception

    with no_dispatch():
        inp_impls = {}

        def to_real_tensor(e):
            if isinstance(e, FakeTensor):
                out = torch.zeros_like(e, device=e.fake_device)
                if e.is_sparse:
                    out._coalesced_(e.is_coalesced())
                inp_impls[id(out)] = e
                return out
            return e

        args = tree_map(to_real_tensor, args)
        kwargs = tree_map(to_real_tensor, kwargs)

        r = func(*args, **kwargs)

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                if not e.is_sparse:
                    storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to device, unless we can reuse an
        # input impl
        for e in tree_flatten(r)[0]:
            if id(e) not in inp_impls and (
                isinstance(e, torch.Tensor)
                and not e.is_sparse
                and e.storage()._cdata in storages
            ):
                raise orig_not_implemented_exception

    def map_out(e):
        if isinstance(e, torch.Tensor):
            if id(e) in inp_impls:
                return inp_impls[id(e)]
            else:
                return fake_mode.fake_tensor_converter(fake_mode, e)
        else:
            return e

    return tree_map(map_out, r)
Esempio n. 18
0
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e) -> torch.Tensor:
                    if isinstance(e, NonRewrappingTensor):
                        t = e.tensor
                        return t
                    else:
                        return e

                r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
                # Return an unwrapped tensor no longer of original subclass type.
                return r
Esempio n. 19
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return e.elem if isinstance(e, cls) else e

        def wrap(e):
            return cls(e) if isinstance(e, torch.Tensor) else e

        global schema_check_recorded_ops
        schema_check_recorded_ops.append(func.__name__)
        out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
        return tree_map(wrap, out)
Esempio n. 20
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return e.elem if isinstance(e, cls) else e

        def wrap(e):
            return cls(e) if isinstance(e, torch.Tensor) else e

        with cls.context():
            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
        logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
        return rs
Esempio n. 21
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return e.elem if isinstance(e, LoggingTensor) else e

        def wrap(e):
            return LoggingTensor(e) if isinstance(e, torch.Tensor) else e

        # TODO: handle kwargs
        assert not kwargs
        rs = tree_map(wrap, func(*tree_map(unwrap, args)))
        logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, rs)
        return rs
Esempio n. 22
0
def check_requires_grad(*args, **kwargs):
    requires_grad = False

    def check_grad(e):
        nonlocal requires_grad
        if isinstance(e, TorchMLIRTensor):
            requires_grad |= e.requires_grad

    tree_map(check_grad, args)
    tree_map(check_grad, kwargs)

    return requires_grad
Esempio n. 23
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return e.elem if isinstance(e, CompositeCompliantTensor) else e

        def wrap(e):
            return CompositeCompliantTensor(e) if isinstance(
                e, torch.Tensor) else e

        if func.overloadpacket.__name__ in ('set_', 'resize_'):
            raise RuntimeError(
                f"{func.__name__} is not allowed to be called inside of "
                f"CompositeImplicitAutograd operators.")

        with no_dispatch():
            unwrapped_args = tree_map(unwrap, args)
            unwrapped_kwargs = tree_map(unwrap, kwargs)
            unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
            rs = tree_map(wrap, unwrapped_rs)

        if is_view_fn(func):
            # Autograd asserts that for B = A.view_fn(...), B and A's storages
            # are the same. Here we try to make B alias A to avoid those asserts.
            # See https://github.com/pytorch/pytorch/issues/65339 for more information
            # about the issue.
            with no_dispatch():
                # Idea: this is a weird way of getting a storage that aliases the input.
                # This is a workaround for #65339.
                # 1. under no_dispatch, all of the wrapper tensors look like regular
                #    tensors with special storage (the storage is nullptr and
                #    advertises CPU/CUDA device.
                # 2. we run func, which ends up running the view operation
                # 3. All view operations reuse the input's storage and return
                #    result Tensor(s) with new sizes/strides/offset that alias
                #    the input.
                # 4. we set the storage (and sizes/strides/offset) of the wrapper
                #    tensor results to be that of the tensors that alias the input
                result = func(*args, **kwargs)
                if isinstance(result, tuple) or isinstance(result, list):
                    for a, b in zip(rs, result):
                        a.set_(b)
                else:
                    rs.set_(result)

        # Some operations are allowed to in-place modify the metadata of the
        # inputs. The only ones are the "inplace view functions"; when we
        # run into these, we manually modify the metadata of the input.
        with no_dispatch():
            if is_inplace_view_fn(func):
                func(*args, **kwargs)

        # For each CompositeCompliantTensor t, we check that t and t.elem
        # have consistent metadata. If they don't have consistent metadata,
        # that means the operator did something fishy.
        check = partial(check_metadata_consistency)
        tree_map(check, args)
        tree_map(check, kwargs)
        tree_map(check, rs)
        return rs
Esempio n. 24
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(e):
            return e.elem if isinstance(e, cls) else e

        def wrap(e):
            return cls(e) if isinstance(e, torch.Tensor) else e
        unwrapped_args = tree_map(unwrap, args)
        out = func(*unwrapped_args, **tree_map(unwrap, kwargs))
        if func._schema.name in IncorrectAliasTensor.INCORRECT_OPS:
            args[0].elem = out

        return tree_map(wrap, out)
Esempio n. 25
0
def compute_quantities_for_vmap_test(op,
                                     orig_batched_args,
                                     orig_kwarg_values,
                                     in_dims,
                                     out_dim=0,
                                     batch_size=2,
                                     compute_loop_out=True,
                                     clone_inputs=False):
    def maybe_clone_inputs():
        if clone_inputs:
            batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args)
            kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values)
            return batched_args, kwarg_values
        return orig_batched_args, orig_kwarg_values

    batched_args, kwarg_values = maybe_clone_inputs()
    if compute_loop_out:
        loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args,
                        **kwarg_values)
    else:
        loop_out = None
    # Used for debugging the resulting operations
    # from functorch import make_fx
    # def f(a):
    #     return op(a)
    # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
    # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
    batched_args, kwarg_values = maybe_clone_inputs()
    batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args,
                                                              **kwarg_values)
    yield (loop_out, batched_out)

    # Tests case where we dispatch to a batching rule with no bdims
    # This should be handled by autogenerated plumbing. For vmap support
    # added via a manual plumbing you may need to handle this specially.
    def add_bdim_if_tensor(x):
        if isinstance(x, torch.Tensor):
            return x.unsqueeze(1)
        return x

    def f(dummy, *args, **kwargs):
        return op(*args, **kwargs)

    dummy = torch.ones(batch_size, 1)
    expected = pytree.tree_map(add_bdim_if_tensor, batched_out)

    inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims)
    outer_in_dims = (0, ) + in_dims
    batched_args, kwarg_values = maybe_clone_inputs()
    output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args,
                                                         **kwarg_values)
    yield (expected, output)
Esempio n. 26
0
def check_with_mode(op, args, kwargs):
    def wrap(e):
        return CompositeCompliantTensor(e) if isinstance(e,
                                                         torch.Tensor) else e

    args = tree_map(wrap, args)
    kwargs = tree_map(wrap, kwargs)
    try:
        with enable_python_mode(CompositeCompliantTensor):
            op(*args, **kwargs)
    # see NOTE: [What errors are Composite Compiance trying to catch?]
    except RuntimeError as err:
        raise_composite_compliance_error(err)
Esempio n. 27
0
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e):
                    return e.elem if isinstance(e, NonWrapperSublass) else e

                def wrap(e):
                    return NonWrapperSublass(e) if isinstance(e, torch.Tensor) else e

                # no_dispatch is only needed if you use enable_python_mode.
                # It prevents infinite recursion.
                with no_dispatch():
                    rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
                logging.getLogger("NonWrapperSublass").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
                return rs
Esempio n. 28
0
    def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
        func = func_overload.overloadpacket
        if func_overload in CURRENT_DECOMPOSITION_TABLE:
            return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
        # Commenting this out for now since it causes some spurious failures (such as error checking)
        # if func == aten._local_scalar_dense:
        #     raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
        #                        "It's likely that this is caused by data-dependent control flow or similar.")

        def unwrap_proxy(e):
            return e.proxy if isinstance(e, PythonTensor) else e

        proxy_args = pytree.tree_map(unwrap_proxy, args)
        proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)

        proxy_out = func(*proxy_args, **proxy_kwargs)

        # Kind of a hacky way to test if an op is in-place or not
        if func.__name__[-1] == "_" and func.__name__[0] != "_":
            args[0].proxy = proxy_out
            proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(
                args[0])

        with no_dispatch():
            real_out = func_overload(*args, **kwargs)

        def wrap_with_proxy(e, proxy):
            # Some ops (like native_batch_norm_backward) return undefined tensors that get
            # converted into None in python.
            # As the function signature expects tensors, if we directly return these None
            # tensors back to C++, we'll error.
            if e is None:
                e = torch.empty(())
            if type(e) == torch.Tensor:
                return PythonTensor(e, proxy)
            else:
                return e

        if isinstance(real_out, tuple):
            return tuple(
                wrap_with_proxy(e, proxy_out[idx])
                for idx, e in enumerate(real_out))
        elif isinstance(real_out, list):
            return [
                wrap_with_proxy(e, proxy_out[idx])
                for idx, e in enumerate(real_out)
            ]
        elif isinstance(real_out, torch.Tensor):
            return wrap_with_proxy(real_out, proxy_out)
        else:
            return real_out
Esempio n. 29
0
    def _find_common_device(func, args, kwargs):
        # cpu - zero-dim tensors can be called in cuda kernels,
        # so overwrite the common_device if it the only existing
        # device comes from a cpu zero-dim tensor
        common_device = None
        is_cpu_zero_dim = None

        def cpu_zero_dim(t):
            return t.device.type == "cpu" and t.dim() == 0

        def merge_devices(t):
            nonlocal common_device
            nonlocal is_cpu_zero_dim
            if not isinstance(t, FakeTensor):
                return

            if common_device is None:
                common_device = t.device
                is_cpu_zero_dim = cpu_zero_dim(t)
                return

            t_is_cpu_zero_dim = cpu_zero_dim(t)
            if t.device == common_device:
                if is_cpu_zero_dim:
                    is_cpu_zero_dim = t_is_cpu_zero_dim
                return

            # mismatching devices !
            # if current tensor is cpu 0 dim, defer to existing device
            if t_is_cpu_zero_dim:
                return

            # current device is from cpu 0 dim tensor, overwrite
            if is_cpu_zero_dim:
                common_device = t.device
                is_cpu_zero_dim = t_is_cpu_zero_dim
                return

            # mismatching devices of non-zero dim tensors, throw
            # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
            raise Exception(
                f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
            )

        tree_map(merge_devices, args)
        tree_map(merge_devices, kwargs)

        assert common_device is not None, f"Could not find common device for {func}"

        return common_device
Esempio n. 30
0
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e):
                    return e.elem if isinstance(e, NonWrapperSubclass) else e

                def wrap(e):
                    return NonWrapperSubclass(e) if isinstance(
                        e, torch.Tensor) else e

                rs = tree_map(
                    wrap,
                    func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
                logging.getLogger("NonWrapperSubclass").info(
                    f"{func.__module__}.{func.__name__}", args, kwargs, rs)
                return rs