Exemplo n.º 1
0
def test_prune_inv_no_inv(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])

    output = ad.einsum('ab,bc->ac', A, B)
    new_output = prune_inv_node(output)

    assert new_output is output
Exemplo n.º 2
0
def test_prune_inv_different_num_inputs_no_pruning(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])

    inv_input = ad.einsum('ab,bc->ac', A, A)
    output = ad.einsum('ab,bc->ac', ad.tensorinv(inv_input, ind=1), A)
    new_output = prune_inv_node(output)

    assert new_output is output
Exemplo n.º 3
0
def test_prune_inv_set_not_match(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])

    inv = ad.tensorinv(ad.einsum('ab,bc->ac', A, B), ind=1)
    output = ad.einsum('ab,bc->ac', inv, A)
    new_output = prune_inv_node(output)

    assert new_output is output
Exemplo n.º 4
0
def test_prune_inv_nonmatmul_no_pruning(backendopt):
    A = ad.Variable(name="A", shape=[2, 2])
    B = ad.Variable(name="B", shape=[2, 2])

    inv_input = ad.einsum('ab,bc->ac', A, B)
    # inv(inv_input) * inv_input.T, cannot be pruned
    output = ad.einsum('ac,ab,bc->ac', ad.tensorinv(inv_input, ind=1), A, B)

    new_output = prune_inv_node(output)

    assert new_output is output
Exemplo n.º 5
0
def test_prune_inv_nodes_transpose(backendopt):
    for datatype in backendopt:
        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])

        inv_input = ad.einsum('ab,bc->ca', A, B)
        # inv(inv_input.T) @ inv_input.T
        output = ad.einsum('ca,cd,de->ae', ad.tensorinv(inv_input, ind=1), A,
                           B)
        new_output = prune_inv_node(output)

        assert isinstance(new_output, ad.IdentityNode)
        assert tree_eq(output, new_output, [A, B], tol=1e-6)
Exemplo n.º 6
0
def test_prune_inv_multiple_inv(backendopt):
    for datatype in backendopt:
        A0 = ad.Variable(name="A0", shape=[2, 2])
        A1 = ad.Variable(name="A1", shape=[2, 2])
        A2 = ad.Variable(name="A2", shape=[2, 2])

        out = ad.einsum('ab,bc,cd,de,ef,fg,gh->ah', A0, A1, A1,
                        ad.tensorinv(ad.einsum('ab,bc->ac', A1, A1), ind=1),
                        A2, A2,
                        ad.tensorinv(ad.einsum('ab,bc->ac', A2, A2), ind=1))
        new_out = prune_inv_node(out)

        for node in new_out.inputs:
            assert not isinstance(node, ad.EinsumNode)

        assert tree_eq(out, new_out, [A0, A1, A2], tol=1e-6)
Exemplo n.º 7
0
def test_prune_inv_nodes_cpd(backendopt):
    for datatype in backendopt:
        A = ad.Variable(name="A", shape=[2, 2])
        B = ad.Variable(name="B", shape=[2, 2])
        C = ad.Variable(name="C", shape=[2, 2])

        inv_input = ad.einsum('ab,dc,ac,db->bc', B, C, B, C)
        output = ad.einsum('ed,ea,cd,ba,ca,gd->bg', C, C, B, A, B,
                           ad.tensorinv(inv_input, ind=1))

        new_output = prune_inv_node(output)

        # T.einsum('ba,ag->bg',A,T.identity(2))
        assert len(new_output.inputs) == 2
        for node in new_output.inputs:
            if isinstance(node, ad.VariableNode):
                assert node == A

        assert tree_eq(output, new_output, [A, B, C], tol=1e-6)
Exemplo n.º 8
0
def simplify(output_node):
    """Simplify a graph with a single output node.
    The simplified form will distribute selected operations
    (+), and fuse all connected einsums.

    Args:
        node: The output node.
    Returns:
        node: The newly generated node.
    """
    def fuse_all_einsums(node):
        linearize(node)
        ret_node = PseudoNode(node)
        all_pnodes = find_topo_sort_p([ret_node])
        with OutputInjectedModeP(all_pnodes):
            trees = find_sub_einsumtree(ret_node)
            for tree in trees:
                out_node_p, in_nodes = tree
                new_z = fuse_einsums(out_node_p.node, in_nodes)
                prune_identity_nodes(new_z)
                replace_node(out_node_p, new_z)

        node = declone(ret_node.node)
        return node

    output_node = distribute_graph_w_linearize(output_node)
    output_node = fuse_all_einsums(output_node)

    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    # optimize inverse
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same same,
                # so that we can collapse the add node.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.TensorInverseNode):
                new_inv_node = optimize_inverse(node)
                replace_node(pnode, new_inv_node)

    # fuse again
    output_node = output_pnode.node
    output_node = fuse_all_einsums(output_node)

    # prune the orthonormal matmuls
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_orthonormal_matmuls(node)
                replace_node(pnode, new_node)

    # prune inverse nodes
    output_pnode = PseudoNode(output_node)
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                new_node = prune_inv_node(node)
                replace_node(pnode, new_node)

    # prune the scalar nodes and remove unnecessary identity nodes
    all_pnodes = find_topo_sort_p([output_pnode])
    with OutputInjectedModeP(all_pnodes):
        for pnode in all_pnodes:
            node = pnode.node
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if isinstance(node, ad.EinsumNode):
                prune_identity_nodes(node)
                new_node = prune_scalar_nodes(node)
                replace_node(pnode, new_node)

    # collapse symmetric expressions
    all_pnodes = find_topo_sort_p([output_pnode])
    for i in range(len(all_pnodes)):
        for j in range(i):
            collapse_symmetric_expr(all_pnodes[i].node, all_pnodes[j].node)

    sympy_input_types = (ad.DistributiveNode, ad.ScalarNode, ad.MulNode)
    #sympy_simplify the distributed nodes
    if isinstance(output_node, ad.DistributiveNode):
        sympy_inputs = []
        all_nodes = find_topo_sort([output_node])
        for node in all_nodes:
            if isinstance(node, ad.EinsumNode):
                # To make sure the same einsum nodes have the same name,
                # so that they can be reduced by sympy.
                rewrite_einsum_expr(node)
            if node.inputs != []:
                node.set_inputs(node.inputs)
            if not isinstance(node, sympy_input_types):
                sympy_inputs.append(node)
        output_node = sympy_simplify(output_node, sympy_inputs)

    return output_node