Exemple #1
0
def run_cpu_fallback(func, args, kwargs, orig_not_implemented_exception):
    with no_dispatch():

        def to_cpu(e):
            if isinstance(e, FakeTensor):
                return torch.zeros_like(e, device="cpu")
            return e

        try:
            args = tree_map(to_cpu, args)
            kwargs = tree_map(to_cpu, kwargs)
            r = func(*args, **kwargs)
        except Exception as new_exception:
            raise orig_not_implemented_exception from new_exception

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                tensor_impls.add(e)
                storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to cpu, error on reused impls
        for e in tree_flatten(r)[0]:
            if e in tensor_impls or (isinstance(e, torch.Tensor)
                                     and e.storage()._cdata in storages):
                raise orig_not_implemented_exception

    # we're only converting these to MetaTensors now, not Fake Tensors,
    # and the cpu inputs should be temporary. just convert outputs to meta
    # and continue
    return tree_map(MetaConverter(), r)
Exemple #2
0
def vjp(f, *primals):
    level = _grad_increment_nesting()
    try:
        primals = _wrap_all_tensors(primals, level)
        diff_primals = _create_differentiable(primals, level)
        primals_out = f(*diff_primals)
        results = _undo_create_differentiable(primals_out, level)

        flat_diff_primals, primals_spec = tree_flatten(diff_primals)
        flat_primals_out, primals_out_spec = tree_flatten(
            _as_tuple(primals_out))

        def wrapper(*cotangents, retain_graph=True, create_graph=True):
            flat_cotangents, cotangents_spec = tree_flatten(cotangents)
            if primals_out_spec != cotangents_spec:
                raise RuntimeError(
                    f'Expected pytree structure of cotangents to be the same '
                    f'as pytree structure of outputs to the function. '
                    f'cotangents: {treespec_pprint(cotangents_spec)}, '
                    f'primal output: {treespec_pprint(primals_out_spec)}')
            result = _autograd_grad(flat_primals_out,
                                    flat_diff_primals,
                                    flat_cotangents,
                                    retain_graph=retain_graph,
                                    create_graph=create_graph)
            return tree_unflatten(result, primals_spec)

    finally:
        _grad_decrement_nesting()

    return results, wrapper
Exemple #3
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # need to handle here to avoid infinite recursion
        # see [in_kernel_invocation]
        if func == torch.ops.prim.device.default:
            assert len(args) == 1 and isinstance(args[0], FakeTensor)
            if args[0].fake_mode.in_kernel_invocation:
                return torch.device("meta")
            else:
                return args[0].fake_device

        # Because fake mode can return NotImplemented (if it sees a subclass
        # it doesn't know how to deal with), this test here is important
        # because the next dispatch after a fake mode will attempt to use
        # subclasses of tensors to dispatch, and any FakeTensor arguments
        # will be considered eligible.
        if any(not issubclass(t, FakeTensor) and t is not torch.Tensor
               for t in types):
            return NotImplemented

        fake_mode = None
        for arg in itertools.chain(
                tree_flatten(args)[0],
                tree_flatten(kwargs)[0]):
            if isinstance(arg, FakeTensor):
                if fake_mode is None:
                    fake_mode = arg.fake_mode
                else:
                    assert fake_mode is arg.fake_mode, "Mixing modes NYI"

        with enable_torch_dispatch_mode(fake_mode):
            return func(*args, **kwargs)
