Example #1
0
def get_exhaustive_batched_inputs_batch_norm_is_training(
        arg_values, kwarg_values, batch_size=2):
    flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
    is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
    num_tensors = sum(is_tensors)
    if num_tensors == 1:  # if there's only an input, can't batch it since running_mean/var will be seen as unbatched tensors
        return
    bdim_choices = get_bdim_choices_batch_norm(num_tensors, *arg_values)

    @memoize
    def get_batched_arg(arg, bdim):
        assert isinstance(arg, torch.Tensor)
        assert bdim is not None
        result, _ = add_batch_dim(arg, bdim, batch_size)
        return result

    for bdim_choice in bdim_choices:
        flat_in_dims = construct_in_dims(bdim_choice, is_tensors)

        flat_batched_args = tuple(
            arg if in_dim is None else get_batched_arg(arg, in_dim)
            for arg, in_dim in zip(flat_args, flat_in_dims))
        batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
        in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
        yield batched_args, in_dims, kwarg_values
Example #2
0
def generate_subclass_choices_args_kwargs(args, kwargs):
    flat_kwargs, spec = tree_flatten(kwargs)
    flat_args_kwargs = list(args) + list(flat_kwargs)
    for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs):
        new_args = choice[:len(args)]
        new_kwargs = tree_unflatten(choice[len(args):], spec)
        which_args_are_wrapped = debug_metadata[:len(args)]
        which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):],
                                                  spec)
        yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
def generate_subclass_choices_args_kwargs(args, kwargs, CCT):
    # CCT: CompositeCompliantTensor class which is generated using generate_cct
    flat_kwargs, spec = tree_flatten(kwargs)
    flat_args_kwargs = list(args) + list(flat_kwargs)
    for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT):
        new_args = choice[:len(args)]
        new_kwargs = tree_unflatten(choice[len(args):], spec)
        which_args_are_wrapped = debug_metadata[:len(args)]
        which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
        yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
Example #4
0
def _create_batched_inputs(
        flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
    batched_inputs = [arg if in_dim is None else
                      _add_batch_dim(arg, in_dim, vmap_level)  # type: ignore
                      for in_dim, arg in zip(flat_in_dims, flat_args)]
    return tree_unflatten(batched_inputs, args_spec)
Example #5
0
    def wrapped_with_chunks(*args, **kwargs):
        _check_out_dims_is_int_or_int_pytree(out_dims, func)
        _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
            in_dims, args, func)
        # Chunk flat arguments
        chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims,
                                                chunks)

        # Apply vmap on chunks
        chunks_output = []
        rs = torch.get_rng_state() if randomness == "same" else None
        for flat_args in chunks_flat_args:
            batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
            if rs is not None:
                torch.set_rng_state(rs)
            chunks_output.append(
                _flat_vmap(func, batch_size, flat_in_dims, flat_args,
                           args_spec, out_dims, randomness, **kwargs))
        flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
        # Removing temporary variables helps to reduce memory usage on device like CUDA
        del chunks_output

        # concat chunks on out_dim
        flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
        assert len(flat_out_dims) == len(flat_output_chunks)
        flat_output = []
        for out_dim in flat_out_dims:
            flat_output.append(torch.cat(flat_output_chunks[0], dim=out_dim))
            # release source data
            del flat_output_chunks[0]
        del flat_output_chunks

        # finally unflatten the output
        return tree_unflatten(flat_output, arg_spec)
Example #6
0
def _autograd_grad(
    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
    inputs, inputs_spec = tree_flatten(inputs)
    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
    if grad_outputs is None:
        diff_outputs = tuple(out for out in outputs if out.requires_grad)
    else:
        diff_grad_outputs = [
            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
        ]
        if len(diff_grad_outputs) == 0:
            diff_outputs, grad_outputs = (), ()
        else:
            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
    grad_inputs = torch.autograd.grad(
        diff_outputs,
        diff_inputs,
        grad_outputs,
        retain_graph=retain_graph,
        create_graph=create_graph,
        allow_unused=True,
    )
    result = []
    grad_inputs_iter = iter(grad_inputs)
    for inp in inputs:
        if inp.requires_grad:
            grad_input = next(grad_inputs_iter)
            if grad_input is None:
                result.append(torch.zeros_like(inp))
            else:
                result.append(grad_input)
        else:
            result.append(torch.zeros_like(inp))
    return tree_unflatten(result, inputs_spec)
Example #7
0
        def run_test_with_leaf(leaf):
            values, treespec = tree_flatten(leaf)
            self.assertEqual(values, [leaf])
            self.assertEqual(treespec, LeafSpec())

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, leaf)
Example #8
0
 def unflatten_outs(self, out):
     if self._pytree_info is None:
         return out
     if not isinstance(out, list):
         out = [out]
     assert (self._pytree_info.out_spec is not None)
     return pytree.tree_unflatten(out, self._pytree_info.out_spec)
