Пример #1
0
        def forward(ctx, *flat_tensor_args):
            nonlocal compiled_fw, compiled_bw, num_outs
            # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
            # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
            old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
            if compiled_fw is None:
                flat_tensor_args = pytree.tree_map(
                    lambda x: x.detach().requires_grad_(x.requires_grad)
                    if isinstance(x, Tensor) else x, flat_tensor_args)
                fake_mode = FakeTensorMode.push(
                ) if config.use_fake_tensor else nullcontext()
                with preserve_rng_state(), fake_mode as mode:
                    # Set input tensors that require grad to leaves
                    fake_flat_tensor_args = pytree.tree_map(
                        lambda x: mode.from_tensor(x) if mode else x
                        if isinstance(x, Tensor) else x, flat_tensor_args)
                    with torch.set_grad_enabled(grad_state):
                        out = flat_fn(*fake_flat_tensor_args)
                    out = pytree.tree_map(
                        lambda x: x.detach().contiguous()
                        if isinstance(x, Tensor) else x, out)

                    if isinstance(out, (list, tuple)):
                        num_outs = len(out)
                    else:
                        num_outs = 1

                    joint_inputs = (fake_flat_tensor_args, out)
                    aot_decompositions = {
                        **aot_autograd_decompositions,
                        **decompositions
                    }
                    with torch.set_grad_enabled(grad_state):
                        fx_g = make_fx(joint_forward_backward,
                                       aot_decompositions)(*joint_inputs)

                        if config.use_functionalize:
                            # Functionalize the foward backward graph. First create a
                            # fake fn to make functionalize happy
                            def fake_fn(primals, tangents):
                                return fx_g(primals, tangents)

                            fx_g = make_fx(
                                functionalize(fake_fn))(*joint_inputs)
                fw_module, bw_module = partition_fn(fx_g, joint_inputs)
                # print(fw_module.code, bw_module.code)

                compiled_fw = fw_compiler(fw_module, flat_tensor_args)
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

                bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
                compiled_bw = bw_compiler(bw_module, bw_args)
            else:
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
            torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
            ctx.save_for_backward(*fw_outs[num_outs:])
            return tuple(fw_outs[0:num_outs])
Пример #2
0
    def test_make_fx_no_decompose(self, device):
        # FIXME
        return self.skipTest("error: maximum recursion reached")

        def f(x):
            return torch.tanh(x).sum()

        fx_f = make_fx(grad(f))(torch.randn(5))
        ops = set([i.target for i in fx_f.graph.nodes])

        self.assertEqual(torch.ops.aten.tanh_backward in ops, True)

        fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
        ops = set([i.target for i in fx_f.graph.nodes])
        self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
Пример #3
0
    def test_scalar_device(self, device):
        def f(a, b):
            return a + b

        inps = [torch.randn(3, device=device), torch.tensor(5)]
        fx_f = make_fx(f)(*inps)
        self.assertEqual(fx_f(*inps), f(*inps))
Пример #4
0
        def forward(ctx, *flat_tensor_args):
            nonlocal compiled_fw, compiled_bw, num_outs
            if compiled_fw is None:
                with torch.set_grad_enabled(grad_state):
                    out = flat_fn(*flat_tensor_args)
                out = pytree.tree_map(
                    lambda x: x.detach().contiguous()
                    if isinstance(x, Tensor) else x, out)

                if isinstance(out, (list, tuple)):
                    num_outs = len(out)
                else:
                    num_outs = 1

                joint_inputs = (flat_tensor_args, out)
                aot_decompositions = {
                    **aot_autograd_decompositions,
                    **decompositions
                }
                with torch.set_grad_enabled(grad_state):
                    fx_g = make_fx(joint_forward_backward,
                                   aot_decompositions)(*joint_inputs)
                fw_module, bw_module = partition_fn(fx_g, joint_inputs)
                # print(fw_module.code, bw_module.code)

                compiled_fw = fw_compiler(fw_module, flat_tensor_args)
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

                bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
                compiled_bw = bw_compiler(bw_module, bw_args)
            else:
                fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
            ctx.save_for_backward(*fw_outs[num_outs:])
            return tuple(fw_outs[0:num_outs])
