示例#1
0
def test_kronecker_product_non_even(backendopt):

    A = ad.Variable(name="A", shape=[4, 4, 2, 2])
    B = ad.Variable(name="B", shape=[2, 2])

    out = ad.einsum("abcd,ef->abcdef", A, B)
    inv = ad.tensorinv(out, ind=2)
    newinv = optimize_inverse(inv)

    assert inv is newinv
示例#2
0
def test_kronecker_product_nondecomposable(backendopt):

    A = ad.Variable(name="A", shape=[2, 3])
    B = ad.Variable(name="B", shape=[3, 2])

    out = ad.einsum("ab,cd->acbd", A, B)
    inv = ad.tensorinv(out)
    newinv = optimize_inverse(inv)

    assert inv is newinv
示例#3
0
def test_kronecker_product_repeated_inputs(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])

        out = ad.einsum("ab,cd->acbd", A, A)
        inv = ad.tensorinv(out)
        newinv = optimize_inverse(inv)

        assert isinstance(newinv, ad.EinsumNode)
        for node in newinv.inputs:
            assert isinstance(node, ad.TensorInverseNode)

        assert tree_eq(inv, newinv, [A], tol=1e-5)
示例#4
0
def test_inv_multiple_decomposation(backendopt):
    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])
        C = ad.Variable(name="C", shape=[2, 2])

        out = ad.einsum("ab,cd,ef->acebdf", A, B, C)
        inv = ad.tensorinv(out)
        newinv = optimize_inverse(inv)

        assert isinstance(newinv, ad.EinsumNode)
        for node in newinv.inputs:
            assert isinstance(node, ad.TensorInverseNode)
        assert len(newinv.inputs) == 3

        assert tree_eq(inv, newinv, [A, B, C], tol=1e-5)
示例#5
0
def test_high_dim_inv(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2, 2, 2])
        B = ad.Variable(name="B", shape=[2, 2, 2, 2])

        out = ad.einsum("aceg,dbhf->abcdefgh", A, B)
        inv = ad.tensorinv(out)
        # T.einsum('aceg,bdfh->abcdefgh',
        #     T.tensorinv(T.einsum('aceg->aceg',A), ind=2),
        #     T.tensorinv(T.einsum('dbhf->bdfh',B), ind=2))
        newinv = optimize_inverse(inv)

        assert isinstance(newinv, ad.EinsumNode)
        for node in newinv.inputs:
            assert isinstance(node, ad.TensorInverseNode)

        assert tree_eq(inv, newinv, [A, B], tol=1e-6)
示例#6
0
def test_complex_product_inv(backendopt):

    for datatype in backendopt:
        T.set_backend(datatype)

        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])
        C = ad.Variable(name="C", shape=[2, 2])
        D = ad.Variable(name="D", shape=[2, 2])

        out = ad.einsum("ab,bc,de,ef->adcf", A, B, C, D)
        inv = ad.tensorinv(out)
        # T.einsum('ac,df->adcf',
        #     T.tensorinv(T.einsum('ab,bc->ac',A,B), ind=1),
        #     T.tensorinv(T.einsum('de,ef->df',C,D), ind=1))
        newinv = optimize_inverse(inv)

        assert isinstance(newinv, ad.EinsumNode)
        for node in newinv.inputs:
            assert isinstance(node, ad.TensorInverseNode)

        assert tree_eq(inv, newinv, [A, B, C, D], tol=1e-5)
示例#7
0
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node