Example #1
0
def cross_einsum_connect(uf, output_node, dims_info):
    """
        Link the literal relationship for an einsum op.
        
        Args: 
            uf: union find data structure.
            output_node: An einsum node.
            dims_info: A list of all the dimensions information including the output_node.
        
        Inputs of the einsum node can have duplicates.
    """
    assert (isinstance(output_node, ad.EinsumNode))
    # for child in output_node.inputs:
    #     assert (isinstance(child, ad.EinsumNode))

    in_subs, out_subs, _ = _parse_einsum_input(
        (output_node.einsum_subscripts, *output_node.inputs))
    in_subs_list = in_subs.split(',')
    whole_str = out_subs + "".join(in_subs_list)

    record = {}

    for pos, pair in enumerate(zip(whole_str, dims_info)):
        char, litername = pair
        if char in record:
            # encode
            uf.connect(litername, record[char])
        else:
            record[char] = litername
Example #2
0
def dedup_transpose(graph, node, trans_node, trans_indices):
    """
    Replace the node with the trans_node, and change its output nodes in graph accordingly.

    Parameters
    ----------
    graph: list of nodes denoting a connected graph.
    node: node to be replaced.
    trans_node: the transposed node that will replace node.
    trans_indices: the transpose indices.
    """
    assert node in graph
    assert trans_node in graph

    with OutputInjectedModeP([PseudoNode(n) for n in graph]):
        for onode in node.outputs:
            # NOTE: currently we cannot deal with non-einsum nodes.
            assert isinstance(onode, ad.EinsumNode)
            in_subs, out_subs, _ = _parse_einsum_input(
                (onode.einsum_subscripts, *onode.inputs))
            in_subs_list = in_subs.split(',')
            for (i, n) in enumerate(onode.inputs):
                if n is node:
                    onode.inputs[i] = trans_node
                    str_list = list(in_subs_list[i])
                    in_subs_list[i] = "".join(
                        [str_list[j] for j in trans_indices])

            new_subscripts = ",".join(in_subs_list) + "->" + out_subs
            onode.einsum_subscripts = new_subscripts
            onode.set_inputs(onode.inputs)