Exemple #4
0
def assert_ref_meta_equal(test_case, meta_rs, rs, msg_callable):
    flat_meta_rs, _ = tree_flatten(meta_rs)
    flat_rs, _ = tree_flatten(rs)
    test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
    for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):

        def test_assert(cond, msg):
            if not cond:
                raise RuntimeError(f"output {i}: {msg_callable(msg)}")

        if not isinstance(r, torch.Tensor):
            continue
        test_assert(isinstance(meta_r, torch.Tensor),
                    f"but real {i}th result is Tensor")
        test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
        test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
        # NOTE: stride checking is currently disabled
        # See https://github.com/pytorch/pytorch/issues/78050
        # same_strides, _ = prims.utils.check_significant_strides(meta_r, r)
        # test_assert(same_strides, f"but real stride was {r.stride()}")
        test_assert(meta_r.storage_offset() == r.storage_offset(),
                    f"but real storage_offset was {r.storage_offset()}")
        test_assert(meta_r.requires_grad == r.requires_grad,
                    f"but real requires_grad was {r.requires_grad}")
        test_assert(meta_r.is_conj() == r.is_conj(),
                    f"but real is_conj was {r.is_conj()}")
        test_assert(meta_r.is_neg() == r.is_neg(),
                    f"but real is_neg was {r.is_neg()}")
Exemple #5
0
            def _torch_dispatch(cls, func, types, args=(), kwargs=None):
                self.precision = saved_precision
                self.rel_tol = saved_rel_tol

                called.add(func)
                all_called[func] += 1

                # Stuff we shouldn't bother testing
                # (TODO: remove detach from the decomp table?)
                if func not in decomposition_table or func in [
                        torch.ops.aten.detach.default
                ] or any_unsupported(args, kwargs):
                    return func(*args, **kwargs)

                decomposed.add(func)
                all_decomposed.add(func)

                # We take 2 main strategies for verifying correctness/numerical stability of decompositions
                # The first one is simply tolerance checking between decomp_out and pytorch_out
                # However, for fp16/bf16 and reductions, this becomes very
                # finicky, as there are not many guarantees we can make.
                # So, for fp16/bf16, we instead compare the difference of
                # {decomp_out, pytorch_out_64} and {pytorch_out,
                # pytorch_out_64}. In other words, we compare how far the
                # decomposition and pytorch are from the "ground truth" (i.e.
                # fp64). If the decomposition results in more error, we error

                decomposition = decomposition_table[func]

                do_relative_check = test_dtype in [
                    torch.float16, torch.bfloat16
                ]
                real_out_unflat = func(*args, **kwargs)
                real_out, _ = tree_flatten(real_out_unflat)
                decomp_out, _ = tree_flatten(decomposition(*args, **kwargs))
                assert len(real_out) == len(decomp_out)

                if do_relative_check:
                    upcast = partial(upcast_tensor, dtype=torch.float64)
                    real_out_double, _ = tree_flatten(
                        func(*tree_map(upcast, args),
                             **tree_map(upcast, kwargs)))
                    for i, orig, decomp, ref in zip(range(len(real_out)),
                                                    real_out, decomp_out,
                                                    real_out_double):
                        if orig is None:
                            assert decomp is None
                            continue
                        op_assert_ref(self, func, test_dtype, i, orig, decomp,
                                      ref, args, kwargs)
                else:
                    for orig, decomp in zip(real_out, decomp_out):
                        if orig is None:
                            assert decomp is None
                            continue
                        op_assert_equal(self, func, test_dtype, orig, decomp,
                                        args, kwargs)

                return real_out_unflat
Exemple #6
0
def _outs_and_grads(fn, inps):
    outs = fn(*inps)
    for out in pytree.tree_flatten(outs)[0]:
        if isinstance(out, torch.Tensor) and out.requires_grad:
            out.sum().backward(retain_graph=True)
    grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
    for inp in pytree.tree_flatten(inps)[0]:
        inp.grad = None
    return outs, grads
def gather_leaf_tensors(args, kwargs):
    leaf_tensors = []
    args, args_spec = tree_flatten(args)
    kwargs, kwargs_spec = tree_flatten(kwargs)
    args = args + kwargs
    for arg in args:
        if not isinstance(arg, torch.Tensor):
            continue
        if arg.requires_grad:
            leaf_tensors.append(arg)
    return leaf_tensors