Пример #5
0
 def test_make_fx_jacrev(self, device):
     def f(x):
         return x.sin().sum()
     inp = torch.randn(3)
     f = jacrev(jacrev(f))
     fx_f = make_fx(f)(inp)
     new_inp = torch.randn(3)
     self.assertEqual(fx_f(new_inp), f(new_inp))
Пример #6
0
 def test_make_fx_vmap(self, device):
     def f(x):
         return torch.sin(x)
     inp = torch.randn(5, 3)
     f = vmap(f)
     fx_f = make_fx(f)(inp)
     new_inp = torch.randn(5, 3)
     self.assertEqual(fx_f(new_inp), f(new_inp))
Пример #7
0
    def test_make_fx_grad(self, device):
        def f(x):
            return torch.sin(x).sum()
        inp = torch.randn(3)
        f = grad(f)
        fx_f = make_fx(f)(inp)

        new_inp = torch.randn(3)
        self.assertEqual(fx_f(new_inp), f(new_inp))
Пример #8
0
    def test_make_fx_vjp(self, device):
        def f(x):
            return torch.sin(x).sum()

        primals = torch.randn(3)
        _, vjp_fn = vjp(f, primals)
        cotangent = torch.randn(())
        fx_f = make_fx(vjp_fn)(cotangent, True, True)
        new_cotangent = torch.randn(())
        self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
Пример #9
0
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
    fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
    with track_graph_compiling("inference"):
        compiled_fw = aot_config.fw_compiler(fw_module, flat_args)

    @wraps(compiled_fw)
    def new_fn(args):
        fw_outs = call_func_with_args(compiled_fw, args)
        return fw_outs

    return new_fn
Пример #10
0
def profile_function(name, f, inp):
    fx_g = make_fx(f)(inp)

    new_g = fx_graph_cse(fx_g.graph)
    new_g = fx.GraphModule(fx_g, new_g)
    # do not benchmark against the scripted version because script already does some CSE
    # script_f = torch.jit.script(fx_g)
    # script_g = torch.jit.script(new_g)
    # avg_cuda_time_f = profile_it(script_f, inp)
    # avg_cuda_time_g = profile_it(script_g, inp)
    avg_cuda_time_f = profile_it(fx_g, inp)
    avg_cuda_time_g = profile_it(new_g, inp)
    num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)

    print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
Пример #11
0
    def test_tup_use(self):
        def f(a, b):
            tup = torch.std_mean(a)
            return (tup[0] + b * tup[1], )

        inps = [torch.randn(3), torch.randn(3)]

        def has_add(fx_g, inps):
            return (torch.ops.aten.add.Tensor
                    in set([i.target for i in fx_g.graph.nodes]))

        failing_f = make_fx(f)(*inps)
        min_f, inps = minifier(failing_f, inps, has_add)

        self.assertEqual(len(min_f.graph.nodes), 4)
        self.assertEqual(len(inps), 2)
Пример #12
0
    def test_has_mul_minifier(self):
        def failing_f(x, y):
            y = y / 3
            x = x + 3
            x = x * y
            return x + y

        inps = [torch.randn(3), torch.randn(3)]
        failing_f = make_fx(failing_f)(*inps)

        def has_mul(fx_g, inps):
            return (torch.ops.aten.mul.Tensor
                    in set([i.target for i in fx_g.graph.nodes]))

        min_f, inps = minifier(failing_f, inps, has_mul)
        self.assertEqual(len(min_f.graph.nodes), 4)
        self.assertEqual(len(inps), 2)
Пример #13
0
    def test_has_mul_minifier(self):
        def failing_f(x, y):
            y = y / 3
            x = x + 3
            x = x * y
            return x + y

        inps = [torch.randn(3), torch.randn(3)]
        failing_f = make_fx(failing_f)(*inps)

        def pass_checker(fx_g, inps):
            return (torch.ops.aten.mul
                    in set([i.target for i in fx_g.graph.nodes]))

        min_f, inps = minifier(failing_f, inps, pass_checker)
        assert len(min_f.graph.nodes) == 4
        assert len(inps) == 2