Example #3
0
def prune_scalar_nodes(einsum_node):
    """
        Remove the scalar input nodes of a einsum_node.
        Args:
            einsum_node: An fused einsum node.
        Return:
            both the scalar and the pruned einsum node.
    """
    in_subs, out_subs, _ = _parse_einsum_input(
        (einsum_node.einsum_subscripts, *einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    new_inputs, new_input_subs, scalars = [], [], []

    for i in range(len(in_subs_list)):
        if in_subs_list[i] == "" and isinstance(einsum_node.inputs[i],
                                                ad.ScalarNode):
            scalars.append(einsum_node.inputs[i].value)
        else:
            new_inputs.append(einsum_node.inputs[i])
            new_input_subs.append(in_subs_list[i])

    scalar = np.prod(scalars)

    new_subscripts = ",".join(new_input_subs) + "->" + out_subs
    output_node = ad.einsum(new_subscripts, *new_inputs)

    if scalar == 1.:
        return output_node
    else:
        return scalar * output_node
Example #4
0
 def apl(node, subscripts, x):
     operands = []
     operands.append(subscripts)
     operands.extend(x)
     # we use the internal einsum function to parse subscripts to a normal form.
     sub_op, sub_y, operands = _parse_einsum_input(operands)
     sub_op = sub_op.split(',')
     return dict(y=numpy.einsum(_join_einsum_sub(sub_op, sub_y), *x), sub_op=sub_op, sub_y=sub_y, x=x)
def split_inv_einsum(inv_node):
    """
    Optimize the inverse of an einsum expression, such that
    inverse is operated on several smaller tensors.

    Parameters
    ----------
    node: The inverse of a fused einsum node

    Returns
    -------
    If the input node cannot be optimized, then return the input node.
    If it can be optimized, return the optimized node.

    """
    einsum_node = inv_node.inputs[0]
    assert isinstance(einsum_node, ad.EinsumNode)
    # einsum_node is a fused einsum
    for node in einsum_node.inputs:
        assert not isinstance(node, ad.EinsumNode)

    in_subs, out_subs, _ = _parse_einsum_input(
        (einsum_node.einsum_subscripts, *einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    p_einsum_node = PseudoNode(node=einsum_node, subscript=out_subs)
    p_in_nodes = []
    for i, node in enumerate(einsum_node.inputs):
        p_in_nodes.append(PseudoNode(node=node, subscript=in_subs_list[i]))

    dsets = inv_disjoint_sets(p_einsum_node, p_in_nodes)

    # If the node cannot be decomposed, just return the input node
    if len(dsets) == 1:
        return inv_node

    new_inputs = []
    for dset in dsets:
        input_decomp_einsum = list(
            filter(lambda node: any(char in dset for char in node.subscript),
                   p_in_nodes))
        out_subs = "".join(
            [char for char in p_einsum_node.subscript if char in dset])

        decomp_node = generate_new_einsum(input_decomp_einsum, out_subs)

        decomp_node.set_in_indices_length(int(len(out_subs) / 2))

        input_node = PseudoNode(node=ad.tensorinv(decomp_node),
                                subscript=out_subs)
        new_inputs.append(input_node)

    return generate_new_einsum(new_inputs, p_einsum_node.subscript)
Example #6
0
def morph(operation, array, reduce=None, **shape_hints):
    """ This is an experimental version of a generalized reshape.
    See test cases for examples.
    """
    operation = _normalize(operation)
    source, target = operation.split('->')

    # Expanding reshape
    array = _expanding_reshape(array, source, target, **shape_hints)

    # Initial squeeze
    squeeze_operation = operation.split('->')[0].split()
    for axis, op in reversed(list(enumerate(squeeze_operation))):
        if op == '1':
            array = np.squeeze(array, axis=axis)

    # Transpose
    transposition_operation = operation.replace('1', ' ').replace('*', ' ')
    try:
        in_shape, out_shape, (array, ) = _parse_einsum_input(
            [transposition_operation.replace(' ', ''), array])

        if len(set(in_shape) - set(out_shape)) > 0:
            assert reduce is not None, ('Missing reduce function', reduce,
                                        transposition_operation)

            reduce_axis = tuple(
                [i for i, s in enumerate(in_shape) if s not in out_shape])
            array = reduce(array, axis=reduce_axis)
            in_shape = ''.join([s for s in in_shape if s in out_shape])

        array = np.einsum(f'{in_shape}->{out_shape}', array)
    except ValueError as e:
        msg = (f'op: {transposition_operation} ({in_shape}->{out_shape}), '
               f'shape: {np.shape(array)}')

        if len(e.args) == 1:
            e.args = (e.args[0] + '\n\n' + msg, )
        else:
            print(msg)
        raise

    # Final reshape
    source = transposition_operation.split('->')[-1]
    target = operation.split('->')[-1]

    return _shrinking_reshape(array, source, target)
Example #7
0
def n_mode_eigendec(node, tensor_val, rank):
    """
    Eigendecomposition of mode-n unfolding of a input node.
    Used in Tucker decomposition to update the core tensor
    and one factor matrix.

    Parameters
    ----------
    node: the input einsum node. Note that it must be the EinsumNode
        of the core tensor node and one factor matrix node.
    tensor_val: the value of the input node
    rank: Tucker decomposition rank

    Returns
    -------
    1. the core tensor
    2. the corresponding factor matrix
    """
    assert isinstance(node, ad.EinsumNode)
    assert len(node.inputs) == 2

    in_subs, out_subs, _ = _parse_einsum_input(
        (node.einsum_subscripts, *node.inputs))
    core_subs, A_subs = in_subs.split(',')

    assert len(A_subs) == 2

    contracted_char = list(set(A_subs) - set(out_subs))[0]

    out_subs_2 = "".join(
        [char if char not in A_subs else contracted_char for char in out_subs])
    # used for tensor_val.T @ tensor_val in its matricized form
    einstr = out_subs + "," + out_subs_2 + "->" + A_subs

    Y = T.einsum(einstr, tensor_val, tensor_val)
    U, _, _ = T.svd(Y)
    U = U[:, :rank]

    einstr = out_subs + "," + A_subs + "->" + core_subs
    core = T.einsum(einstr, tensor_val, U)
    return core, U
def grad_einsum(argnum, g, ans, vs, gvs, operands, kwargs):
    if isinstance(operands[0], string_types):  # using "ijk" convention.
        in_subs, out_subs, _ = _parse_einsum_input(tuple(map(getval,
                                                             operands)))
        string, operands = operands[0], operands[1:]

        in_subs_list = in_subs.split(',')
        op_num = argnum - 1
        subs_wrt = in_subs_list[op_num]
        rest_of_ops = operands[:op_num] + operands[op_num + 1:]
        rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num + 1:]

        # subscripts that only appear in subs_wrt (and not in other subscript lists
        # or in the output) are implicitly being summed out, as if contracted
        # against a tensor of ones. we make that tensor of ones explicit to handle
        # the necessary vjp broadcasting inside einsum.
        other_named_subs = set(''.join([out_subs] + rest_of_subs))
        naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt)
                        if sub not in other_named_subs]
        if naked_summed:
            naked_summed_dims, ones_subs = zip(*naked_summed)
            ones_subs = ''.join(ones_subs)
            ones = onp.ones(
                onp.array(operands[op_num].shape)[list(naked_summed_dims)])
            new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs)
            new_operands = (g, ones) + rest_of_ops
        else:
            new_input_subs = ','.join([out_subs] + rest_of_subs)
            new_operands = (g, ) + rest_of_ops

        new_subscripts = new_input_subs + '->' + subs_wrt
        return unbroadcast(vs, gvs, anp.einsum(new_subscripts, *new_operands))
    else:  # using (op0, sublist0, op1, sublist1, ..., sublistout) convention
        if len(operands) % 2 == 0:
            raise NotImplementedError("Need sublistout argument")
        operands = list(operands)
        rest_of_ops = [operands[-1]] + operands[:argnum] + \
                operands[(argnum+2):-1] + [operands[argnum+1]]
        return unbroadcast_einsum(vs, gvs, anp.einsum(g, *rest_of_ops),
                                  operands[argnum + 1])
