예제 #1
0
def test_cpd_hessian_optimize_offdiag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_offdiag = [hessian[0][1], hessian[1][0]]
        for node in hessian_offdiag:
            optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            # This is currently non-deterministic.
            # assert num_operations == 14

        executor = ad.Executor(hessian_offdiag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })
예제 #2
0
def test_large_matmul_chain(backendopt):
    n = 60
    size = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        # build the graph of x_1 @ ... @ x_n
        x_list = [
            ad.Variable(name=f"x{i}", shape=[size, size]) for i in range(n)
        ]
        prev_char = chr(192)
        left_char = prev_char
        for i in range(n):
            new_char = chr(ord(prev_char) + 1)
            x_list[i].subscripts = f"{prev_char}{new_char}"
            prev_char = new_char
        right_char = prev_char
        input_subs = ','.join([node.subscripts for node in x_list])
        einsum_subscripts = input_subs + '->' + left_char + right_char

        out = ad.einsum(einsum_subscripts, *x_list)
        # decompose the large einsum, and rewrite the einsum expression of the
        # generated einsum tree so there's no unicode character
        out = optimize(out)
        executor = ad.Executor([out])

        x_val_list = [T.random([size, size]) for _ in range(n)]
        out_val, = executor.run(feed_dict=dict(zip(x_list, x_val_list)))

        out_val_matmul = x_val_list[0]
        for i in range(1, n):
            out_val_matmul = out_val_matmul @ x_val_list[i]
        assert float_eq(out_val, out_val_matmul, tol=1e-2)
예제 #3
0
def test_cpd_jtjvp_optimized(benchmark):
    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = benchmark(executor_JtJvps.run,
                              feed_dict={
                                  A: A_val,
                                  B: B_val,
                                  C: C_val,
                                  input_tensor: input_tensor_val,
                                  v_A: v_A_val,
                                  v_B: v_B_val,
                                  v_C: v_C_val
                              })
예제 #4
0
def test_simplify_optimize_w_tail_einsum():

    A = ad.Variable(name="A", shape=[2, 2])

    out = ad.einsum("ab,bc->ac", A,
                    ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2)))
    newout_optimize = optimize(out)
    newout_simplify = simplify(out)

    assert newout_optimize == A
    assert newout_simplify == A
예제 #5
0
def test_simplify_optimize_w_tail_einsum(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,bc->ac", A,
                        ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2)))
        newout_optimize = optimize(out)
        newout_simplify = simplify(out)

        assert newout_optimize == A
        assert newout_simplify == A
예제 #6
0
def test_cpd_hessian_optimize_diag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]]
        for node in hessian_diag:
            node = optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            """
            Use this assertion to test the optimize function.
            5 operations:
            1. T.einsum('ca,cb->ab',A,A),
            2. T.einsum('ca,cb->ab',B,B),
            3. T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),
            4. T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)),
            5. (T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))+
            T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)))
            """
            assert num_operations == 5

        executor = ad.Executor(hessian_diag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })

        expected_hes_diag_val = [
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val,
                         T.identity(size))
        ]
        assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8
        assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8
        assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
예제 #7
0
def test_cpd_jtjvp_optimize(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = executor_JtJvps.run(
            feed_dict={
                A: A_val,
                B: B_val,
                C: C_val,
                input_tensor: input_tensor_val,
                v_A: v_A_val,
                v_B: v_B_val,
                v_C: v_C_val
            })

        expected_hvp_val = expect_jtjvp_val(A_val, B_val, C_val, v_A_val,
                                            v_B_val, v_C_val)

        assert T.norm(jtjvp_val[0] - expected_hvp_val[0]) < 1e-8
        assert T.norm(jtjvp_val[1] - expected_hvp_val[1]) < 1e-8
        assert T.norm(jtjvp_val[2] - expected_hvp_val[2]) < 1e-8
예제 #8
0
파일: cpd.py 프로젝트: LinjianMa/AutoHOOT
def cpd_nls(size, rank, regularization=1e-7, mode='ad'):
    """
    mode: ad / optimized / jax
    """
    assert mode in {'ad', 'jax', 'optimized'}

    dim = 3

    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)
        T.seed(1)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list

        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])
        grads = ad.gradients(loss, [A, B, C])
        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        if mode == 'jax':
            from source import SourceToSource
            StS = SourceToSource()
            StS.forward(JtJvps,
                        file=open("examples/jax_jtjvp.py", "w"),
                        function_name='jtjvp',
                        backend='jax')

        executor_grads = ad.Executor([loss] + grads)
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        executor_JtJvps = ad.Executor(JtJvps)
        optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val])

        regu_increase = False
        normT = T.norm(input_tensor_val)
        time_all, fitness = 0., 0.

        for i in range(10):

            t0 = time.time()

            def hess_fn(v):
                if mode == 'optimized':
                    from examples.cpd_jtjvp_optimized import jtjvp
                    return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]])
                elif mode == 'ad':
                    return executor_JtJvps.run(
                        feed_dict={
                            A: A_val,
                            B: B_val,
                            C: C_val,
                            input_tensor: input_tensor_val,
                            v_A: v[0],
                            v_B: v[1],
                            v_C: v[2]
                        })
                elif mode == 'jax':
                    from examples.jax_jtjvp import jtjvp
                    return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]])

            loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run(
                feed_dict={
                    A: A_val,
                    B: B_val,
                    C: C_val,
                    input_tensor: input_tensor_val
                })

            res = math.sqrt(loss_val)
            fitness = 1 - res / normT
            print(f"[ {i} ] Residual is {res} fitness is: {fitness}")
            print(f"Regularization is: {regularization}")

            [A_val, B_val, C_val], total_cg_time = optimizer.step(
                hess_fn=hess_fn,
                grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2],
                regularization=regularization)

            t1 = time.time()
            print(f"[ {i} ] Sweep took {t1 - t0} seconds")
            time_all += t1 - t0

            if regularization < 1e-07:
                regu_increase = True
            elif regularization > 1:
                regu_increase = False
            if regu_increase:
                regularization = regularization * 2
            else:
                regularization = regularization / 2

        return total_cg_time, fitness