示例#1
0
    def test_contract(self):
        try:
            from pyscf.lib import tblis_einsum
            tblis_available = True
        except (ImportError, OSError):
            tblis_available = False

        if tblis_available:
            a = numpy.random.random((5,4,6))
            b = numpy.random.random((4,9,6))

            c1 = numpy.ones((9,5), dtype=numpy.complex128)
            c0 = tblis_einsum._contract('ijk,jlk->li', a, b, out=c1, alpha=.5j, beta=.2)
            c1 = numpy.ones((9,5), dtype=numpy.complex128)
            c1 = c1*.2 + numpy.einsum('ijk,jlk->li', a, b)*.5j
            self.assertTrue(abs(c0-c1).max() < 1e-13)
示例#2
0
    def test_contract(self):
        try:
            from pyscf.lib import tblis_einsum
            tblis_available = True
        except (ImportError, OSError):
            tblis_available = False

        if tblis_available:
            a = numpy.random.random((5, 4, 6))
            b = numpy.random.random((4, 9, 6))

            c1 = numpy.ones((9, 5), dtype=numpy.complex128)
            c0 = tblis_einsum._contract('ijk,jlk->li',
                                        a,
                                        b,
                                        out=c1,
                                        alpha=.5j,
                                        beta=.2)
            c1 = numpy.ones((9, 5), dtype=numpy.complex128)
            c1 = c1 * .2 + numpy.einsum('ijk,jlk->li', a, b) * .5j
            self.assertTrue(abs(c0 - c1).max() < 1e-13)
示例#3
0
def _contract(subscripts, *tensors, **kwargs):
    idx_str = subscripts.replace(' ','')
    indices  = idx_str.replace(',', '').replace('->', '')
    if '->' not in idx_str or any(indices.count(x)>2 for x in set(indices)):
        return numpy.einsum(idx_str, *tensors)

    A, B = tensors
    # Call numpy.asarray because A or B may be HDF5 Datasets 
    A = numpy.asarray(A, order='A')
    B = numpy.asarray(B, order='A')
    if A.size < EINSUM_MAX_SIZE or B.size < EINSUM_MAX_SIZE:
        return numpy.einsum(idx_str, *tensors)

    C_dtype = numpy.result_type(A, B)
    if FOUND_TBLIS and C_dtype == numpy.double:
        # tblis is slow for complex type
        return tblis_einsum._contract(idx_str, A, B, **kwargs)

    DEBUG = kwargs.get('DEBUG', False)

    # Split the strings into a list of idx char's
    idxA, idxBC = idx_str.split(',')
    idxB, idxC = idxBC.split('->')
    assert(len(idxA) == A.ndim)
    assert(len(idxB) == B.ndim)

    if DEBUG:
        print("*** Einsum for", idx_str)
        print(" idxA =", idxA)
        print(" idxB =", idxB)
        print(" idxC =", idxC)

    # Get the range for each index and put it in a dictionary
    rangeA = dict(zip(idxA, A.shape))
    rangeB = dict(zip(idxB, B.shape))
    #rangeC = dict(zip(idxC, C.shape))
    if DEBUG:
        print("rangeA =", rangeA)
        print("rangeB =", rangeB)

    # duplicated indices 'in,ijj->n'
    if len(rangeA) != A.ndim or len(rangeB) != B.ndim:
        return numpy.einsum(idx_str, A, B)

    # Find the shared indices being summed over
    shared_idxAB = set(idxA).intersection(idxB)
    if len(shared_idxAB) == 0: # Indices must overlap
        return numpy.einsum(idx_str, A, B)

    idxAt = list(idxA)
    idxBt = list(idxB)
    inner_shape = 1
    insert_B_loc = 0
    for n in shared_idxAB:
        if rangeA[n] != rangeB[n]:
            err = ('ERROR: In index string %s, the range of index %s is '
                   'different in A (%d) and B (%d)' %
                   (idx_str, n, rangeA[n], rangeB[n]))
            raise ValueError(err)

        # Bring idx all the way to the right for A
        # and to the left (but preserve order) for B
        idxA_n = idxAt.index(n)
        idxAt.insert(len(idxAt)-1, idxAt.pop(idxA_n))

        idxB_n = idxBt.index(n)
        idxBt.insert(insert_B_loc, idxBt.pop(idxB_n))
        insert_B_loc += 1

        inner_shape *= rangeA[n]

    if DEBUG:
        print("shared_idxAB =", shared_idxAB)
        print("inner_shape =", inner_shape)

    # Transpose the tensors into the proper order and reshape into matrices
    new_orderA = [idxA.index(idx) for idx in idxAt]
    new_orderB = [idxB.index(idx) for idx in idxBt]

    if DEBUG:
        print("Transposing A as", new_orderA)
        print("Transposing B as", new_orderB)
        print("Reshaping A as (-1,", inner_shape, ")")
        print("Reshaping B as (", inner_shape, ",-1)")

    shapeCt = list()
    idxCt = list()
    for idx in idxAt:
        if idx in shared_idxAB:
            break
        shapeCt.append(rangeA[idx])
        idxCt.append(idx)
    for idx in idxBt:
        if idx in shared_idxAB:
            continue
        shapeCt.append(rangeB[idx])
        idxCt.append(idx)
    new_orderCt = [idxCt.index(idx) for idx in idxC]

    if A.size == 0 or B.size == 0:
        shapeCt = [shapeCt[i] for i in new_orderCt]
        return numpy.zeros(shapeCt, dtype=C_dtype)

    At = A.transpose(new_orderA)
    Bt = B.transpose(new_orderB)

    if At.flags.f_contiguous:
        At = numpy.asarray(At.reshape(-1,inner_shape), order='F')
    else:
        At = numpy.asarray(At.reshape(-1,inner_shape), order='C')
    if Bt.flags.f_contiguous:
        Bt = numpy.asarray(Bt.reshape(inner_shape,-1), order='F')
    else:
        Bt = numpy.asarray(Bt.reshape(inner_shape,-1), order='C')

    return dot(At,Bt).reshape(shapeCt, order='A').transpose(new_orderCt)