Example #9
0
def grad_einsum(argnum, g, ans, vs, gvs, operands, kwargs):
    if isinstance(operands[0], string_types):  # using "ijk" convention.
        in_subs, out_subs, _ = _parse_einsum_input(tuple(map(getval, operands)))
        string, operands = operands[0], operands[1:]

        in_subs_list = in_subs.split(',')
        op_num = argnum - 1
        subs_wrt = in_subs_list[op_num]
        rest_of_ops = operands[:op_num] + operands[op_num+1:]
        rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num+1:]

        # subscripts that only appear in subs_wrt (and not in other subscript lists
        # or in the output) are implicitly being summed out, as if contracted
        # against a tensor of ones. we make that tensor of ones explicit to handle
        # the necessary vjp broadcasting inside einsum.
        other_named_subs = set(''.join([out_subs] + rest_of_subs))
        naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt)
                        if sub not in other_named_subs]
        if naked_summed:
            naked_summed_dims, ones_subs = zip(*naked_summed)
            ones_subs = ''.join(ones_subs)
            ones = onp.ones(onp.array(operands[op_num].shape)[list(naked_summed_dims)])
            new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs)
            new_operands = (g, ones) + rest_of_ops
        else:
            new_input_subs = ','.join([out_subs] + rest_of_subs)
            new_operands = (g,) + rest_of_ops

        new_subscripts = new_input_subs + '->' + subs_wrt
        return unbroadcast(vs, gvs, anp.einsum(new_subscripts, *new_operands))
    else:  # using (op0, sublist0, op1, sublist1, ..., sublistout) convention
        if len(operands) % 2 == 0:
            raise NotImplementedError("Need sublistout argument")
        operands = list(operands)
        rest_of_ops = [operands[-1]] + operands[:argnum] + \
                operands[(argnum+2):-1] + [operands[argnum+1]]
        return unbroadcast_einsum(vs, gvs, anp.einsum(g, *rest_of_ops), operands[argnum + 1])
