def _distribute(binary_op_node, output):
    """ Distribute the operations. E.g (A + B) * C = A * C + B * C 

    Currently only consider the case where the binary_op is plus node.

    Args:
        binary_op_node: This is the plus node
        output: This is the (A + B) * C node.
    Return:
        The new output node that already distribute the computation.
    """
    assert isinstance(binary_op_node, ad.DistributiveNode)
    assert isinstance(output, ad.EinsumNode)
    assert binary_op_node in output.inputs

    # Then find the childs, the binary op should only have two.
    A, B = binary_op_node.inputs
    AC_seq = [
        tmp if tmp.name != binary_op_node.name else A for tmp in output.inputs
    ]
    BC_seq = [
        tmp if tmp.name != binary_op_node.name else B for tmp in output.inputs
    ]
    AC = ad.einsum(output.einsum_subscripts, *AC_seq)
    BC = ad.einsum(output.einsum_subscripts, *BC_seq)
    return type(binary_op_node)(AC, BC)
def test_tree_distribution_w_add_output(dist_op, backendopt):
    """
        Test C * (A + B) + F * (D + E)
            = (C * A + C * B) + (F * D + F * E)
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 3])
        c = ad.Variable(name="c", shape=[3, 3])

        d = ad.Variable(name="d", shape=[3, 3])
        e = ad.Variable(name="e", shape=[3, 3])
        f = ad.Variable(name="f", shape=[3, 3])

        out1 = ad.einsum('ik,kj->ij', c, dist_op(a, b))
        out2 = ad.einsum('ik,kj->ij', d, dist_op(e, f))
        output = dist_op(out1, out2)
        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)
        for input_node in new_output.inputs:
            assert isinstance(input_node, dist_op)
        assert tree_eq(output, new_output, [a, b, c, d, e, f])
def test_einsum_multiuse(backendopt):
    """
        Test manual fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Note that here we assume A is split into 2 vars by some other operations.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        a_copy = ad.Variable(name="a2", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a_copy, c)
        # New graph
        out_new = fuse_einsums(output, [a, a_copy, b])
        assert tree_eq(output, out_new, [a, a_copy, b])
def test_einsum_fuse_w_identity(backendopt):
    """
        [Fuse einsum with multiple identities]
        We want to fuse
            A   identity  identity
            |    \       /
            |     \     /
            |      \   /
            |       es
            |     /
            |    /
            |   /
            |  /
            es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', a, es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [a])
def test_einsum_fuse_graph(backendopt):
    """
        [Fuse einsum used twice]
        This case is rather subtle.
        We want to auto fuse
            A   B   C 
            |    \ /  
            |     es  
            |    /|   
            |  /  |   
            es    |   
              \   |   
                es
        Here es is einsum.
    """

    for datatype in backendopt:
        T.set_backend(datatype)
        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        BC = ad.einsum('ik, kj->ij', b, c)  # 3x3

        ABC = ad.einsum('ik, kj->ij', a, BC)  # 3x3

        out = ad.einsum('jk, ki->ji', ABC, BC)  # 3x3

        linearize(out)
        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_z = fuse_einsums(out.node, ins)

        assert tree_eq(out.node, new_z, [a, b, c])
def test_einsum_multiuse_auto_copy(backendopt):
    """
        Test autolinearization and auto fuse.
        A    B   inputs 
        |\   |
        | \  |
        |  \ |
        |   C
        |  / 
        | /
        output

        Next: we would need to autoprune.
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a1", shape=[3, 2])
        b = ad.Variable(name="b", shape=[2, 3])

        c = ad.einsum('ik,kj->ij', a, b)
        output = ad.einsum('ik,ij->kj', a, c)

        linearize(output)
        all_nodes = find_topo_sort([output])
        cloned_nodes = [
            tmp for tmp in all_nodes if isinstance(tmp, ad.CloneNode)
        ]

        out_new = fuse_einsums(output, [*cloned_nodes, b])
        # Test that every inputs is now fused.
        assert all([not isinstance(x, ad.EinsumNode) for x in out_new.inputs])

        assert tree_eq(output, out_new, [*cloned_nodes, b])
def test_get_transpose_indices_dup():
    a = ad.Variable(name='a', shape=[2, 2])
    h = ad.Variable(name='h', shape=[2, 2, 2])
    out1 = ad.einsum("ad,bc,ecd->abe", a, a, h)
    out2 = ad.einsum("ac,bd,ecd->eab", a, a, h)
    trans = get_transpose_indices(out1, out2)
    assert trans == [2, 0, 1] or trans == [2, 1, 0]
