def test_cpd_jtjvp_optimized(benchmark):
    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = benchmark(executor_JtJvps.run,
                              feed_dict={
                                  A: A_val,
                                  B: B_val,
                                  C: C_val,
                                  input_tensor: input_tensor_val,
                                  v_A: v_A_val,
                                  v_B: v_B_val,
                                  v_C: v_C_val
                              })
def test_cpd_jtjvp(benchmark):
    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list
        expected_hvp_val = benchmark(expect_jtjvp_val, A_val, B_val, C_val,
                                     v_A_val, v_B_val, v_C_val)
Exemple #3
0
def test_cpd_shared_exec(backendopt):
    dim = 3

    for datatype in backendopt:
        T.set_backend(datatype)

        input_val = init_rand_cp(dim, size, rank)
        A_list, input_tensor_val = input_val
        A_val, B_val, C_val = A_list

        outputs = cpd_als_shared_exec(dim, size, rank, 1, input_val)

        # expected values
        A_val = T.einsum(
            "abc,bk,ck->ak", input_tensor_val, B_val, C_val) @ T.inv(
                (T.transpose(B_val) @ B_val) * (T.transpose(C_val) @ C_val))
        B_val = T.einsum(
            "abc,ak,ck->bk", input_tensor_val, A_val, C_val) @ T.inv(
                (T.transpose(A_val) @ A_val) * (T.transpose(C_val) @ C_val))
        C_val = T.einsum(
            "abc,ak,bk->ck", input_tensor_val, A_val, B_val) @ T.inv(
                (T.transpose(A_val) @ A_val) * (T.transpose(B_val) @ B_val))

        assert T.norm(outputs[0] - A_val) < 1e-8
        assert T.norm(outputs[1] - B_val) < 1e-8
        assert T.norm(outputs[2] - C_val) < 1e-8
Exemple #4
0
def cpd_gradient_descent(size, rank, learning_rate):
    dim = 3

    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C])
        executor = ad.Executor([loss, grad_A, grad_B, grad_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        for i in range(100):
            loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run(
                feed_dict={
                    input_tensor: input_tensor_val,
                    A: A_val,
                    B: B_val,
                    C: C_val
                })
            A_val -= learning_rate * grad_A_val
            B_val -= learning_rate * grad_B_val
            C_val -= learning_rate * grad_C_val
            print(f'At iteration {i} the loss is: {loss_val}')
Exemple #5
0
def test_cpd_hessian_optimize_offdiag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_offdiag = [hessian[0][1], hessian[1][0]]
        for node in hessian_offdiag:
            optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            # This is currently non-deterministic.
            # assert num_operations == 14

        executor = ad.Executor(hessian_offdiag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })
Exemple #6
0
def test_cpd_als_sktensor(benchmark):
    for datatype in BACKEND_TYPES:

        _, input_tensor_val = init_rand_cp(dim, size, rank)
        benchmark(sk_cp_als,
                  dtensor(input_tensor_val),
                  rank=rank,
                  max_iter=1,
                  init='random')
Exemple #7
0
def test_cpd_jtjvp_optimize(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = executor_JtJvps.run(
            feed_dict={
                A: A_val,
                B: B_val,
                C: C_val,
                input_tensor: input_tensor_val,
                v_A: v_A_val,
                v_B: v_B_val,
                v_C: v_C_val
            })

        expected_hvp_val = expect_jtjvp_val(A_val, B_val, C_val, v_A_val,
                                            v_B_val, v_C_val)

        assert T.norm(jtjvp_val[0] - expected_hvp_val[0]) < 1e-8
        assert T.norm(jtjvp_val[1] - expected_hvp_val[1]) < 1e-8
        assert T.norm(jtjvp_val[2] - expected_hvp_val[2]) < 1e-8
Exemple #8
0
def cpd_newton(size, rank):
    dim = 3

    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])
        grads = ad.gradients(loss, [A, B, C])
        Hvps = ad.hvp(output_node=loss,
                      node_list=[A, B, C],
                      vector_list=[v_A, v_B, v_C])

        executor_grads = ad.Executor([loss] + grads)
        executor_Hvps = ad.Executor(Hvps)

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        for i in range(100):

            def hess_fn(v):
                return executor_Hvps.run(
                    feed_dict={
                        A: A_val,
                        B: B_val,
                        C: C_val,
                        input_tensor: input_tensor_val,
                        v_A: v[0],
                        v_B: v[1],
                        v_C: v[2]
                    })

            loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run(
                feed_dict={
                    A: A_val,
                    B: B_val,
                    C: C_val,
                    input_tensor: input_tensor_val
                })

            delta = conjugate_gradient(
                hess_fn=hess_fn,
                grads=[grad_A_val, grad_B_val, grad_C_val],
                error_tol=1e-9,
                max_iters=250)

            A_val -= delta[0]
            B_val -= delta[1]
            C_val -= delta[2]
            print(f'At iteration {i} the loss is: {loss_val}')