Example #10
0
def einsum(*operands, optimize=False, requires_grad=False):
    # 这段没有验证过,直接超过来的
    operands = list(operands)
    if isinstance(operands[0], str):
        # operands form: "ijk, ijk", x, y
        variables = operands[1:]
        if any(isinstance(i, Tensor) for i in operands):
            operands[1:] = (var.data if isinstance(var, Tensor) else var
                            for var in operands[1:])
    else:
        # operands form: op0, sublist0, op1, sublist1, ..., [sublistout]
        end = -1 if len(operands) % 2 else None  # -1 if sublistout is included
        variables = operands[:end:2]
        if any(isinstance(i, Tensor) for i in operands):
            operands[:end:2] = (var.data if isinstance(var, Tensor) else var
                                for var in operands[:end:2])

    in_lbls, out_lbls, _ = _parse_einsum_input(operands)
    return Tensor._op(EinSum,
                      *variables,
                      op_kwargs=dict(in_lbls=in_lbls,
                                     out_lbls=out_lbls,
                                     optimize=optimize),
                      requires_grad=requires_grad)
Example #11
0
def parse_einsum_input(*args):
    return _parse_einsum_input(args)
Example #12
0
def einsum(*operands, optimize=False, constant=False):
    r"""
    einsum(subscripts, *operands)

    Evaluates the Einstein summation convention on the operands. This implementation
    exactly mirrors that of ``numpy.einsum`` and supports back-propagation through
    all variety of tensor-products, sums, traces, and views that it can perform.

    The following docstring was adapted from the documentation for ``numpy.einsum``

    Using the Einstein summation convention, many common multi-dimensional
    array operations can be represented in a simple fashion.  This function
    provides a way to compute such summations. The best way to understand this
    function is to try the examples below, which show how many common NumPy/MyGrad
    functions can be implemented as calls to ``einsum``.

    Back-propagation via ``einsum`` is optimized such that any tensor that occurs
    redundantly within the summation will only have its gradient computed once.
    This optimization accommodates all number and combination of redundancies that can
    be encountered.

    E.g. back-propping through ``einsum('...,...->', x, x)`` will only incur a single
    computation/accumulation for ``x.grad`` rather than two. This permits users to
    leverage the efficiency of sum-reduction, where ``(x ** 2).sum()`` is sub-optimal,
    without being penalized during back-propagation.

    Parameters
    ----------
    subscripts : str
        Specifies the subscripts for summation.

    operands : array_like
        The tensors used in the summation.

    optimize : {False, True, 'greedy', 'optimal'}, optional (default=False)
        Controls if intermediate optimization should occur; also enables
        the use of BLAS where possible. This can produce significant speedups
        for computations like matrix multiplication.

        No optimization will occur if False and True will default to the 'greedy'
        algorithm. Also accepts an explicit contraction list from the
        ``np.einsum_path`` function. See ``np.einsum_path`` for more details.

    constant : bool, optional (default=False)
        If True, the resulting Tensor is a constant.

    Returns
    -------
    output : mygrad.Tensor
        The calculation based on the Einstein summation convention.

    Notes
    -----
    The subscripts string is a comma-separated list of subscript labels,
    where each label refers to a dimension of the corresponding operand.
    Repeated subscripts labels in one operand take the diagonal.  For example,
    ``einsum('ii', a)`` is equivalent to ``np.trace(a)`` (however, the former
    supports back-propagation).

    Whenever a label is repeated, it is summed, so ``einsum('i, i', a, b)``
    is equivalent to ``np.inner(a, b)``.  If a label appears only once,
    it is not summed, so ``einsum('i', a)`` produces a view of ``a``
    with no changes.

    The order of labels in the output is by default alphabetical.  This
    means that ``np.einsum('ij', a)`` doesn't affect a 2D tensor, while
    ``einsum('ji', a)`` takes its transpose.

    The output can be controlled by specifying output subscript labels
    as well.  This specifies the label order, and allows summing to
    be disallowed or forced when desired.  The call ``einsum('i->', a)``
    is like ``np.sum(a, axis=-1)``, and ``einsum('ii->i', a)``
    is like ``np.diag(a)``.  The difference is that `einsum` does not
    allow broadcasting by default.

    To enable and control broadcasting, use an ellipsis.  Default
    NumPy-style broadcasting is done by adding an ellipsis
    to the left of each term, like ``einsum('...ii->...i', a)``.
    To take the trace along the first and last axes,
    you can do ``einsum('i...i', a)``, or to do a matrix-matrix
    product with the left-most indices instead of rightmost, you can do
    ``einsum('ij...,jk...->ik...', a, b)``.

    When there is only one operand, no axes are summed, and no output
    parameter is provided, a view into the operand is returned instead
    of a new tensor.  Thus, taking the diagonal as ``einsum('ii->i', a)``
    produces a view.

    An alternative way to provide the subscripts and operands is as
    ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``. The examples
    below have corresponding `einsum` calls with the two parameter methods.

    Examples
    --------
    >>> import mygrad as mg
    >>> import numpy as np
    >>> a = mg.arange(25).reshape(5,5)
    >>> b = mg.arange(5)
    >>> c = mg.arange(6).reshape(2,3)

    Compute the trace of ``a``, :math:`\sum_{i}{A_{ii}} = f`:

    >>> einsum('ii', a)
    Tensor(60)
    >>> einsum(a, [0, 0])
    Tensor(60)
    >>> np.trace(a.data)
    array(60)

    Return a view along the diagonal of ``a``, :math:`A_{ii} = F_{i}`:

    >>> einsum('ii->i', a)
    Tensor([ 0,  6, 12, 18, 24])
    >>> einsum(a, [0,0], [0])
    Tensor([ 0,  6, 12, 18, 24])
    >>> np.diag(a.data)
    array([ 0,  6, 12, 18, 24])

    Compute the matrix-vector product of ``a`` with ``b``, :math:`\sum_{j}{A_{ij} B_{j}} = F_{i}`:

    >>> einsum('ij,j', a, b)
    Tensor([ 30,  80, 130, 180, 230])
    >>> einsum(a, [0,1], b, [1])
    Tensor([ 30,  80, 130, 180, 230])
    >>> mg.matmul(a, b)
    Tensor([ 30,  80, 130, 180, 230])
    >>> einsum('...j,j', a, b)
    Tensor([ 30,  80, 130, 180, 230])

    Take the transpose of ``c``, :math:`C_{ji} = F_{ij}`:

    >>> einsum('ji', c)
    Tensor([[0, 3],
            [1, 4],
            [2, 5]])
    >>> einsum(c, [1, 0])
    Tensor([[0, 3],
            [1, 4],
            [2, 5]])
    >>> c.T
    Tensor([[0, 3],
            [1, 4],
            [2, 5]])

    Compute ``3 * c``:

    >>> einsum('..., ...', 3, c)
    Tensor([[ 0,  3,  6],
            [ 9, 12, 15]])
    >>> einsum(',ij', 3, c)
    Tensor([[ 0,  3,  6],
            [ 9, 12, 15]])
    >>> einsum(3, [Ellipsis], c, [Ellipsis])
    Tensor([[ 0,  3,  6],
            [ 9, 12, 15]])
    >>> 3 * c
    Tensor([[ 0,  3,  6],
            [ 9, 12, 15]])

    Compute the inner product of ``b`` with itself, :math:`\sum_{i}{B_{i} B_{i}} = f`:

    >>> einsum('i,i', b, b)
    Tensor(30)
    >>> einsum(b, [0], b, [0])
    Tensor(30)
    >>> np.inner(b.data, b.data)
    30

    Compute the outer product of ``array([1, 2])`` with ``b``, :math:`A_{i}B_{j} = F_{ij}`:

    >>> einsum('i,j', np.arange(2)+1, b)
    Tensor([[0, 1, 2, 3, 4],
           [0, 2, 4, 6, 8]])
    >>> einsum(np.arange(2)+1, [0], b, [1])
    Tensor([[0, 1, 2, 3, 4],
           [0, 2, 4, 6, 8]])
    >>> np.outer(np.arange(2)+1, b)
    array([[0, 1, 2, 3, 4],
           [0, 2, 4, 6, 8]])
    >>> einsum('i...->...', a)
    Tensor([50, 55, 60, 65, 70])
    >>> einsum(a, [0,Ellipsis], [Ellipsis])
    Tensor([50, 55, 60, 65, 70])
    >>> np.sum(a, axis=0)
    array([50, 55, 60, 65, 70])

    Compute the tensor product :math:`\sum_{ij}{A_{ijk} B_{jil}} = F_{kl}`

    >>> a = mg.arange(60.).reshape(3,4,5)
    >>> b = mg.arange(24.).reshape(4,3,2)
    >>> einsum('ijk,jil->kl', a, b)
    Tensor([[ 4400.,  4730.],
            [ 4532.,  4874.],
            [ 4664.,  5018.],
            [ 4796.,  5162.],
            [ 4928.,  5306.]])
    >>> einsum(a, [0,1,2], b, [1,0,3], [2,3])
    Tensor([[ 4400.,  4730.],
            [ 4532.,  4874.],
            [ 4664.,  5018.],
            [ 4796.,  5162.],
            [ 4928.,  5306.]])
    >>> np.tensordot(a,b, axes=([1,0],[0,1]))
    array([[ 4400.,  4730.],
            [ 4532.,  4874.],
            [ 4664.,  5018.],
            [ 4796.,  5162.],
            [ 4928.,  5306.]])

    Matrix multiply ``a.T`` with ``b.T``, :math:`\sum_{k}{A_{ki} B_{jk}} = F_{ij}`

    >>> a = mg.arange(6).reshape((3,2))
    >>> b = mg.arange(12).reshape((4,3))
    >>> einsum('ki,jk->ij', a, b)
    Tensor([[10, 28, 46, 64],
            [13, 40, 67, 94]])
    >>> einsum('ki,...k->i...', a, b)
    Tensor([[10, 28, 46, 64],
            [13, 40, 67, 94]])
    >>> einsum('k...,jk', a, b)
    Tensor([[10, 28, 46, 64],
            [13, 40, 67, 94]])

    Make an assignment to a view along the diagonal of ``a``:

    >>> a = mg.zeros((3, 3))
    >>> einsum('ii->i', a).data[:] = 1
    >>> a
    Tensor([[ 1.,  0.,  0.],
            [ 0.,  1.,  0.],
            [ 0.,  0.,  1.]])
    """

    # TODO: normalize error handling for invalid inputs
    operands = list(operands)
    if isinstance(operands[0], str):
        # operands form: "ijk, ijk", x, y
        variables = operands[1:]
        if any(isinstance(i, Tensor) for i in operands):
            operands[1:] = (
                var.data if isinstance(var, Tensor) else var for var in operands[1:]
            )
    else:
        # operands form: op0, sublist0, op1, sublist1, ..., [sublistout]
        end = -1 if len(operands) % 2 else None  # -1 if sublistout is included
        variables = operands[:end:2]
        if any(isinstance(i, Tensor) for i in operands):
            operands[:end:2] = (
                var.data if isinstance(var, Tensor) else var for var in operands[:end:2]
            )

    in_lbls, out_lbls, _ = _parse_einsum_input(operands)
    return Tensor._op(
        EinSum,
        *variables,
        op_kwargs=dict(in_lbls=in_lbls, out_lbls=out_lbls, optimize=optimize),
        constant=constant
    )