Exemple #8
0
def test_executor_dependent(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[3, 3])
        B = ad.Variable(name="B", shape=[3, 3])
        AA = ad.einsum('ab,ab->', A, A)
        BB = ad.einsum('ab,ab->', B, B)
        AB = ad.einsum('ab,ab->', A, B)

        out_A = AA + AB
        out_B = AB + AA

        executor = ad.Executor({out_A, out_B})

        data = gen_dict([A, B])
        A_val, = executor.run(feed_dict=data,
                              reset_graph=False,
                              out_nodes=[out_A])
        data2 = gen_dict([A])
        data2.update({B: data[B]})
        B_val, = executor.run(feed_dict=data2, out_nodes=[out_B])
        # This is checking A's val is not reused in B_val computationA.
        assert A_val != B_val
def test_collapse_symmetric_expr_complex():
    """
    out1:
    A1 - a - A2 - b - A3
    |         |        |
    c         d        e
    |         |        |
    H1 - f - H2 - g - H3
    |         |        |
    h         i        j
    out2:
    a         b        c
    |         |        |
    H1 - d - H2 - e - H3
    |         |        |
    f         g        h
    A1 - i - A2 - j - A3
    """
    H1 = ad.Variable(name="H1", shape=[2, 2, 2], symmetry=[[0, 2]])
    H2 = ad.Variable(name="H2", shape=[2, 2, 2, 2], symmetry=[[0, 2]])
    H3 = ad.Variable(name="H3", shape=[2, 2, 2], symmetry=[[0, 1]])

    A1 = ad.Variable(name="H1", shape=[2, 2])
    A2 = ad.Variable(name="H2", shape=[2, 2, 2])
    A3 = ad.Variable(name="H3", shape=[2, 2])

    out1 = ad.einsum("ca,dab,eb,cfh,dgif,ejg->hij", A1, A2, A3, H1, H2, H3)
    out2 = ad.einsum("fi,gij,hj,adf,begd,che->abc", A1, A2, A3, H1, H2, H3)

    collapse_symmetric_expr(out1, out2)

    assert out1.name == out2.name
def test_prune_identity(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        a1 = ad.Variable(name="a1", shape=[3, 3])
        a2 = ad.Variable(name="a2", shape=[3, 3])
        i1 = ad.identity(3)
        i2 = ad.identity(3)
        i3 = ad.identity(3)

        out = ad.einsum("ab,cd,ac,be,ef->abdf", a1, a2, i1, i2, i3)
        prune_identity_nodes(out)
        """
        Explanation to the einsum above:
        The identity node i1 means that a and c should be the same dim.
        we can get rid of i1 and rewrite the expr as 
        ad.einsum("ab,ad,be,ef->abdf", a1, a2, i2, i3).
        we can also combine i2 and i3 cuz e is useless. Therefore,
        we can rewrite the expr as
        ad.einsum("ab,ad,bf->abdf", a1, a2, i2).
        """
        out_expect = ad.einsum("ab,ad,bf->abdf", a1, a2, i2)
        assert len(out.inputs) == 3

        assert tree_eq(out, out_expect, [a1, a2])
def test_tree_distribution_two_layers(dist_op, backendopt):
    """
        [Distributive] ((A + B) * G) * C

        will produce
        
        AGC + BGC

        Note that (A+B) * G is contracted first.
    """

    for datatype in backendopt:
        if datatype == "taco":
            # '..,kk,..->..' is not supported in taco
            continue
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        g = ad.Variable(name="g", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        interm = ad.einsum('ik, kk->ik', dist_op(a, b), g)
        output = ad.einsum('ik,kj->ij', interm, c)

        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c, g])