Exemple #9
0
def cpd_als_shared_exec(dim, size, rank, num_iter, input_val=[]):

    A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)

    full_hessian = ad.hessian(loss, A_list)
    hessians = [full_hessian[i][i] for i in range(len(full_hessian))]
    grads = ad.gradients(loss, A_list)

    updates = [
        ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]])
        for (hes, grad) in zip(hessians, grads)
    ]

    new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)]
    new_A_list = generate_sequential_optimal_tree(new_A_list, A_list)

    executor = ad.Executor(new_A_list)
    executor_loss = ad.Executor([simplify(loss)])

    if input_val == []:
        A_val_list, input_tensor_val = init_rand_cp(dim, size, rank)
    else:
        A_val_list, input_tensor_val = input_val

    for iter in range(num_iter):
        t0 = time.time()
        # als iterations
        for i in range(len(A_list)):

            feed_dict = dict(zip(A_list, A_val_list))
            feed_dict.update({input_tensor: input_tensor_val})

            if i == 0:
                A_val_list[0], = executor.run(feed_dict=feed_dict,
                                              out_nodes=[new_A_list[0]])
            else:
                A_val_list[i], = executor.run(feed_dict=feed_dict,
                                              reset_graph=False,
                                              evicted_inputs=[A_list[i - 1]],
                                              out_nodes=[new_A_list[i]])

        feed_dict = dict(zip(A_list, A_val_list))
        feed_dict.update({input_tensor: input_tensor_val})
        loss_val, = executor_loss.run(feed_dict=feed_dict)

        print(f'At iteration {iter} the loss is: {loss_val}')
        t1 = time.time()
        print(f"[ {iter} ] Sweep took {t1 - t0} seconds")

    return A_val_list
Exemple #10
0
def test_cpd_hessian_optimize_diag(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]]
        for node in hessian_diag:
            node = optimize(node)
            assert isinstance(node, ad.AddNode)
            num_operations = len(
                list(
                    filter(lambda x: isinstance(x, ad.OpNode),
                           find_topo_sort([node]))))
            """
            Use this assertion to test the optimize function.
            5 operations:
            1. T.einsum('ca,cb->ab',A,A),
            2. T.einsum('ca,cb->ab',B,B),
            3. T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),
            4. T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)),
            5. (T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10))+
            T.einsum('bd,ac->abcd',T.einsum('ab,ab->ab',T.einsum('ca,cb->ab',A,A),T.einsum('ca,cb->ab',B,B)),T.identity(10)))
            """
            assert num_operations == 5

        executor = ad.Executor(hessian_diag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })

        expected_hes_diag_val = [
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val,
                         T.identity(size))
        ]
        assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8
        assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8
        assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
def test_cpd_als_tensorly(benchmark):
    for datatype in BACKEND_TYPES:
        tl.set_backend(datatype)
        assert tl.get_backend() == datatype

        _, input_tensor_val = init_rand_cp(dim, size, rank)
        input_tensor = tl.tensor(input_tensor_val, dtype='float64')
        factors = benchmark(parafac,
                            input_tensor,
                            rank=rank,
                            init='random',
                            tol=0,
                            n_iter_max=1,
                            verbose=0)