Пример #14
0
    def test_input_returned(self):
        def f(a, b, c):
            a = a.sin()
            c = c.cos()
            d = a * c
            return (a, b, c, d)

        inps = [torch.randn(3) for _ in range(3)]

        def inputs_returned(fx_g, inps):
            inps = set(get_placeholders(fx_g.graph))
            outs = set(get_outputs(fx_g.graph))
            return len(inps & outs) > 0

        failing_f = make_fx(f)(*inps)
        min_f, inps = minifier(failing_f, inps, inputs_returned)
        self.assertEqual(len(min_f.graph.nodes), 2)
        self.assertEqual(len(inps), 1)
Пример #15
0
def generate_graph(model, inputs, training_fn):
    # TODO: Pass the decomposition_table according to the model/needs.
    fx_g = make_fx(training_fn,
                   decomposition_table=get_decompositions([
                       torch.ops.aten.embedding_dense_backward,
                       torch.ops.aten.native_layer_norm_backward,
                       torch.ops.aten.slice_backward,
                       torch.ops.aten.select_backward
                   ]))(dict(model.named_parameters()),
                       dict(model.named_buffers()), inputs)
    fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
    fx_g.recompile()
    fx_g = change_fx_graph_return_to_tuple(fx_g)
    ts_g = torch.jit.script(fx_g)
    # TODO: If not saved/load creates some unnecessary functions that
    # causes problem during mlir-translate.
    temp = tempfile.NamedTemporaryFile(suffix='_heavy_dep', prefix='temp_ts_')
    ts_g.save(temp.name)
    new_ts = torch.jit.load(temp.name)
    return new_ts
Пример #16
0
    def test_has_add_mul(self):
        def failing_f(x):
            x = x * 3
            x = x + 5
            x = x.cos()
            zero = x - x
            result = zero / zero
            result = result + 3
            return (result * 2, )

        inps = [torch.randn(3)]
        failing_f = make_fx(failing_f)(*inps)

        def pass_checker(fx_g, inps):
            # Basically, make sure none of the inputs are nans
            for i in inps:
                if torch.isnan(i).any():
                    return False
            return torch.isnan(fx_g(*inps)[0]).any()

        min_f, inps = minifier(failing_f, inps, pass_checker)
        assert len(min_f.graph.nodes) == 3
        assert len(inps) == 1
Пример #17
0
    def test_make_fx_exhaustive(self, device, dtype, op):
        def f(args, kwargs):
            return op.op(*args, **kwargs)

        sample_inputs_itr = op.sample_inputs(device,
                                             dtype,
                                             requires_grad=False)
        new_f = None
        for sample_input in sample_inputs_itr:
            args = [sample_input.input] + list(sample_input.args)
            kwargs = sample_input.kwargs

            new_f = make_fx(f)(args, kwargs)
            for arg in args:
                if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
                    arg.uniform_(0, 1)
            try:
                old_out = f(args, kwargs)
            except Exception:
                continue
            new_out = new_f(args, kwargs)
            self.assertEqual(new_out, old_out)
            pass
Пример #18
0
    def test_has_add_mul(self):
        def failing_f(x):
            x = x * 3
            x = x + 5
            x = x.cos()
            zero = x - x
            result = zero / zero
            result = result + 3
            return (result * 2, )

        inps = [torch.randn(3)]
        failing_f = make_fx(failing_f)(*inps)

        def has_nans(fx_g, inps):
            # Basically, make sure none of the nodes are computing nans
            for i in inps:
                if torch.isnan(i).any():
                    return False
            return torch.isnan(fx_g(*inps)[0]).any()

        min_f, inps = minifier(failing_f, inps, has_nans)
        self.assertEqual(len(min_f.graph.nodes), 3)
        self.assertEqual(len(inps), 1)