Exemple #8
0
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception):
    # these should all be supported, just to be safe
    # avoid fallback for operators which inplace modify metadata
    # because the input fake tensors would be umodified
    if torch.Tag.inplace_view in func.tags:  # type: ignore[attr-defined]
        raise orig_not_implemented_exception

    with no_dispatch():
        inp_impls = {}

        def to_real_tensor(e):
            if isinstance(e, FakeTensor):
                out = torch.zeros_like(e, device=e.fake_device)
                if e.is_sparse:
                    out._coalesced_(e.is_coalesced())
                inp_impls[id(out)] = e
                return out
            return e

        args = tree_map(to_real_tensor, args)
        kwargs = tree_map(to_real_tensor, kwargs)

        r = func(*args, **kwargs)

        tensor_impls = set()
        storages = set()

        for e in tree_flatten((args, kwargs))[0]:
            if isinstance(e, torch.Tensor):
                if not e.is_sparse:
                    storages.add(e.storage()._cdata)

        # TODO: also check metadata change on inputs
        # proper aliasing/metadata relationship between outputs and inputs will
        # not be set up, bc of conversion to device, unless we can reuse an
        # input impl
        for e in tree_flatten(r)[0]:
            if id(e) not in inp_impls and (
                isinstance(e, torch.Tensor)
                and not e.is_sparse
                and e.storage()._cdata in storages
            ):
                raise orig_not_implemented_exception

    def map_out(e):
        if isinstance(e, torch.Tensor):
            if id(e) in inp_impls:
                return inp_impls[id(e)]
            else:
                return fake_mode.fake_tensor_converter(fake_mode, e)
        else:
            return e

    return tree_map(map_out, r)