Example #13
0
def dmrg_local_update(intermediate, eigvec, max_mps_rank):
    """
    Perform local update for DMRG.

    Parameters
    ----------
    intermediate: the input einsum node. Its inputs are two mps sites.
    eigvec: the eigenvector to get the low rank decomposition.
    max_mps_rank: maximum mps tensor rank.
    """
    # parse intermediate strings
    inputs = intermediate.inputs
    assert len(inputs) == 2

    # Here input names are formatted as A{i}.
    index_input_0 = int(inputs[0].name[1:])
    index_input_1 = int(inputs[1].name[1:])

    in_subs, out_subs, _ = _parse_einsum_input(
        (intermediate.einsum_subscripts, *intermediate.inputs))

    if index_input_0 > index_input_1:
        # right site appers first
        right_subs, left_subs = in_subs.split(',')
    else:
        left_subs, right_subs = in_subs.split(',')

    map_subs_indices = dict(zip(out_subs,
                                list(range(len(intermediate.shape)))))

    contract_char, = list(set(left_subs) - set(out_subs))

    left_uncontract_chars = list(set(left_subs) - set(contract_char))
    right_uncontract_chars = list(set(right_subs) - set(contract_char))

    left_indices = [map_subs_indices[char] for char in left_uncontract_chars]
    right_indices = [map_subs_indices[char] for char in right_uncontract_chars]

    left_uncontract_str = "".join(left_uncontract_chars)
    right_uncontract_str = "".join(right_uncontract_chars)

    #############################################################
    # svd decomposition to get updated sites
    eigvec_shape = intermediate.shape
    eigvec_mat = T.transpose(eigvec, left_indices + right_indices)
    eigvec_mat = T.reshape(eigvec_mat,
                           (np.prod([eigvec_shape[i]
                                     for i in left_indices]), -1))

    U, s, VT = T.svd(eigvec_mat)
    rank = min([max_mps_rank, eigvec_mat.shape[0], eigvec_mat.shape[1]])
    U, s, VT = U[:, :rank], s[:rank], VT[:rank, :]
    VT = T.diag(s) @ VT

    U = T.reshape(U, [eigvec_shape[i] for i in left_indices] + [rank])
    VT = T.reshape(VT, ([rank] + [eigvec_shape[i] for i in right_indices]))

    left = T.einsum(f"{left_uncontract_str}{contract_char}->{left_subs}", U)
    right = T.einsum(f"{contract_char}{right_uncontract_str}->{right_subs}",
                     VT)

    return left, right
