コード例 #1
0
def test_tree_distribution_two_layers(dist_op, backendopt):
    """
        [Distributive] ((A + B) * G) * C

        will produce
        
        AGC + BGC

        Note that (A+B) * G is contracted first.
    """

    for datatype in backendopt:
        if datatype == "taco":
            # '..,kk,..->..' is not supported in taco
            continue
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        g = ad.Variable(name="g", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        interm = ad.einsum('ik, kk->ik', dist_op(a, b), g)
        output = ad.einsum('ik,kj->ij', interm, c)

        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c, g])
コード例 #2
0
def test_tree_distribution_ppE(dist_op, backendopt):
    """
        [Distributive] ((A + B) + C) * G

        will produce
        
        AG + BG + CG

        Note that (A+B) has parent (A + B) + C.
    """

    for datatype in backendopt:
        if datatype == "taco":
            # '..,kk,..->..' is not supported in taco
            continue
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[3, 2])
        g = ad.Variable(name="g", shape=[2, 2])

        output = ad.einsum('ik,kk->ik', dist_op(dist_op(a, b), c), g)

        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c, g])
コード例 #3
0
def test_tree_distribution_w_add_output(dist_op, backendopt):
    """
        Test C * (A + B) + F * (D + E)
            = (C * A + C * B) + (F * D + F * E)
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 3])
        b = ad.Variable(name="b", shape=[3, 3])
        c = ad.Variable(name="c", shape=[3, 3])

        d = ad.Variable(name="d", shape=[3, 3])
        e = ad.Variable(name="e", shape=[3, 3])
        f = ad.Variable(name="f", shape=[3, 3])

        out1 = ad.einsum('ik,kj->ij', c, dist_op(a, b))
        out2 = ad.einsum('ik,kj->ij', d, dist_op(e, f))
        output = dist_op(out1, out2)
        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)
        for input_node in new_output.inputs:
            assert isinstance(input_node, dist_op)
        assert tree_eq(output, new_output, [a, b, c, d, e, f])
コード例 #4
0
def test_tree_distribution_order(dist_op, backendopt):
    """
        [Distributive]
        Test C * (A + B) = C * A + C * B
    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        output = ad.einsum('ik,kj->ij', c, dist_op(a, b))
        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c])
コード例 #5
0
def test_tree_distribution_mim(dist_op, backendopt):
    """
        [Distributive] (A + B) * G * (C + D) 

        will produce
        
        AGC + BGD + BGC + DGB

        Note that G must be in the middle.

        We do the following, 
        (A+B)*G*(C+D)
        = A*G*(C+D) + B*G*(C+D)
        = AGC + AGD + BGC + BGD

        Mim: man in the middle

    """

    for datatype in backendopt:
        if datatype == "taco":
            # '..,kk,..->..' is not supported in taco
            continue
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        g = ad.Variable(name="g", shape=[2, 2])
        c = ad.Variable(name="c", shape=[2, 3])
        d = ad.Variable(name="d", shape=[2, 3])

        add_nodeab = dist_op(a, b)
        add_nodecd = dist_op(c, d)
        output = ad.einsum('ik,kk,kj->ij', add_nodeab, g, add_nodecd)

        new_output = distribute_tree(output)

        assert isinstance(new_output, dist_op)
        for node in new_output.inputs:
            assert isinstance(node, dist_op)

        assert tree_eq(output, new_output, [a, b, c, d, g])
コード例 #6
0
def test_tree_distribution_four_terms(dist_op, backendopt):
    """
        [Distributive] (A + B) * (C + D) 
        A    B     C     D   inputs 
         \   |     |    /
          \  |     |   /
           \ |     |  /
           A + B   C+D
             \     |
              \    |
               \   |
               output

        will produce
        
        AC + BD + BC + DB

    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])
        d = ad.Variable(name="d", shape=[2, 3])

        dist_nodeab = dist_op(a, b)
        dist_nodecd = dist_op(c, d)
        output = ad.einsum('ik,kj->ij', dist_nodeab, dist_nodecd)

        # Idea:
        # (A + B) * (C + D) = A * (C + D) + B * (C + D)
        # Then do A * (C + D) and B * (C + D)
        new_output = distribute_tree(output)

        assert isinstance(new_output, dist_op)
        add1, add2 = new_output.inputs
        assert isinstance(add1, dist_op)
        assert isinstance(add2, dist_op)

        assert tree_eq(output, new_output, [a, b, c, d])
コード例 #7
0
def test_tree_distribution(dist_op, backendopt):
    """
        [Distributive] An einsum graph like
        A    B     C  inputs 
         \   |     |
          \  |     |
           \ |     |
           A + B   |
             \     |
              \    |
               \   |
               output

        will produce

        A    C    B inputs 
         \   |\   |
          \  | \  |
           \ |  \ |
            AC   BC
             \    |
              \   |
               \  |
               output (AC + BC)
        put in workd (A + B) * C = A * C + B * C where * is an einsum node.

    """

    for datatype in backendopt:
        T.set_backend(datatype)

        a = ad.Variable(name="a", shape=[3, 2])
        b = ad.Variable(name="b", shape=[3, 2])
        c = ad.Variable(name="c", shape=[2, 3])

        add_node = dist_op(a, b)
        output = ad.einsum('ik,kj->ij', add_node, c)

        new_output = distribute_tree(output)
        assert isinstance(new_output, dist_op)

        assert tree_eq(output, new_output, [a, b, c])