예제 #1
0
def test_cpd_shared_exec(backendopt):
    dim = 3

    for datatype in backendopt:
        T.set_backend(datatype)

        input_val = init_rand_cp(dim, size, rank)
        A_list, input_tensor_val = input_val
        A_val, B_val, C_val = A_list

        outputs = cpd_als_shared_exec(dim, size, rank, 1, input_val)

        # expected values
        A_val = T.einsum(
            "abc,bk,ck->ak", input_tensor_val, B_val, C_val) @ T.inv(
                (T.transpose(B_val) @ B_val) * (T.transpose(C_val) @ C_val))
        B_val = T.einsum(
            "abc,ak,ck->bk", input_tensor_val, A_val, C_val) @ T.inv(
                (T.transpose(A_val) @ A_val) * (T.transpose(C_val) @ C_val))
        C_val = T.einsum(
            "abc,ak,bk->ck", input_tensor_val, A_val, B_val) @ T.inv(
                (T.transpose(A_val) @ A_val) * (T.transpose(B_val) @ B_val))

        assert T.norm(outputs[0] - A_val) < 1e-8
        assert T.norm(outputs[1] - B_val) < 1e-8
        assert T.norm(outputs[2] - C_val) < 1e-8
예제 #2
0
def test_cpd_hessian_optimize_diag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]]
        for node in hessian_diag:
            node = optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            """
            Use this assertion to test the optimize function.
            5 operations:
            1. T.einsum('ca,cb->ab',A,A),
            2. T.einsum('ca,cb->ab',B,B),
            3. T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),
            4. T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)),
            5. (T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))+
            T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)))
            """
            assert num_operations == 5

        executor = ad.Executor(hessian_diag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })

        expected_hes_diag_val = [
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val,
                         T.identity(size))
        ]
        assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8
        assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8
        assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