Example #9
0
            def flat_fn(*flat_tensor_args):
                # The input are flattened tensor args. Prepare the args in the
                # order that original function expects. Add static args as well.
                # They will appear as tensor constants in the traced graph.
                nonlocal out_spec, static_args

                tensor_args, kwargs = pytree.tree_unflatten(
                    flat_tensor_args, tensor_args_spec)
                if static_argnums is None:
                    args = tensor_args
                else:
                    args = rearrange(tensor_args, static_args, static_argnums)
                tree_out = fn(*args, **kwargs)
                flat_out, spec = pytree.tree_flatten(tree_out)
                for i in flat_out:
                    is_known_type = False
                    for j in KNOWN_TYPES:
                        if isinstance(i, j):
                            is_known_type = True
                            break
                    if not is_known_type:
                        raise RuntimeError(
                            f"Found {type(i)} in output, which is not a known type. "
                            "If this type holds tensors, you need to register a pytree for it. "
                            "See https://github.com/pytorch/functorch/issues/475 for a brief "
                            "explanation why. If you don't need to register a pytree, please "
                            "leave a comment explaining your use case and we'll make this more "
                            "ergonomic to deal with")
                out_spec.set(spec)
                return flat_out
Example #10
0
 def flatten_fn(*args):
     tree_args = pytree.tree_unflatten(list(args), in_spec)
     tree_out = root_fn(*tree_args)
     out_args, out_spec = pytree.tree_flatten(tree_out)
     assert(isinstance(self.graph._codegen, _PyTreeCodeGen))
     self.graph._codegen.pytree_info = self.graph._codegen.pytree_info._replace(out_spec=out_spec)
     return out_args
Example #11
0
 def flatten_fn(*args):
     tree_args = pytree.tree_unflatten(list(args), in_spec)
     tree_out = root_fn(*tree_args)
     out_args, out_spec = pytree.tree_flatten(tree_out)
     assert(self.graph._pytree_info is not None)
     self.graph._pytree_info = self.graph._pytree_info._replace(out_spec=out_spec)
     return out_args
Example #12
0
    def wrapped(*args):
        phs = pytree.tree_map(lambda _: fx.PH,
                              args)  # type: ignore[attr-defined]
        fx_tracer = PythonKeyTracer()
        fake_tensor_mode: Any = nullcontext()
        if tracing_mode == "real":
            fake_tensor_mode = nullcontext()
        elif tracing_mode == "fake":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
        elif tracing_mode == "symbolic":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
        else:
            raise AssertionError(f"Unexpected tracing type: {tracing_mode}")

        proxy_mode = ProxyTorchDispatchMode(
            fx_tracer, trace_factory_functions=trace_factory_functions)

        def wrap_fake_concrete(x):
            if isinstance(x, torch.Tensor):
                return fake_tensor_mode.from_tensor(
                    x)  # type: ignore[attr-defined]

            return x

        shape_env = ShapeEnv()

        # todo: Figure out a more informative name for symints
        def wrap_fake_symbolic(x, sym_shape):
            if isinstance(x, torch.Tensor):
                val = FakeTensor(fake_tensor_mode,
                                 torch.empty(sym_shape, device="meta"),
                                 x.device)
                return val
            return x

        wrap_fn_map = {
            "real": lambda x: x,
            "fake": wrap_fake_concrete,
        }
        if tracing_mode == "symbolic":
            flat_shapes = shape_env.create_shapes_for_args(args)
            flat_args, spec = pytree.tree_flatten(args)
            args = pytree.tree_unflatten(
                list(
                    map(lambda a: wrap_fake_symbolic(a[0], a[1]),
                        zip(flat_args, flat_shapes))), spec)
        else:
            args = pytree.tree_map(wrap_fn_map[tracing_mode], args)

        with decompose(
                decomposition_table
        ), fake_tensor_mode, proxy_mode:  # type: ignore[attr-defined]
            t = dispatch_trace(wrap_key(f, args, proxy_mode),
                               tracer=fx_tracer,
                               concrete_args=tuple(phs))

        # TODO: kind of a bad way to do it, should maybe figure out a better way
        t.shape_env = shape_env  # type: ignore[assignment]
        return t