Exemple #12
0
def test_hvp2(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)
        x = ad.Variable(name="x", shape=[3, 1])
        H = ad.Variable(name="H", shape=[3, 3])
        v = ad.Variable(name="v", shape=[3, 1])
        y = ad.sum(
            ad.einsum("ab,bc->ac", ad.einsum("ab,bc->ac", ad.transpose(x), H),
                      x))

        grad_x, = ad.gradients(y, [x])
        Hv, = ad.hvp(output_node=y, node_list=[x], vector_list=[v])

        executor = ad.Executor([y, grad_x, Hv])
        x_val = T.tensor([[1.], [2.], [3]])  # 3x1
        v_val = T.tensor([[1.], [2.], [3]])  # 3x1
        H_val = T.tensor([[2., 0., 0.], [0., 2., 0.], [0., 0., 2.]])  # 3x3
        y_val, grad_x_val, Hv_val = executor.run(feed_dict={
            x: x_val,
            H: H_val,
            v: v_val
        })
        Hx = T.dot(H_val, x_val)
        expected_yval = T.sum(T.dot(T.transpose(x_val), Hx))
        expected_grad_x_val = 2 * Hx
        expected_hv_val = T.tensor([[4.], [8.], [12.]])

        assert isinstance(y, ad.Node)
        assert T.array_equal(y_val, expected_yval)
        assert T.array_equal(grad_x_val, expected_grad_x_val)
        assert T.array_equal(Hv_val, expected_hv_val)
Exemple #13
0
def test_prune_inv_different_num_inputs_no_pruning(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])

    inv_input = ad.einsum('ab,bc->ac', A, A)
    output = ad.einsum('ab,bc->ac', ad.tensorinv(inv_input, ind=1), A)
    new_output = prune_inv_node(output)

    assert new_output is output
Exemple #14
0
def test_prune_inv_set_not_match(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])

    inv = ad.tensorinv(ad.einsum('ab,bc->ac', A, B), ind=1)
    output = ad.einsum('ab,bc->ac', inv, A)
    new_output = prune_inv_node(output)

    assert new_output is output
Exemple #15
0
def test_get_common_ancester_intermediate_leaves(backendopt):

    a = ad.Variable(name="a", shape=[2, 2])
    b = ad.Variable(name="b", shape=[2, 2])
    c = ad.einsum("ab,bc->ac", a, b)
    d = ad.einsum("ab,ab->ab", c, c)

    ancester = get_common_ancestor(d, d.inputs, c)
    assert ancester == d
def test_collapse_expr_w_identity():
    a = ad.Variable(name="a", shape=[2, 2])
    I = ad.identity(2)

    out1 = ad.einsum("ab,bc->ac", a, I)
    out2 = ad.einsum("ab,cb->ac", a, I)

    collapse_symmetric_expr(out1, out2)

    assert out1.name == out2.name
def test_cannot_collapse_expr():
    h = ad.Variable(name="h", shape=[2, 2, 2, 2])
    a = ad.Variable(name="a", shape=[2, 2])

    out1 = ad.einsum("ijkl,ik->jl", h, a)
    out2 = ad.einsum("ijkl,jl->ik", h, a)

    collapse_symmetric_expr(out1, out2)

    assert out1.name != out2.name
def test_collapse_symmetry_w_multiple_contraction_ind():

    H = ad.Variable(name="H", shape=[2, 2], symmetry=[[0, 1]])
    x1 = ad.Variable(name="x1", shape=[2])
    x2 = ad.Variable(name="x2", shape=[2])

    inner1 = ad.einsum("ab,a,b->", H, x1, x2)
    inner2 = ad.einsum("ab,b,a->", H, x1, x2)

    collapse_symmetric_expr(inner1, inner2)
    assert inner1.name == inner2.name
Exemple #19
0
def test_simplify_optimize_w_tail_einsum():

    A = ad.Variable(name="A", shape=[2, 2])

    out = ad.einsum("ab,bc->ac", A,
                    ad.einsum("ab,bc->ac", ad.identity(2), ad.identity(2)))
    newout_optimize = optimize(out)
    newout_simplify = simplify(out)

    assert newout_optimize == A
    assert newout_simplify == A
Exemple #20
0
def test_einsum_rewrite_duplicate_input():

    a = ad.Variable(name="a", shape=[3, 2])

    x = ad.einsum('ca,cb->ab', a, a)
    y = ad.einsum('cb,ca->ab', a, a)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
Exemple #21
0
def test_prune_inv_nonmatmul_no_pruning(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])

    inv_input = ad.einsum('ab,bc->ac', A, B)
    # inv(inv_input) * inv_input.T, cannot be pruned
    output = ad.einsum('ac,ab,bc->ac', ad.tensorinv(inv_input, ind=1), A, B)

    new_output = prune_inv_node(output)

    assert new_output is output