Exemple #9
0
def _autograd_grad(
    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
    inputs, inputs_spec = tree_flatten(inputs)
    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
    if grad_outputs is None:
        diff_outputs = tuple(out for out in outputs if out.requires_grad)
    else:
        diff_grad_outputs = [
            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
        ]
        if len(diff_grad_outputs) == 0:
            diff_outputs, grad_outputs = (), ()
        else:
            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
    grad_inputs = torch.autograd.grad(
        diff_outputs,
        diff_inputs,
        grad_outputs,
        retain_graph=retain_graph,
        create_graph=create_graph,
        allow_unused=True,
    )
    result = []
    grad_inputs_iter = iter(grad_inputs)
    for inp in inputs:
        if inp.requires_grad:
            grad_input = next(grad_inputs_iter)
            if grad_input is None:
                result.append(torch.zeros_like(inp))
            else:
                result.append(grad_input)
        else:
            result.append(torch.zeros_like(inp))
    return tree_unflatten(result, inputs_spec)
Exemple #10
0
            def flat_fn(*flat_tensor_args):
                # The input are flattened tensor args. Prepare the args in the
                # order that original function expects. Add static args as well.
                # They will appear as tensor constants in the traced graph.
                nonlocal out_spec, static_args

                tensor_args, kwargs = pytree.tree_unflatten(
                    flat_tensor_args, tensor_args_spec)
                if static_argnums is None:
                    args = tensor_args
                else:
                    args = rearrange(tensor_args, static_args, static_argnums)
                tree_out = fn(*args, **kwargs)
                flat_out, spec = pytree.tree_flatten(tree_out)
                for i in flat_out:
                    is_known_type = False
                    for j in KNOWN_TYPES:
                        if isinstance(i, j):
                            is_known_type = True
                            break
                    if not is_known_type:
                        raise RuntimeError(
                            f"Found {type(i)} in output, which is not a known type. "
                            "If this type holds tensors, you need to register a pytree for it. "
                            "See https://github.com/pytorch/functorch/issues/475 for a brief "
                            "explanation why. If you don't need to register a pytree, please "
                            "leave a comment explaining your use case and we'll make this more "
                            "ergonomic to deal with")
                out_spec.set(spec)
                return flat_out
 def flatten_fn(*args):
     tree_args = pytree.tree_unflatten(list(args), in_spec)
     tree_out = root_fn(*tree_args)
     out_args, out_spec = pytree.tree_flatten(tree_out)
     assert(isinstance(self.graph._codegen, _PyTreeCodeGen))
     self.graph._codegen.pytree_info = self.graph._codegen.pytree_info._replace(out_spec=out_spec)
     return out_args
Exemple #12
0
def wrap_key(f, inps):
    flat_inps, _ = pytree.tree_flatten(inps)

    @functools.wraps(f)
    def wrapped(*args):
        flat_args, args_spec = pytree.tree_flatten(args)
        assert (len(flat_args) == len(flat_inps))
        for idx, arg in enumerate(flat_args):
            if isinstance(flat_inps[idx], torch.Tensor):
                with no_dispatch():
                    flat_args[idx] = ProxyTensor(
                        flat_inps[idx],
                        arg,
                        requires_grad=flat_inps[idx].is_leaf)
            else:
                flat_args[idx] = flat_inps[idx]

        tree_args = pytree.tree_unflatten(flat_args, args_spec)
        out = f(*tree_args)
        flat_outs, out_spec = pytree.tree_flatten(out)
        for idx in range(len(flat_outs)):
            if isinstance(flat_outs[idx], torch.Tensor) and isinstance(
                    flat_outs[idx], ProxyTensor):
                flat_outs[idx] = flat_outs[idx].proxy
        return pytree.tree_unflatten(flat_outs, out_spec)

    return wrapped
Exemple #13
0
def get_isolated_graphmodule(func, args, kwargs):
    """A helper function used to get the GraphModule for the given func.

    It's expected to be used in the ProxyTensor tracing context.
    It detaches the args and kwargs from the current tracer so that the trace of
    the current graph module can be created without any side-effects.
    """
    # make_fx doesn't support kwargs, so we need to do this flattening
    # and then unflatten the args before calling func
    all_args, spec = pytree.tree_flatten((args, kwargs))

    def wrapped(args):
        fn_args, fn_kwargs = pytree.tree_unflatten(args, spec)
        return func(*fn_args, **fn_kwargs)

    unwrapped_all_args = [unwrap_elem(a) for a in all_args]

    # Current implementation doesn't support the case when ProxyTensor is
    # wrapped with another Tensor subclass
    # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
    # TODO: Once https://github.com/pytorch/pytorch/pull/82549 is merged, we can
    # remove this
    assert all(
        getattr(a, "elem", None) is None for a in unwrapped_all_args
        if isinstance(a, torch.Tensor)
    ), "ProxyTensor is wrapped with another Tensor subclass"

    with disable_proxy_modes_tracing():
        gm = make_fx(wrapped)(unwrapped_all_args)
    return gm
Exemple #14
0
 def test_treespec_repr(self):
     # Check that it looks sane
     pytree = (0, [0, 0, 0])
     _, spec = tree_flatten(pytree)
     self.assertEqual(
         repr(spec),
         'TreeSpec(tuple, None, [*, TreeSpec(list, None, [*, *, *])])')
Exemple #15
0
def wrap_key(f, inps):
    flat_inps, inp_spec = pytree.tree_flatten(inps)

    @functools.wraps(f)
    def wrapped(*args):
        flat_args, args_spec = pytree.tree_flatten(args)
        assert (len(flat_args) == len(flat_inps))
        for idx, arg in enumerate(flat_args):
            if isinstance(flat_inps[idx], torch.Tensor):
                flat_args[idx] = addPythonKey(PythonTensor(
                    flat_inps[idx], arg))
            else:
                flat_args[idx] = flat_inps[idx]

        tree_args = pytree.tree_unflatten(flat_args, args_spec)
        out = f(*tree_args)

        flat_outs, out_spec = pytree.tree_flatten(out)
        for idx in range(len(flat_outs)):
            if isinstance(flat_outs[idx], torch.Tensor) and hasPythonKey(
                    flat_outs[idx]):
                flat_outs[idx] = removePythonKey(flat_outs[idx]).proxy
        return pytree.tree_unflatten(flat_outs, out_spec)

    return wrapped
Exemple #16
0
def get_exhaustive_batched_inputs_batch_norm_is_training(
        arg_values, kwarg_values, batch_size=2):
    flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
    is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
    num_tensors = sum(is_tensors)
    if num_tensors == 1:  # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors
        return
    bdim_choices = get_bdim_choices_batch_norm(num_tensors, *arg_values)

    @memoize
    def get_batched_arg(arg, bdim):
        assert isinstance(arg, torch.Tensor)
        assert bdim is not None
        result, _ = add_batch_dim(arg, bdim, batch_size)
        return result

    for bdim_choice in bdim_choices:
        flat_in_dims = construct_in_dims(bdim_choice, is_tensors)

        flat_batched_args = tuple(
            arg if in_dim is None else get_batched_arg(arg, in_dim)
            for arg, in_dim in zip(flat_args, flat_in_dims))
        batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
        in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
        yield batched_args, in_dims, kwarg_values
Exemple #17
0
        def _fn(*args, **kwargs):
            bound = sig.bind(*args, **kwargs)
            type_promoting_args = tuple(
                bound.arguments[x]
                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys()
            )

            flattened_type_promoting_args = tree_flatten(type_promoting_args)[0]
            compute_dtype, result_dtype = utils.elementwise_dtypes(
                *flattened_type_promoting_args,
                type_promotion_kind=self.type_promotion_kind,
            )

            promoted_args = {
                x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
                for x in self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys()
            }
            bound.arguments.update(promoted_args)

            result = fn(**bound.arguments)

            # FIXME?: assumes result is a single tensor
            assert isinstance(result, TensorLike)
            return _maybe_convert_to_dtype(result, result_dtype)
Exemple #18
0
 def flatten_fn(*args):
     tree_args = pytree.tree_unflatten(list(args), in_spec)
     tree_out = root_fn(*tree_args)
     out_args, out_spec = pytree.tree_flatten(tree_out)
     assert(self.graph._pytree_info is not None)
     self.graph._pytree_info = self.graph._pytree_info._replace(out_spec=out_spec)
     return out_args
Exemple #19
0
def normalize_op_input_output2(f,
                               args,
                               kwargs,
                               output_process_fn_grad=None,
                               requires_grad=True):
    flat_args, args_spec = tree_flatten(args)
    diff_argnums = tuple(i for i, arg in enumerate(flat_args)
                         if diff_arg(arg, requires_grad=requires_grad))
    assert len(diff_argnums) > 0
    primals = tuple(flat_args[i] for i in diff_argnums)

    @functools.wraps(f)
    def wrapped(*primals):
        _args = list(flat_args)
        for num, arg in zip(diff_argnums, primals):
            _args[num] = arg
        _args = tree_unflatten(_args, args_spec)
        result = f(*_args, **kwargs)
        if output_process_fn_grad is not None:
            result = output_process_fn_grad(result)
        if isinstance(result, tuple):
            # TODO: Remove the following hack for namedtuples
            result = tuple(result)
            result = tuple(r for r in result if isinstance(r, Tensor) and (
                r.is_floating_point() or r.is_complex()))
            assert len(result) > 0
        return result

    return wrapped, primals
Exemple #20
0
        def run_test_with_leaf(leaf):
            values, treespec = tree_flatten(leaf)
            self.assertEqual(values, [leaf])
            self.assertEqual(treespec, LeafSpec())

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, leaf)
Exemple #21
0
 def substitute(arg_list):
     arg_list, spec = tree_flatten(arg_list)
     for i in range(len(arg_list)):
         v = arg_list[i]
         if isinstance(v, torch.fx.node.Node) and v in env:
             arg_list[i] = env[v]
     return tuple(arg_list), spec