Example #13
0
    def wrapped(*args):
        flat_args, args_spec = pytree.tree_flatten(args)
        assert (len(flat_args) == len(flat_inps))
        for idx, arg in enumerate(flat_args):
            if isinstance(flat_inps[idx], torch.Tensor):
                with no_dispatch():
                    flat_args[idx] = ProxyTensor(flat_inps[idx], arg, requires_grad=flat_inps[idx].is_leaf)
            else:
                flat_args[idx] = flat_inps[idx]

        tree_args = pytree.tree_unflatten(flat_args, args_spec)
        out = f(*tree_args)
        flat_outs, out_spec = pytree.tree_flatten(out)
        for idx in range(len(flat_outs)):
            if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
                flat_outs[idx] = flat_outs[idx].proxy
        return pytree.tree_unflatten(flat_outs, out_spec)
Example #14
0
    def wrapped(*args):
        flat_args, args_spec = pytree.tree_flatten(args)
        assert (len(flat_args) == len(flat_inps))
        for idx, arg in enumerate(flat_args):
            if isinstance(flat_inps[idx], torch.Tensor):
                flat_args[idx] = PythonTensor(flat_inps[idx], arg)
            else:
                flat_args[idx] = flat_inps[idx]

        tree_args = pytree.tree_unflatten(flat_args, args_spec)
        out = f(*tree_args)
        flat_outs, out_spec = pytree.tree_flatten(out)
        for idx in range(len(flat_outs)):
            if isinstance(flat_outs[idx], torch.Tensor) and isinstance(
                    flat_outs[idx], PythonTensor):
                flat_outs[idx] = flat_outs[idx].proxy
        return pytree.tree_unflatten(flat_outs, out_spec)
Example #15
0
    def test_flatten_unflatten_torch_namedtuple_return_type(self):
        x = torch.randn(3, 3)
        expected = torch.max(x, dim=0)

        values, spec = tree_flatten(expected)
        result = tree_unflatten(values, spec)

        self.assertEqual(type(result), type(expected))
        self.assertEqual(result, expected)
Example #16
0
        def run_test(pytree):
            values, treespec = tree_flatten(pytree)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(len(values), treespec.num_leaves)

            # NB: python basic data structures (dict list tuple) all have
            # contents equality defined on them, so the following works for them.
            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, pytree)
Example #17
0
        def run_test(tup):
            expected_spec = TreeSpec(tuple, None, [LeafSpec() for _ in tup])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, tuple))
Example #18
0
        def run_test(lst):
            expected_spec = TreeSpec(list, None, [LeafSpec() for _ in lst])
            values, treespec = tree_flatten(lst)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, lst)
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, lst)
            self.assertTrue(isinstance(unflattened, list))
Example #19
0
 def functional_call(*args, **kwargs):
     with _stateless.reparametrize_module(
             mod, pytree.tree_unflatten(args[:params_len], params_spec)):
         out = mod(*args[params_len:], **kwargs)
     if not isinstance(out, (tuple, list)):
         raise RuntimeError(
             "Graph output must be a tuple(). This is so that we can avoid "
             "pytree processing of the ouputs. Please change the module to "
             "have tuple outputs or use aot_module instead.")
     return out
def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None):
    CCT = generate_cct(enable_recursive_torch_dispatch=True, autograd_view_consistency=False)
    # Permutations of arg and kwargs in CCT.
    for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
        new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice

        def maybe_tangent(t):
            assert type(t) is not CCT
            # Generate `tangent` tensor
            # if given object is a Tensor and requires grad is set.
            if isinstance(t, torch.Tensor) and t.requires_grad:
                return torch.randn_like(t)
            elif is_tensorlist(t):
                return list(torch.randn_like(e) if e.requires_grad else None for e in t)
            return None

        tangent_args = tuple(maybe_tangent(arg) for arg in args)
        flat_kwargs, spec = tree_flatten(kwargs)
        flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
        tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)

        # Permutations tangent arg and tangent kwargs in CCT.
        for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT):
            new_tang_args, new_tang_kwargs, \
                which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice

            with fwAD.dual_level():
                def maybe_make_dual(dual):
                    # Returns dual tensor if primal is a tensor/tensor subclass
                    # with requires_grad set.
                    primal, tangent = dual
                    if isinstance(primal, torch.Tensor) and primal.requires_grad:
                        return fwAD.make_dual(primal, tangent)
                    elif is_tensorlist(primal):
                        return tuple(fwAD.make_dual(pri, tang) if tang is not None else pri
                                     for pri, tang in zip(primal, tangent))
                    return primal

                op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
                op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}

                try:
                    if gradcheck_wrapper is None:
                        op(*op_args, **op_kwargs)
                    else:
                        gradcheck_wrapper(op, *op_args, **op_kwargs)
                # see NOTE: [What errors are Composite Compiance trying to catch?]
                except RuntimeError as err:
                    raise_composite_compliance_error(
                        err,
                        f"- wrapped_args: {which_args_are_wrapped}\n"
                        f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
                        f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
                        f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
                    )