Пример #19
0
def check(f, t, delta, check_val=True, graph_input=False):
    if graph_input:
        fx_g = f
    else:
        fx_g = make_fx(f)(t)
    new_graph = fx_graph_cse(fx_g.graph)
    new_g = fx.GraphModule(fx_g, new_graph)

    # the number of nodes decrease/ or stay the same
    old_num_nodes = len(fx_g.graph.nodes)
    new_num_nodes = len(new_graph.nodes)
    if delta == -1:
        assert old_num_nodes >= new_num_nodes, (
            f"number of nodes increased {old_num_nodes}, {new_num_nodes}")
    else:
        assert old_num_nodes == new_num_nodes + delta, (
            f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
        )

    # a second pass should not reduce more nodes
    pass_2_graph = fx_graph_cse(new_graph)
    pass_2_num_nodes = len(pass_2_graph.nodes)
    assert pass_2_num_nodes == new_num_nodes, (
        f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
    )

    # check correctness
    if check_val:
        true_result = fx_g(t)
        our_result = new_g(t)
        if true_result is None:  # both return None
            assert our_result is None, f"true result is None, CSE result is {our_result}"
        else:  # results returned are the same
            assert torch.all(true_result == our_result), (
                f"results are different {true_result}, {our_result}"
            )  # check results are the same
Пример #20
0
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
    joint_forward_backward = create_joint_forward_backward(flat_fn)
    out = flat_fn(*flat_args)
    out = pytree.tree_map(
        lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
        out,
    )

    if isinstance(out, (list, tuple)):
        _num_outs = len(out)
    else:
        _num_outs = 1

    joint_inputs = (flat_args, out)
    fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs)

    if config.use_functionalize:
        # Functionalize the foward backward graph. First create a
        # fake fn to make functionalize happy
        def fake_fn(primals, tangents):
            return fx_g(primals, tangents)

        fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)

    if config.debug_joint:
        print(fx_g.code)

    with torch.no_grad():
        with track_graph_compiling("joint"):
            fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs)

        if config.debug_graphs:
            print(fw_module.code, bw_module.code)

        with track_graph_compiling("forward"):
            compiled_fw_func = aot_config.fw_compiler(fw_module, flat_args)

        if config.debug_partitioner:
            fw_outs = call_func_with_args(compiled_fw_func, flat_args)
            activation_sizes = 0
            for out in fw_outs[_num_outs:]:
                if isinstance(out, torch.Tensor):
                    activation_sizes += out.storage().nbytes()
            print(f"Real Activations Stored(GB): {activation_sizes/1e9}")

    class CompiledFunction(torch.autograd.Function):
        compiled_fw = compiled_fw_func
        compiled_bw = None
        num_outs = _num_outs

        @staticmethod
        @disable_torchdynamo
        def forward(ctx, *flat_tensor_args):
            fw_outs = call_func_with_args(
                CompiledFunction.compiled_fw, flat_tensor_args
            )
            num_outs = CompiledFunction.num_outs
            ctx.save_for_backward(*fw_outs[num_outs:])
            return tuple(fw_outs[0:num_outs])

        @staticmethod
        @disable_torchdynamo
        def backward(ctx, *flat_args):
            contiguous_args = [t.contiguous() for t in flat_args]
            all_args = list(ctx.saved_tensors) + list(contiguous_args)
            if CompiledFunction.compiled_bw is None:
                with track_graph_compiling("backward", True):
                    CompiledFunction.compiled_bw = aot_config.bw_compiler(
                        bw_module, all_args
                    )
            ctx.maybe_clear_saved_tensors()
            out = call_func_with_args(
                CompiledFunction.compiled_bw, all_args, steal_args=True
            )

            return tuple(out)

    return CompiledFunction.apply
Пример #21
0
from functorch import grad, nnc_jit, make_fx, make_nnc
import torch
import time


def f(x):
    return torch.sin(x).sum()


inp = torch.randn(100)
grad_pt = grad(f)
grad_fx = make_fx(grad_pt)(inp)
grad_nnc = nnc_jit(grad_pt)
loopnest = make_nnc(grad_pt)(inp)
print(loopnest)


def bench(name, f, iters=10000, warmup=3):
    for _ in range(warmup):
        f()
    begin = time.time()
    for _ in range(iters):
        f()
    print(f"{name}: ", time.time() - begin)


bench("Pytorch: ", lambda: grad_pt(inp))
bench("FX: ", lambda: grad_fx(inp))
bench("NNC: ", lambda: grad_nnc(inp))