Exemple #22
0
    def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
        func = func_overload.overloadpacket
        # We don't want to convert torch.tensor constants into tracing objects.
        if func_overload == aten.lift.default:
            return args[0]
        if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
            return proxy_call(func_overload, args, kwargs)
        # When we trace through a torch.tensor invocation, you never actually
        # see a torch.ops.aten.tensor call. Instead, the way this function is
        # implemented internally is that we allocate a plain tensor (this is
        # *guaranteed* to be a plain tensor, we disable all modes when doing
        # so), and then call at::lift_fresh on it (to give modes a chance to do
        # their stuff).  Furthermore, the tensor argument to lift_fresh is guaranteed
        # to be freshly allocated, so we want lift_fresh to be a no-op (directly
        # returning the input argument).
        #
        # Here is the basic problem: when we trace this sequence of executions
        # into an FX graph, what happens to this call sequence?  Traditionally,
        # tensor constants get interned as buffers on the FX GraphModule.  But
        # this is dangerous.  Consider:
        #
        #       x = torch.tensor(1)
        #       x.add_(2)
        #
        # Naively, this traces into:
        #
        #       t = self._tensor_constant0  # initialized to torch.tensor(1)
        #       x = torch.ops.aten.lift_fresh(t)
        #       x.add_(2)
        #
        # If lift_fresh returns t directly, the subsequent add_ call will
        # modify the tensor constant. Really, the problem is we've violated
        # the invariant the the argument to lift is fresh.  So what we should
        # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
        #
        #       t = self._tensor_constant0  # initialized to torch.tensor(1)
        #       x = torch.ops.aten.lift_fresh_copy(t)
        #       x.add_(2)
        #
        # This is what the overload modification does.
        else:
            if func_overload is torch.ops.aten.lift_fresh.default:
                func_overload = torch.ops.aten.lift_fresh_copy.default

            proxy_res = self.tracer.create_proxy('call_function', func_overload, args, kwargs,
                                                 name=self.tracer.graph._target_to_str(func.__name__))

            inner_res = func_overload(*args, **kwargs)

            # If this is a lift, the input tensor is guaranteed to be a
            # constant, so we keep a copy of the original argument along so
            # we can query it if we're asked to item() it at some later point
            is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default
            if is_lift:
                with maybe_disable_fake_tensor_mode():
                    constant = args[0].clone()
            else:
                constant = None
            return wrap_output(inner_res, proxy_res, constant=constant)