Exemple #12
0
def test_cpd_grad(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        grad_A, grad_B, grad_C = ad.gradients(loss, [A, B, C])
        executor = ad.Executor([loss, grad_A, grad_B, grad_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        loss_val, grad_A_val, grad_B_val, grad_C_val = executor.run(
            feed_dict={
                input_tensor: input_tensor_val,
                A: A_val,
                B: B_val,
                C: C_val
            })

        expected_output_tensor = T.einsum("ia,ja,ka->ijk", A_val, B_val, C_val)
        expected_residual = expected_output_tensor - input_tensor_val
        expected_norm_error = T.norm(expected_residual)
        expected_loss = expected_norm_error * expected_norm_error

        expected_contract_residual_A = 2 * T.einsum("ijk,ia->ajk",
                                                    expected_residual, A_val)
        expected_contract_residual_B = 2 * T.einsum("ijk,ja->iak",
                                                    expected_residual, B_val)
        expected_contract_residual_C = 2 * T.einsum("ijk,ka->ija",
                                                    expected_residual, C_val)

        expected_grad_A = T.einsum("iak,ka->ia", expected_contract_residual_B,
                                   C_val)
        expected_grad_B = T.einsum("ajk,ka->ja", expected_contract_residual_A,
                                   C_val)
        expected_grad_C = T.einsum("ajk,ja->ka", expected_contract_residual_A,
                                   B_val)

        assert abs(loss_val - expected_loss) < 1e-8
        assert T.norm(grad_A_val - expected_grad_A) < 1e-8
        assert T.norm(grad_B_val - expected_grad_B) < 1e-8
        assert T.norm(grad_C_val - expected_grad_C) < 1e-8
Exemple #13
0
def cpd_als(dim, size, rank, num_iter, input_val=[]):

    A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)

    full_hessian = ad.hessian(loss, A_list)
    hessians = [full_hessian[i][i] for i in range(len(full_hessian))]
    grads = ad.gradients(loss, A_list)

    updates = [
        ad.tensordot(ad.tensorinv(hes), grad, [[2, 3], [0, 1]])
        for (hes, grad) in zip(hessians, grads)
    ]

    new_A_list = [simplify(A - update) for (A, update) in zip(A_list, updates)]

    executor = ad.Executor(new_A_list)
    executor_loss = ad.Executor([simplify(loss)])

    if input_val == []:
        A_val_list, input_tensor_val = init_rand_cp(dim, size, rank)
    else:
        A_val_list, input_tensor_val = input_val

    for iter in range(num_iter):
        # als iterations
        for i in range(len(A_list)):

            feed_dict = dict(zip(A_list, A_val_list))
            feed_dict.update({input_tensor: input_tensor_val})
            A_val_list[i], = executor.run(feed_dict=feed_dict,
                                          out_nodes=[new_A_list[i]])

        feed_dict = dict(zip(A_list, A_val_list))
        feed_dict.update({input_tensor: input_tensor_val})
        loss_val, = executor_loss.run(feed_dict=feed_dict)
        print(f'At iteration {iter} the loss is: {loss_val}')

    return A_val_list
Exemple #14
0
def test_cpd_hessian_simplify(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        hessian = ad.hessian(loss, [A, B, C])
        # TODO (issue #101): test the off-diagonal elements
        hessian_diag = [hessian[0][0], hessian[1][1], hessian[2][2]]
        for node in hessian_diag:
            node = simplify(node)
            input_node = node.inputs[0]
            assert len(input_node.inputs) == 5

        executor = ad.Executor(hessian_diag)
        hes_diag_vals = executor.run(feed_dict={
            A: A_val,
            B: B_val,
            C: C_val,
            input_tensor: input_tensor_val,
        })

        expected_hes_diag_val = [
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', B_val, B_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, C_val, C_val,
                         T.identity(size)),
            2 * T.einsum('eb,ed,fb,fd,ac->abcd', A_val, A_val, B_val, B_val,
                         T.identity(size))
        ]
        assert T.norm(hes_diag_vals[0] - expected_hes_diag_val[0]) < 1e-8
        assert T.norm(hes_diag_vals[1] - expected_hes_diag_val[1]) < 1e-8
        assert T.norm(hes_diag_vals[2] - expected_hes_diag_val[2]) < 1e-8
Exemple #15
0
def test_cpd_als_shared_exec(benchmark):
    for datatype in BACKEND_TYPES:
        input_tensor = init_rand_cp(dim, size, rank)
        outputs = benchmark(cpd_als_shared_exec, dim, size, rank, 1,
                            input_tensor)
Exemple #16
0
def cpd_nls(size, rank, regularization=1e-7, mode='ad'):
    """
    mode: ad / optimized / jax
    """
    assert mode in {'ad', 'jax', 'optimized'}

    dim = 3

    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)
        T.seed(1)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list

        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])
        grads = ad.gradients(loss, [A, B, C])
        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        if mode == 'jax':
            from source import SourceToSource
            StS = SourceToSource()
            StS.forward(JtJvps,
                        file=open("examples/jax_jtjvp.py", "w"),
                        function_name='jtjvp',
                        backend='jax')

        executor_grads = ad.Executor([loss] + grads)
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        executor_JtJvps = ad.Executor(JtJvps)
        optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val])

        regu_increase = False
        normT = T.norm(input_tensor_val)
        time_all, fitness = 0., 0.

        for i in range(10):

            t0 = time.time()

            def hess_fn(v):
                if mode == 'optimized':
                    from examples.cpd_jtjvp_optimized import jtjvp
                    return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]])
                elif mode == 'ad':
                    return executor_JtJvps.run(
                        feed_dict={
                            A: A_val,
                            B: B_val,
                            C: C_val,
                            input_tensor: input_tensor_val,
                            v_A: v[0],
                            v_B: v[1],
                            v_C: v[2]
                        })
                elif mode == 'jax':
                    from examples.jax_jtjvp import jtjvp
                    return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]])

            loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run(
                feed_dict={
                    A: A_val,
                    B: B_val,
                    C: C_val,
                    input_tensor: input_tensor_val
                })

            res = math.sqrt(loss_val)
            fitness = 1 - res / normT
            print(f"[ {i} ] Residual is {res} fitness is: {fitness}")
            print(f"Regularization is: {regularization}")

            [A_val, B_val, C_val], total_cg_time = optimizer.step(
                hess_fn=hess_fn,
                grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2],
                regularization=regularization)

            t1 = time.time()
            print(f"[ {i} ] Sweep took {t1 - t0} seconds")
            time_all += t1 - t0

            if regularization < 1e-07:
                regu_increase = True
            elif regularization > 1:
                regu_increase = False
            if regu_increase:
                regularization = regularization * 2
            else:
                regularization = regularization / 2

        return total_cg_time, fitness