예제 #3
0
def test_cpd_grad(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C])
        executor = ad.Executor([loss, grad_A, grad_B, grad_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run(
            feed_dict={
                input_tensor: input_tensor_val,
                A: A_val,
                B: B_val,
                C: C_val
            })

        expected_output_tensor = T.einsum("ia,ja,ka->ijk", A_val, B_val, C_val)
        expected_residual = expected_output_tensor - input_tensor_val
        expected_norm_error = T.norm(expected_residual)
        expected_loss = expected_norm_error * expected_norm_error

        expected_contract_residual_A = 2 * T.einsum("ijk,ia->ajk",
                                                    expected_residual, A_val)
        expected_contract_residual_B = 2 * T.einsum("ijk,ja->iak",
                                                    expected_residual, B_val)
        expected_contract_residual_C = 2 * T.einsum("ijk,ka->ija",
                                                    expected_residual, C_val)

        expected_grad_A = T.einsum("iak,ka->ia", expected_contract_residual_B,
                                   C_val)
        expected_grad_B = T.einsum("ajk,ka->ja", expected_contract_residual_A,
                                   C_val)
        expected_grad_C = T.einsum("ajk,ja->ka", expected_contract_residual_A,
                                   B_val)

        assert abs(loss_val - expected_loss) < 1e-8
        assert T.norm(grad_A_val - expected_grad_A) < 1e-8
        assert T.norm(grad_B_val - expected_grad_B) < 1e-8
        assert T.norm(grad_C_val - expected_grad_C) < 1e-8
예제 #4
0
def test_cpd_jtjvp_optimize(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = executor_JtJvps.run(
            feed_dict={
                A: A_val,
                B: B_val,
                C: C_val,
                input_tensor: input_tensor_val,
                v_A: v_A_val,
                v_B: v_B_val,
                v_C: v_C_val
            })

        expected_hvp_val = expect_jtjvp_val(A_val, B_val, C_val, v_A_val,
                                            v_B_val, v_C_val)

        assert T.norm(jtjvp_val[0] - expected_hvp_val[0]) < 1e-8
        assert T.norm(jtjvp_val[1] - expected_hvp_val[1]) < 1e-8
        assert T.norm(jtjvp_val[2] - expected_hvp_val[2]) < 1e-8
예제 #5
0
def test_gauge_transform_left(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        tensors_input = rand_mps(num=4, rank=4, size=2)
        tensors = gauge_transform_mps(tensors_input, right=False)

        # make sure the transformation will not change the mps results
        mps = T.einsum('ab,acd,cef,eg->bdfg', *tensors_input)
        mps_gauge = T.einsum('ab,acd,cef,eg->bdfg', *tensors)
        assert T.norm(mps - mps_gauge) < 1e-8

        dim = len(tensors_input)

        # test all tensors except the right one's orthogonality
        inner = T.einsum("ab,cb->ac", tensors[0], tensors[0])
        assert T.norm(inner - T.identity(inner.shape[0])) < 1e-8

        for i in range(1, dim - 1):
            inner = T.einsum("abc,adc->bd", tensors[i], tensors[i])
            assert T.norm(inner - T.identity(inner.shape[0])) < 1e-8
예제 #6
0
def test_cpd_hessian_simplify(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        # TODO (issue #101): test the off-diagonal elements
        hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]]
        for node in hessian_diag:
            node = simplify(node)
            input_node = node.inputs[0]
            assert len(input_node.inputs) == 5

        executor = ad.Executor(hessian_diag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })

        expected_hes_diag_val = [
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val,
                         T.identity(size))
        ]
        assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8
        assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8
        assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
예제 #7
0
def test_HinverseG(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        N = 10
        T.seed(1224)

        A = T.random([N, N])
        A = T.transpose(A) @ A
        A = A + T.identity(N)
        b = T.random([N])

        def hess_fn(x):
            return [T.einsum("ab,b->a", A, x[0])]

        error_tol = 1e-9
        x, = conjugate_gradient(hess_fn, [b], error_tol)
        assert (T.norm(T.abs(T.einsum("ab,b->a", A, x) - b)) <= 1e-4)
예제 #8
0
def test_transpose():
    def testfunc(a):
        # Note: because our executor output is always a list, here a list is also
        # returned to make them consistent.
        return np.transpose(a, (1, 0, 2)),

    T.set_backend('jax')
    a = T.random((5, 10, 7))
    inputs = [a]

    out_nodes, variables = make_graph(testfunc, *inputs)
    executor = ad.Executor(out_nodes)
    feed_dict = dict(zip(variables, inputs))

    outvals = executor.run(feed_dict=feed_dict)
    expect_outvals = testfunc(*inputs)

    for outval, expect_outval in zip(outvals, expect_outvals):
        assert T.norm(outval - expect_outval) < 1e-6
예제 #9
0
def test_simpledot():
    def testfunc(w, b, x):
        return np.dot(w, x) + b + np.ones(5), x

    T.set_backend('jax')
    w = T.random((5, 10))
    b = T.random((5, ))
    x = T.random((10, ))
    inputs = [w, b, x]

    out_nodes, variables = make_graph(testfunc, *inputs)
    executor = ad.Executor(out_nodes)
    feed_dict = dict(zip(variables, inputs))

    outvals = executor.run(feed_dict=feed_dict)
    expect_outvals = testfunc(*inputs)

    for outval, expect_outval in zip(outvals, expect_outvals):
        assert T.norm(outval - expect_outval) < 1e-6
예제 #10
0
def test_inner_product_einsum(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)
        x = ad.Variable(name="x", shape=[3])
        x_inner = ad.einsum('i,i->', x, x)

        grad_x, = ad.gradients(x_inner, [x])

        executor = ad.Executor([x_inner, grad_x])
        x_val = T.tensor([3., 4.])  # 1x2

        y_val, grad_x_val = executor.run(feed_dict={x: x_val})

        expected_yval = T.norm(x_val)**2
        expected_grad_x_val = 2 * x_val

        assert isinstance(x_inner, ad.Node)
        assert T.array_equal(y_val, expected_yval)
        assert T.array_equal(grad_x_val, expected_grad_x_val)
예제 #11
0
def test_mul():
    def testfunc(w, b, x):
        # Note: because our executor output is always a list, here a list is also
        # returned to make them consistent.
        return w * x + b,

    T.set_backend('jax')
    w = T.random((5, 10))
    b = T.random((5, 10))
    x = T.random((5, 10))
    inputs = [w, b, x]

    out_nodes, variables = make_graph(testfunc, *inputs)
    executor = ad.Executor(out_nodes)
    feed_dict = dict(zip(variables, inputs))

    outvals = executor.run(feed_dict=feed_dict)
    expect_outvals = testfunc(*inputs)

    for outval, expect_outval in zip(outvals, expect_outvals):
        assert T.norm(outval - expect_outval) < 1e-6
예제 #12
0
def test_norm(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        x = ad.Variable(name="x", shape=[3, 2])
        y = ad.norm(x)
        z = y**2

        grad_x, = ad.gradients(z, [x])

        executor = ad.Executor([z, grad_x])
        x_val = T.tensor([[1., 2.], [3., 4.], [5., 6.]])  # 3x2

        z_val, grad_x_val = executor.run(feed_dict={x: x_val})

        expected_zval = T.norm(x_val)**2
        expected_grad_x_val = 2 * x_val

        assert isinstance(z, ad.Node)
        assert T.array_equal(z_val, expected_zval)
        assert T.array_equal(grad_x_val, expected_grad_x_val)
예제 #13
0
def test_cpd():
    def testfunc(A, B, C, X):
        T = np.einsum("ia,ja,ka->ijk", A, B, C)
        res = T - X
        return np.einsum("ijk,ijk->", res, res),

    size = 3
    rank = 2
    T.set_backend('jax')
    X = T.random((size, size, size))
    A = T.random((size, rank))
    B = T.random((size, rank))
    C = T.random((size, rank))
    inputs = [A, B, C, X]

    out_nodes, variables = make_graph(testfunc, *inputs)
    executor = ad.Executor(out_nodes)
    feed_dict = dict(zip(variables, inputs))

    outvals = executor.run(feed_dict=feed_dict)
    expect_outvals = testfunc(*inputs)

    for outval, expect_outval in zip(outvals, expect_outvals):
        assert T.norm(outval - expect_outval) < 1e-6
예제 #14
0
파일: cpd.py 프로젝트: LinjianMa/AutoHOOT
def cpd_nls(size, rank, regularization=1e-7, mode='ad'):
    """
    mode: ad / optimized / jax
    """
    assert mode in {'ad', 'jax', 'optimized'}

    dim = 3

    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)
        T.seed(1)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list

        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])
        grads = ad.gradients(loss, [A, B, C])
        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        if mode == 'jax':
            from source import SourceToSource
            StS = SourceToSource()
            StS.forward(JtJvps,
                        file=open("examples/jax_jtjvp.py", "w"),
                        function_name='jtjvp',
                        backend='jax')

        executor_grads = ad.Executor([loss] + grads)
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        executor_JtJvps = ad.Executor(JtJvps)
        optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val])

        regu_increase = False
        normT = T.norm(input_tensor_val)
        time_all, fitness = 0., 0.

        for i in range(10):

            t0 = time.time()

            def hess_fn(v):
                if mode == 'optimized':
                    from examples.cpd_jtjvp_optimized import jtjvp
                    return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]])
                elif mode == 'ad':
                    return executor_JtJvps.run(
                        feed_dict={
                            A: A_val,
                            B: B_val,
                            C: C_val,
                            input_tensor: input_tensor_val,
                            v_A: v[0],
                            v_B: v[1],
                            v_C: v[2]
                        })
                elif mode == 'jax':
                    from examples.jax_jtjvp import jtjvp
                    return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]])

            loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run(
                feed_dict={
                    A: A_val,
                    B: B_val,
                    C: C_val,
                    input_tensor: input_tensor_val
                })

            res = math.sqrt(loss_val)
            fitness = 1 - res / normT
            print(f"[ {i} ] Residual is {res} fitness is: {fitness}")
            print(f"Regularization is: {regularization}")

            [A_val, B_val, C_val], total_cg_time = optimizer.step(
                hess_fn=hess_fn,
                grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2],
                regularization=regularization)

            t1 = time.time()
            print(f"[ {i} ] Sweep took {t1 - t0} seconds")
            time_all += t1 - t0

            if regularization < 1e-07:
                regu_increase = True
            elif regularization > 1:
                regu_increase = False
            if regu_increase:
                regularization = regularization * 2
            else:
                regularization = regularization / 2

        return total_cg_time, fitness