Exemple #23
0
    def wrapped(*args):
        phs = pytree.tree_map(lambda _: fx.PH,
                              args)  # type: ignore[attr-defined]
        fx_tracer = PythonKeyTracer()
        fake_tensor_mode: Any = nullcontext()
        if tracing_mode == "real":
            fake_tensor_mode = nullcontext()
        elif tracing_mode == "fake":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
        elif tracing_mode == "symbolic":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
        else:
            raise AssertionError(f"Unexpected tracing type: {tracing_mode}")

        proxy_mode = ProxyTorchDispatchMode(
            fx_tracer, trace_factory_functions=trace_factory_functions)

        def wrap_fake_concrete(x):
            if isinstance(x, torch.Tensor):
                return fake_tensor_mode.from_tensor(
                    x)  # type: ignore[attr-defined]

            return x

        shape_env = ShapeEnv()

        # todo: Figure out a more informative name for symints
        def wrap_fake_symbolic(x, sym_shape):
            if isinstance(x, torch.Tensor):
                val = FakeTensor(fake_tensor_mode,
                                 torch.empty(sym_shape, device="meta"),
                                 x.device)
                return val
            return x

        wrap_fn_map = {
            "real": lambda x: x,
            "fake": wrap_fake_concrete,
        }
        if tracing_mode == "symbolic":
            flat_shapes = shape_env.create_shapes_for_args(args)
            flat_args, spec = pytree.tree_flatten(args)
            args = pytree.tree_unflatten(
                list(
                    map(lambda a: wrap_fake_symbolic(a[0], a[1]),
                        zip(flat_args, flat_shapes))), spec)
        else:
            args = pytree.tree_map(wrap_fn_map[tracing_mode], args)

        with decompose(
                decomposition_table
        ), fake_tensor_mode, proxy_mode:  # type: ignore[attr-defined]
            t = dispatch_trace(wrap_key(f, args, proxy_mode),
                               tracer=fx_tracer,
                               concrete_args=tuple(phs))

        # TODO: kind of a bad way to do it, should maybe figure out a better way
        t.shape_env = shape_env  # type: ignore[assignment]
        return t