Example #21
0
        def run_test(tup):
            expected_spec = TreeSpec(dict, list(tup.keys()),
                                     [LeafSpec() for _ in tup.values()])
            values, treespec = tree_flatten(tup)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(tup.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, tup)
            self.assertTrue(isinstance(unflattened, dict))
Example #22
0
        def run_test(odict):
            expected_spec = TreeSpec(OrderedDict, list(odict.keys()),
                                     [LeafSpec() for _ in odict.values()])
            values, treespec = tree_flatten(odict)
            self.assertTrue(isinstance(values, list))
            self.assertEqual(values, list(odict.values()))
            self.assertEqual(treespec, expected_spec)

            unflattened = tree_unflatten(values, treespec)
            self.assertEqual(unflattened, odict)
            self.assertTrue(isinstance(unflattened, OrderedDict))
Example #23
0
def get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size=2):
    flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
    is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
    bdim_choices = get_bdim_choices(sum(is_tensors))

    @memoize
    def get_batched_arg(arg, bdim):
        assert isinstance(arg, torch.Tensor)
        assert bdim is not None
        result, _ = add_batch_dim(arg, bdim, batch_size)
        return result

    for bdim_choice in bdim_choices:
        flat_in_dims = construct_in_dims(bdim_choice, is_tensors)

        flat_batched_args = tuple(
            arg if in_dim is None else get_batched_arg(arg, in_dim)
            for arg, in_dim in zip(flat_args, flat_in_dims))
        batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
        in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
        yield batched_args, in_dims, kwarg_values
Example #24
0
    def test_flatten_unflatten_return_type(self, op):
        x = torch.randn(3, 3)
        expected = op(x, dim=0)

        values, spec = tree_flatten(expected)
        # Check that values is actually List[Tensor] and not (ReturnType(...),)
        for value in values:
            self.assertTrue(isinstance(value, torch.Tensor))
        result = tree_unflatten(values, spec)

        self.assertEqual(type(result), type(expected))
        self.assertEqual(result, expected)
Example #25
0
 def wrapper(*cotangents, retain_graph=True, create_graph=True):
     flat_cotangents, cotangents_spec = tree_flatten(cotangents)
     if primals_out_spec != cotangents_spec:
         raise RuntimeError(
             f'Expected pytree structure of cotangents to be the same '
             f'as pytree structure of outputs to the function. '
             f'cotangents: {treespec_pprint(cotangents_spec)}, '
             f'primal output: {treespec_pprint(primals_out_spec)}')
     result = _autograd_grad(flat_primals_out,
                             flat_diff_primals,
                             flat_cotangents,
                             retain_graph=retain_graph,
                             create_graph=create_graph)
     return tree_unflatten(result, primals_spec)
Example #26
0
def _create_batched_inputs(in_dims: in_dims_t, args: Tuple, vmap_level: int,
                           func: Callable) -> Tuple[Tuple, int]:
    if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
        raise ValueError(
            f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
            f'expected `in_dims` to be int or a (potentially nested) tuple '
            f'matching the structure of inputs, got: {type(in_dims)}.')
    if len(args) == 0:
        raise ValueError(
            f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
            f'inputs, or you are trying to vmap over a function with no inputs. '
            f'The latter is unsupported.')

    flat_args, args_spec = tree_flatten(args)
    flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
    if flat_in_dims is None:
        raise ValueError(
            f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
            f'in_dims is not compatible with the structure of `inputs`. '
            f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
            f'has structure {args_spec}.')

    for arg, in_dim in zip(flat_args, flat_in_dims):
        if not isinstance(in_dim, int) and in_dim is not None:
            raise ValueError(
                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
                f'Got in_dim={in_dim} for an input but in_dim must be either '
                f'an integer dimension or None.')
        if isinstance(in_dim, int) and not isinstance(arg, Tensor):
            raise ValueError(
                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
                f'Got in_dim={in_dim} for an input but the input is of type '
                f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
                f'please use None as the respective in_dim')
        if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
            raise ValueError(
                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
                f'Got in_dim={in_dim} for some input, but that input is a Tensor '
                f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
                f'0 <= in_dim < {arg.dim()}.')

    batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
    batched_inputs = [
        arg if in_dim is None else torch._add_batch_dim(
            arg, in_dim, vmap_level)
        for in_dim, arg in zip(flat_in_dims, flat_args)
    ]
    return tree_unflatten(batched_inputs, args_spec), batch_size