def test_remove_transposes_multiple_trans():
    a = ad.Variable(name="a", shape=[2, 2, 2, 2])

    intermediate1 = ad.einsum("abcd->dcba", a)
    intermediate2 = ad.einsum("abcd->abdc", a)

    ret1 = ad.einsum("dcba->badc", intermediate1)
    ret2 = ad.einsum("abdc->badc", intermediate2)

    remove_transposes(find_topo_sort([ret1, ret2]))

    assert ret1.name == ret2.name
Exemple #23
0
def test_einsum_equal_repeated_transpose():

    A = ad.Variable(name="A", shape=[3, 5])

    x = ad.einsum('or,ob->br', A, A)
    y = ad.einsum('eb,ed->bd', A, A)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
def test_einsum_fuse_only_identity(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        es_identity = ad.einsum('ik,kj->ij', ad.identity(3), ad.identity(3))
        out = ad.einsum('ai,ij->aj', ad.identity(3), es_identity)

        tree, = find_sub_einsumtree(PseudoNode(out))
        out, ins = tree
        new_out = fuse_einsums(out.node, ins)
        assert tree_eq(out.node, new_out, [])
Exemple #25
0
def test_simplify_symmetric_einsum_expr():

    H = ad.Variable(name="H", shape=[2, 2], symmetry=[[0, 1]])
    x1 = ad.Variable(name="x1", shape=[2])
    x2 = ad.Variable(name="x2", shape=[2])

    inner1 = ad.einsum("ab,a,b->", H, x1, x2)
    inner2 = ad.einsum("ab,b,a->", H, x1, x2)
    out = 0.5 * inner1 + 0.5 * inner2
    newout_simplify = simplify(out)

    # ad.einsum("ab,a,b->", H, x1, x2) or ad.einsum("ab,b,a->", H, x1, x2)
    assert isinstance(newout_simplify, ad.EinsumNode)
Exemple #26
0
def test_einsum_equal():

    a1 = ad.Variable(name="a1", shape=[3, 2])
    a2 = ad.Variable(name="a2", shape=[2, 3])

    x = ad.einsum('ik,kj->ij', a1, a2)
    y = ad.einsum('ml,sm->sl', a2, a1)

    rewrite_einsum_expr(x)
    rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
Exemple #27
0
def test_get_common_ancestor_w_inv(backendopt):

    A = ad.Variable(name="A", shape=[3, 3])
    X = ad.Variable(name="X", shape=[3, 3, 3])
    inv = ad.tensorinv(ad.einsum("ab,ac->bc", A, A), ind=1)
    einsum_node = ad.einsum('abc,ad,ce->bce', X, A, inv)
    opt_einsum = generate_optimal_tree(einsum_node)
    sub_einsum = get_common_ancestor(opt_einsum, einsum_node.inputs, A)

    # sub_einsum should be ad.einsum('ad,abc->dbc',A,X), and shouldn't include the inv node.
    assert sorted(get_all_inputs(sub_einsum),
                  key=lambda node: node.name) == sorted(
                      [A, X], key=lambda node: node.name)
Exemple #28
0
def test_einsum_equal_repeated_transpose():

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])

    x = ad.einsum("ac,ba,bc->", A, A, B)
    y = ad.einsum("ba,ac,bc->", A, A, B)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs
Exemple #29
0
def test_prune_inv_nodes_transpose(backendopt):
    for datatype in backendopt:
        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])

        inv_input = ad.einsum('ab,bc->ca', A, B)
        # inv(inv_input.T) @ inv_input.T
        output = ad.einsum('ca,cd,de->ae', ad.tensorinv(inv_input, ind=1), A,
                           B)
        new_output = prune_inv_node(output)

        assert isinstance(new_output, ad.IdentityNode)
        assert tree_eq(output, new_output, [A, B], tol=1e-6)
Exemple #30
0
def test_einsum_equal_uf_assign_order():

    A = ad.Variable(name="A", shape=[3, 3])
    B = ad.Variable(name="B", shape=[3, 3])
    I = ad.identity(10)

    x = ad.einsum('pb,or,ob,pr,st->srtb', B, A, A, B, I)
    y = ad.einsum('eb,ed,fb,fd,ac->abcd', A, A, B, B, I)

    uf1 = rewrite_einsum_expr(x)
    uf2 = rewrite_einsum_expr(y)

    assert x.einsum_subscripts == y.einsum_subscripts
    assert x.inputs == y.inputs