Example #14
0
def parse_einsum_input(*args):
    return _parse_einsum_input(args)
Example #15
0
def my_einsum(*operands):
    from numpy.core.einsumfunc import _parse_einsum_input
    operands = _parse_einsum_input(operands)
    return np.einsum("->".join(operands[:-1]), *operands[-1])
Example #16
0
def prune_single_inv_node(einsum_node, inv_node):
    """
    Prune the inv_node in the einsum node if condition mets.

    Note:
    1. can only optimize the node when the input of inv is an einsum node.
    2. only supports the case when the splitted nodes are different from the remaining ones.
        For example: ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, C) will be
        optimzied to ad.einsum("ab,bc->ac", C, ad.identity()),
        but we cannot optimize ad.einsum("ab,bc,cd,de->ae", inv("ab,bc->ac", A, B), A, B, B).

    Parameters
    ----------
    einsum_node: The fused einsum node
    inv_node: the input inv node to be pruned

    Returns
    -------
    If the einsum_node cannot be optimized, then return the input einsum_node.
    If it can be optimized, return the optimized einsum node.

    """
    from graph_ops.graph_transformer import rewrite_einsum_expr
    from graph_ops.graph_generator import split_einsum

    inv_node_input = inv_node.inputs[0]
    if not isinstance(inv_node_input, ad.EinsumNode):
        logger.info(f"inv input is not einsum node, can't prune inv")
        return einsum_node

    if not set(inv_node_input.inputs).issubset(set(einsum_node.inputs)):
        logger.info(
            f"inv inputs is not subset of einsum node inputs, can't prune inv")
        return einsum_node

    einsum_inputs_in_inv = [
        n for n in einsum_node.inputs if n in inv_node_input.inputs
    ]
    if len(einsum_inputs_in_inv) < len(inv_node_input.inputs):
        logger.info(
            f"number of inv inputs is more than corresponding einsum inputs, can't prune inv"
        )
        return einsum_node

    split_einsum_node = split_einsum(
        einsum_node,
        list(set(einsum_node.inputs) - set(inv_node_input.inputs)))

    # Assign pseudo nodes and chars
    in_subs, out_subs, _ = _parse_einsum_input(
        (split_einsum_node.einsum_subscripts, *split_einsum_node.inputs))
    in_subs_list = in_subs.split(',')

    updated_p_in_nodes = []
    for i, node in enumerate(split_einsum_node.inputs):
        if isinstance(node, ad.EinsumNode):
            p_einsum_input = PseudoNode(node=node, subscript=in_subs_list[i])
        elif node is inv_node:
            p_inv_input = PseudoNode(node=node, subscript=in_subs_list[i])
        else:
            updated_p_in_nodes.append(
                PseudoNode(node=node, subscript=in_subs_list[i]))

    contract_char = "".join(
        set(p_einsum_input.subscript) & set(p_inv_input.subscript))
    uncontract_str = "".join(
        set("".join([p_einsum_input.subscript, p_inv_input.subscript])) -
        set(contract_char))

    if not (len(p_einsum_input.subscript) == 2 and len(p_inv_input.subscript)
            == 2 and len(contract_char) == 1 and len(uncontract_str) == 2):
        # this is not a matmul. Just return the initial node
        logger.info(
            f"the op between inv input and the selected einsum is not matmul, can't prune inv"
        )
        return einsum_node

    if p_einsum_input.subscript[0] == p_inv_input.subscript[
            0] or p_einsum_input.subscript[1] == p_inv_input.subscript[1]:
        # the str is like "ab,ac", and one einsum needs to be transposed to compare
        p_in_subs, p_out_subs, _ = _parse_einsum_input(
            (p_einsum_input.node.einsum_subscripts,
             *p_einsum_input.node.inputs))
        einsum_input = ad.einsum(
            f"{p_in_subs}->{p_out_subs[1]}{p_out_subs[0]}",
            *p_einsum_input.node.inputs)
    else:
        einsum_input = p_einsum_input.node

    rewrite_einsum_expr(einsum_input)
    rewrite_einsum_expr(inv_node_input)

    if einsum_input.name != inv_node_input.name:
        logger.info(
            f"inv input and the selected einsum have different expressions, can't prune inv"
        )
        return einsum_node

    # prune the inv node
    updated_p_in_nodes = updated_p_in_nodes + [
        PseudoNode(node=ad.identity(inv_node_input.shape[0]),
                   subscript=uncontract_str)
    ]

    return generate_new_einsum(updated_p_in_nodes, out_subs)