Exemple #24
0
    def wrapped(*args):
        flat_args, args_spec = pytree.tree_flatten(args)
        assert (len(flat_args) == len(flat_inps))
        for idx, arg in enumerate(flat_args):
            if isinstance(flat_inps[idx], torch.Tensor):
                flat_args[idx] = PythonTensor(flat_inps[idx], arg)
            else:
                flat_args[idx] = flat_inps[idx]

        tree_args = pytree.tree_unflatten(flat_args, args_spec)
        out = f(*tree_args)
        flat_outs, out_spec = pytree.tree_flatten(out)
        for idx in range(len(flat_outs)):
            if isinstance(flat_outs[idx], torch.Tensor) and isinstance(
                    flat_outs[idx], PythonTensor):
                flat_outs[idx] = flat_outs[idx].proxy
        return pytree.tree_unflatten(flat_outs, out_spec)
Exemple #25
0
def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule):
    num_fwd_outputs = joint_module._out_spec.children_specs[0].num_leaves
    outputs = pytree.tree_flatten([
        node.args for node in joint_module.graph.nodes if node.op == 'output'
    ])[0]
    fwd_outputs = outputs[:num_fwd_outputs]
    bwd_outputs = outputs[num_fwd_outputs:]
    return fwd_outputs, bwd_outputs
Exemple #26
0
    def test_flatten_unflatten_torch_namedtuple_return_type(self):
        x = torch.randn(3, 3)
        expected = torch.max(x, dim=0)

        values, spec = tree_flatten(expected)
        result = tree_unflatten(values, spec)

        self.assertEqual(type(result), type(expected))
        self.assertEqual(result, expected)
Exemple #27
0
        def run_test(pytree):
            values, treespec = tree_flatten(pytree)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(len(values), treespec.num_leaves)

            # NB: python basic data structures (dict list tuple) all have
            # contents equality defined on them, so the following works for them.
            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, pytree)
Exemple #28
0
def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2,
          *batched_args, **kwarg_values):
    outs = []
    flat_args, args_spec = pytree.tree_flatten(batched_args)
    flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
    flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
    assert (args_spec == dims_spec1)
    assert (args_spec == dims_spec2)
    assert (len(flat_dims1) == len(flat_dims2))
    for idx1 in range(batch_size1):
        out_split = []
        arg_split = [
            a.select(in_dim1, idx1) if in_dim1 is not None else a
            for a, in_dim1 in zip(flat_args, flat_dims1)
        ]
        for idx2 in range(batch_size2):
            new_args = [
                a.select(in_dim, idx2) if in_dim is not None else a
                for a, in_dim in zip(arg_split, flat_dims2)
            ]
            out = op(*pytree.tree_unflatten(new_args, args_spec),
                     **kwarg_values)
            out_split.append(out)
        outs.append(out_split)

    loop_out = []
    for out_split in outs:
        if isinstance(out_split[0], torch.Tensor):
            loop_out.append(torch.stack(out_split, out_dim1))
        else:
            new_out = []
            for idx in range(len(out_split[0])):
                new_out.append(
                    torch.stack([i[idx] for i in out_split], out_dim1))
            loop_out.append(new_out)

    new_out = []
    if isinstance(loop_out, torch.Tensor):
        new_out = torch.stack(loop_out, out_dim2)
    else:
        for idx in range(len(loop_out[0])):
            new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2))
    return new_out
Exemple #29
0
        def run_test(tup):
            expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, tuple))
Exemple #30
0
        def run_test(lst):
            expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst])
            values, treespec = tree_flatten(lst)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, lst)
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, lst)
            self.assertTrue(isinstance(unflattened, list))