Example #27
0
 def wrapped(*primals):
     _args = list(flat_args)
     for num, arg in zip(diff_argnums, primals):
         _args[num] = arg
     _args = tree_unflatten(_args, args_spec)
     result = f(*_args, **kwargs)
     if output_process_fn_grad is not None:
         result = output_process_fn_grad(result)
     if isinstance(result, tuple):
         # TODO: Remove the following hack for namedtuples
         result = tuple(result)
         result = tuple(r for r in result if isinstance(r, Tensor) and (
             r.is_floating_point() or r.is_complex()))
         assert len(result) > 0
     return result
Example #28
0
def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2,
          *batched_args, **kwarg_values):
    outs = []
    flat_args, args_spec = pytree.tree_flatten(batched_args)
    flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
    flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
    assert (args_spec == dims_spec1)
    assert (args_spec == dims_spec2)
    assert (len(flat_dims1) == len(flat_dims2))
    for idx1 in range(batch_size1):
        out_split = []
        arg_split = [
            a.select(in_dim1, idx1) if in_dim1 is not None else a
            for a, in_dim1 in zip(flat_args, flat_dims1)
        ]
        for idx2 in range(batch_size2):
            new_args = [
                a.select(in_dim, idx2) if in_dim is not None else a
                for a, in_dim in zip(arg_split, flat_dims2)
            ]
            out = op(*pytree.tree_unflatten(new_args, args_spec),
                     **kwarg_values)
            out_split.append(out)
        outs.append(out_split)

    loop_out = []
    for out_split in outs:
        if isinstance(out_split[0], torch.Tensor):
            loop_out.append(torch.stack(out_split, out_dim1))
        else:
            new_out = []
            for idx in range(len(out_split[0])):
                new_out.append(
                    torch.stack([i[idx] for i in out_split], out_dim1))
            loop_out.append(new_out)

    new_out = []
    if isinstance(loop_out, torch.Tensor):
        new_out = torch.stack(loop_out, out_dim2)
    else:
        for idx in range(len(loop_out[0])):
            new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2))
    return new_out
Example #29
0
    def wrapper(*args):
        level = _grad_increment_nesting()
        output, aux, grad_input = None, None, None
        try:
            args = _wrap_all_tensors(args, level)
            diff_args = _slice_argnums(args, argnums)
            tree_map_(partial(_create_differentiable, level=level), diff_args)

            output = f(*args)
            if has_aux:
                output, aux = output

            if not isinstance(output, torch.Tensor):
                raise RuntimeError(
                    'grad_and_value(f)(*args): Expected f(*args)'
                    f'to return a Tensor, got {type(output)}')
            if output.dim() != 0:
                raise RuntimeError(
                    'grad_and_value(f)(*args): Expected f(*args)'
                    'to return a scalar Tensor, got tensor with '
                    f'{output.dim()} dims. Maybe you wanted to'
                    'use the vjp or jacrev APIs instead?')

            flat_diff_args, spec = tree_flatten(diff_args)

            # NB: need create_graph so that backward pass isn't run in no_grad mode
            flat_outputs = _as_tuple(output)
            flat_grad_input = _autograd_grad(flat_outputs,
                                             flat_diff_args,
                                             create_graph=True)
            grad_input = tree_unflatten(flat_grad_input, spec)

        finally:
            if grad_input is not None:
                grad_input = _undo_create_differentiable(grad_input, level)
            if output is not None:
                output = _undo_create_differentiable(output, level)
            if aux is not None:
                aux = _undo_create_differentiable(aux, level)
            _grad_decrement_nesting()
        if has_aux:
            return grad_input, (output, aux)
        return grad_input, output
Example #30
0
    def functional_call(*args, **kwargs):
        with stateless._reparametrize_module(
            mod, pytree.tree_unflatten(args[:params_len], params_spec)
        ):
            if isinstance(mod, torch.fx.GraphModule):
                with fx_traceback.override_stack_trace(), torch.autograd.detect_anomaly(
                    check_nan=False
                ):
                    out = Interpreter(mod).run(*args[params_len:], **kwargs)
            else:
                out = mod(*args[params_len:], **kwargs)

        if not isinstance(out, (tuple, list)):
            raise RuntimeError(
                "Graph output must be a tuple(). This is so that we can avoid "
                "pytree processing of the ouputs. Please change the module to "
                "have tuple outputs or use aot_module instead."
            )
        return out