def optimize(node):
    """Optimize a graph with a single output node.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    node = distribute_tree(node)
    linearize(node)

    all_nodes = find_topo_sort([node])
    ret_node = PseudoNode(node)
    with OutputInjectedMode(all_nodes):
        trees = find_sub_einsumtree(ret_node)
        for tree in trees:
            out_node_p, in_nodes = tree
            new_z = fuse_einsums(out_node_p.node, in_nodes)
            prune_identity_nodes(new_z)
            new_z = generate_optimal_tree(new_z)
            replace_node(out_node_p, new_z)

    node = declone(ret_node.node)
    all_nodes = find_topo_sort([node])
    for node in all_nodes:
        if isinstance(node, ad.EinsumNode):
            rewrite_einsum_expr(node)

    for node in find_topo_sort([node]):
        if node.inputs != []:
            node.set_inputs(node.inputs)

    dedup(node)
    return node
def test_cpd_jtjvp_optimized(benchmark):
    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = benchmark(executor_JtJvps.run,
                              feed_dict={
                                  A: A_val,
                                  B: B_val,
                                  C: C_val,
                                  input_tensor: input_tensor_val,
                                  v_A: v_A_val,
                                  v_B: v_B_val,
                                  v_C: v_C_val
                              })
def generate_sequential_optimal_tree(einsum_nodes, input_nodes):
    """
    Regenerating einsum expressions based on the dimension tree.
    Parameters
    ----------
    einsum_nodes : list
        List of einsum nodes to be calculated based on the dimension tree.
    input_nodes : list
        List of input nodes whose contraction in the einsum_nodes obeys
        the sequence from the list end to the list start.

    Returns
    -------
        List of einsum nodes whose results are the same as einsum_nodes,
        while obeys the dimension tree calculation sequence.

    Examples
    --------
    >>> einsum_node_A = ad.einsum("abcd,bm,cm,dm->am", X, B, C, D)
    >>> einsum_node_B = ad.einsum("abcd,am,cm,dm->bm", X, A, C, D)
    >>> einsum_node_C = ad.einsum("abcd,am,bm,dm->cm", X, A, B, D)
    >>> dt = generate_sequential_optimal_tree([einsum_node_A, einsum_node_B, einsum_node_C], [A, B, C])
    >>> dt
    [ad.einsum('bm,abm->am', B, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,abm->bm', A, ad.einsum('cm,abcm->abm', C, ad.einsum('abcd,dm->abcm', X, D))),
    ad.einsum('am,bm,abcm->cm', A, B, ad.einsum('abcd,dm->abcm', X, D)),
    ]
    (einsum strings may be different)
    """

    if len(einsum_nodes) == 1 and len(input_nodes) == 1:
        return einsum_nodes

    new_nodes = []
    for (i, node) in enumerate(einsum_nodes):
        contract_order = input_nodes[i + 1:]
        contract_order.reverse()
        contract_order = contract_order + input_nodes[:i]
        # get the subarray that is the inputs of node
        contract_order = list(
            filter(lambda n: n in node.inputs, contract_order))

        new_nodes.append(
            generate_optimal_tree_w_constraint(node, contract_order))

    # After generate_optimal_tree_w_constraint, some einstrs are not in the canonical format,
    # needs to rewrite again for dedup
    all_nodes = find_topo_sort(new_nodes)
    with OutputInjectedMode(all_nodes):
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)

    dedup(*new_nodes)
    remove_transposes(find_topo_sort(new_nodes))
    return new_nodes
Beispiel #4
0
def test_dedup():
    """
    Dedup the tree.
    """

    a = ad.Variable(name="a", shape=[2, 2])
    b = ad.Variable(name="b", shape=[2, 2])

    c = a + b
    d = a + b
    z = c + d

    dedup(z)
    # Assert object level equivalence.
    assert z.inputs[0] == z.inputs[1]
Beispiel #5
0
def test_cpd_jtjvp_optimize(backendopt):
    dim = 3
    for datatype in backendopt:
        T.set_backend(datatype)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list
        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list
        v_A_list, _ = init_rand_cp(dim, size, rank)
        v_A_val, v_B_val, v_C_val = v_A_list

        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        for node in JtJvps:
            assert isinstance(node, ad.AddNode)
        executor_JtJvps = ad.Executor(JtJvps)

        jtjvp_val = executor_JtJvps.run(
            feed_dict={
                A: A_val,
                B: B_val,
                C: C_val,
                input_tensor: input_tensor_val,
                v_A: v_A_val,
                v_B: v_B_val,
                v_C: v_C_val
            })

        expected_hvp_val = expect_jtjvp_val(A_val, B_val, C_val, v_A_val,
                                            v_B_val, v_C_val)

        assert T.norm(jtjvp_val[0] - expected_hvp_val[0]) < 1e-8
        assert T.norm(jtjvp_val[1] - expected_hvp_val[1]) < 1e-8
        assert T.norm(jtjvp_val[2] - expected_hvp_val[2]) < 1e-8
Beispiel #6
0
def cpd_nls(size, rank, regularization=1e-7, mode='ad'):
    """
    mode: ad / optimized / jax
    """
    assert mode in {'ad', 'jax', 'optimized'}

    dim = 3

    for datatype in BACKEND_TYPES:
        T.set_backend(datatype)
        T.seed(1)

        A_list, input_tensor, loss, residual = cpd_graph(dim, size, rank)
        A, B, C = A_list

        v_A = ad.Variable(name="v_A", shape=[size, rank])
        v_B = ad.Variable(name="v_B", shape=[size, rank])
        v_C = ad.Variable(name="v_C", shape=[size, rank])
        grads = ad.gradients(loss, [A, B, C])
        JtJvps = ad.jtjvps(output_node=residual,
                           node_list=[A, B, C],
                           vector_list=[v_A, v_B, v_C])

        A_list, input_tensor_val = init_rand_cp(dim, size, rank)
        A_val, B_val, C_val = A_list

        if mode == 'jax':
            from source import SourceToSource
            StS = SourceToSource()
            StS.forward(JtJvps,
                        file=open("examples/jax_jtjvp.py", "w"),
                        function_name='jtjvp',
                        backend='jax')

        executor_grads = ad.Executor([loss] + grads)
        JtJvps = [optimize(JtJvp) for JtJvp in JtJvps]
        dedup(*JtJvps)
        executor_JtJvps = ad.Executor(JtJvps)
        optimizer = cp_nls_optimizer(input_tensor_val, [A_val, B_val, C_val])

        regu_increase = False
        normT = T.norm(input_tensor_val)
        time_all, fitness = 0., 0.

        for i in range(10):

            t0 = time.time()

            def hess_fn(v):
                if mode == 'optimized':
                    from examples.cpd_jtjvp_optimized import jtjvp
                    return jtjvp([v[0], B_val, C_val, v[1], A_val, v[2]])
                elif mode == 'ad':
                    return executor_JtJvps.run(
                        feed_dict={
                            A: A_val,
                            B: B_val,
                            C: C_val,
                            input_tensor: input_tensor_val,
                            v_A: v[0],
                            v_B: v[1],
                            v_C: v[2]
                        })
                elif mode == 'jax':
                    from examples.jax_jtjvp import jtjvp
                    return jtjvp([B_val, C_val, v[0], A_val, v[1], v[2]])

            loss_val, grad_A_val, grad_B_val, grad_C_val = executor_grads.run(
                feed_dict={
                    A: A_val,
                    B: B_val,
                    C: C_val,
                    input_tensor: input_tensor_val
                })

            res = math.sqrt(loss_val)
            fitness = 1 - res / normT
            print(f"[ {i} ] Residual is {res} fitness is: {fitness}")
            print(f"Regularization is: {regularization}")

            [A_val, B_val, C_val], total_cg_time = optimizer.step(
                hess_fn=hess_fn,
                grads=[grad_A_val / 2, grad_B_val / 2, grad_C_val / 2],
                regularization=regularization)

            t1 = time.time()
            print(f"[ {i} ] Sweep took {t1 - t0} seconds")
            time_all += t1 - t0

            if regularization < 1e-07:
                regu_increase = True
            elif regularization > 1:
                regu_increase = False
            if regu_increase:
                regularization = regularization * 2
            else:
                regularization = regularization / 2

        return total_cg_time, fitness