示例#1
0
文件: bcalc.py 项目: sunqm/sparse
def einsum(idx_str, *tensors, **kwargs):
    '''Perform a more efficient einsum via reshaping to a matrix multiply.

    Current differences compared to numpy.einsum:
    This assumes that each repeated index is actually summed (i.e. no 'i,i->i')
    and appears only twice (i.e. no 'ij,ik,il->jkl'). The output indices must
    be explicitly specified (i.e. 'ij,j->i' and not 'ij,j').
    '''

    DEBUG = kwargs.get('DEBUG', False)

    idx_str = idx_str.replace(' ', '')
    indices = "".join(re.split(',|->', idx_str))
    if '->' not in idx_str or any(indices.count(x) > 2 for x in set(indices)):
        #return np.einsum(idx_str,*tensors)
        raise NotImplementedError

    if idx_str.count(',') > 1:
        indices = re.split(',|->', idx_str)
        indices_in = indices[:-1]
        idx_final = indices[-1]
        n_shared_max = 0
        for i in range(len(indices_in)):
            for j in range(i):
                tmp = list(set(indices_in[i]).intersection(indices_in[j]))
                n_shared_indices = len(tmp)
                if n_shared_indices > n_shared_max:
                    n_shared_max = n_shared_indices
                    shared_indices = tmp
                    [a, b] = [i, j]
        tensors = list(tensors)
        A, B = tensors[a], tensors[b]
        idxA, idxB = indices[a], indices[b]
        idx_out = list(idxA + idxB)
        idx_out = "".join([x for x in idx_out if x not in shared_indices])
        C = einsum(idxA + "," + idxB + "->" + idx_out, A, B)
        indices_in.pop(a)
        indices_in.pop(b)
        indices_in.append(idx_out)
        tensors.pop(a)
        tensors.pop(b)
        tensors.append(C)
        return einsum(",".join(indices_in) + "->" + idx_final, *tensors)

    A, B = tensors

    # Call numpy.asarray because A or B may be HDF5 Datasets
    # A = numpy.asarray(A, order='A')
    # B = numpy.asarray(B, order='A')
    # if A.size < 2000 or B.size < 2000:
    #     return numpy.einsum(idx_str, *tensors)

    # Split the strings into a list of idx char's
    idxA, idxBC = idx_str.split(',')
    idxB, idxC = idxBC.split('->')
    idxA, idxB, idxC = [list(x) for x in [idxA, idxB, idxC]]
    assert (len(idxA) == A.ndim)
    assert (len(idxB) == B.ndim)

    if DEBUG:
        print("*** Einsum for", idx_str)
        print(" idxA =", idxA)
        print(" idxB =", idxB)
        print(" idxC =", idxC)

    # Get the range for each index and put it in a dictionary
    rangeA = dict()
    rangeB = dict()
    block_rangeA = dict()
    block_rangeB = dict()

    for idx, rnge in zip(idxA, A.outer_shape):  # ZHC NOTE
        rangeA[idx] = rnge
    for idx, rnge in zip(idxB, B.outer_shape):
        rangeB[idx] = rnge
    for idx, rnge in zip(idxA, A.block_shape):
        block_rangeA[idx] = rnge
    for idx, rnge in zip(idxB, B.block_shape):
        block_rangeB[idx] = rnge

    if DEBUG:
        print("rangeA =", rangeA)
        print("rangeB =", rangeB)
        print("block_rangeA =", block_rangeA)
        print("block_rangeB =", block_rangeB)

    # Find the shared indices being summed over
    shared_idxAB = list(set(idxA).intersection(idxB))
    #if len(shared_idxAB) == 0:
    #    return np.einsum(idx_str,A,B)
    idxAt = list(idxA)
    idxBt = list(idxB)
    inner_shape = 1
    block_inner_shape = 1
    insert_B_loc = 0
    for n in shared_idxAB:
        if rangeA[n] != rangeB[n]:
            err = ('ERROR: In index string %s, the range of index %s is '
                   'different in A (%d) and B (%d)' %
                   (idx_str, n, rangeA[n], rangeB[n]))
            raise RuntimeError(err)

        # Bring idx all the way to the right for A
        # and to the left (but preserve order) for B
        idxA_n = idxAt.index(n)
        idxAt.insert(len(idxAt) - 1, idxAt.pop(idxA_n))

        idxB_n = idxBt.index(n)
        idxBt.insert(insert_B_loc, idxBt.pop(idxB_n))
        insert_B_loc += 1

        inner_shape *= rangeA[n]
        block_inner_shape *= block_rangeA[n]

    if DEBUG:
        print("shared_idxAB =", shared_idxAB)
        print("inner_shape =", inner_shape)
        print("block_inner_shape =", block_inner_shape)

    # Transpose the tensors into the proper order and reshape into matrices
    new_orderA = [idxA.index(idx) for idx in idxAt]
    new_orderB = [idxB.index(idx) for idx in idxBt]

    if DEBUG:
        print("Transposing A as", new_orderA)
        print("Transposing B as", new_orderB)
        print("Reshaping A as (-1,", inner_shape, ")")
        print("Reshaping B as (", inner_shape, ",-1)")
        print("Reshaping block A as (-1,", block_inner_shape, ")")
        print("Reshaping block B as (", block_inner_shape, ",-1)")

    shapeCt = list()
    block_shapeCt = list()
    idxCt = list()
    for idx in idxAt:
        if idx in shared_idxAB:
            break
        shapeCt.append(rangeA[idx])
        block_shapeCt.append(block_rangeA[idx])
        idxCt.append(idx)
    for idx in idxBt:
        if idx in shared_idxAB:
            continue
        shapeCt.append(rangeB[idx])
        block_shapeCt.append(block_rangeB[idx])
        idxCt.append(idx)
    new_orderCt = [idxCt.index(idx) for idx in idxC]

    np_shapeCt = tuple(np.multiply(shapeCt, block_shapeCt))
    if A.nnz == 0 or B.nnz == 0:
        shapeCt = [shapeCt[i] for i in new_orderCt]
        block_shapeCt = [block_shapeCt[i] for i in new_orderCt]
        return BCOO(np.array([],dtype = np.int), data = np.array([], dtype = \
                    np.result_type(A.dtype,B.dtype)), shape=np_shapeCt,\
                    block_shape = block_shapeCt, has_duplicates=False,\
                    sorted=True).transpose(new_orderCt)

    At = A.transpose(new_orderA)
    Bt = B.transpose(new_orderB)

    # ZHC TODO optimize
    # if At.flags.f_contiguous:
    #     At = numpy.asarray(At.reshape((-1,inner_shape), (-1,block_inner_shape)), order='F')
    # else:
    At = At.block_reshape((-1, inner_shape),
                          block_shape=(-1, block_inner_shape))
    # if Bt.flags.f_contiguous:
    #     Bt = numpy.asarray(Bt.reshape((inner_shape,-1), (block_inner_shape,-1)), order='F')
    # else:
    Bt = Bt.block_reshape((inner_shape, -1),
                          block_shape=(block_inner_shape, -1))

    #AdotB = At.tobsr().dot(Bt.tobsr())
    At = At.tobsr()
    Bt = Bt.tobsr()
    AdotB = At.dot(Bt)

    AdotB_bcoo = BCOO.from_bsr(AdotB)

    if DEBUG:
        print("AdotB bsr format indptr, indices")
        print(AdotB.indptr)
        print(AdotB.indices)
        print("AdotB bcoo format coords")
        print(AdotB_bcoo.coords)

    return AdotB_bcoo.block_reshape(
        shapeCt, block_shape=block_shapeCt).transpose(new_orderCt)
示例#2
0
def _contract(subscripts, *tensors, **kwargs):
    DEBUG = kwargs.get('DEBUG', False)

    idx_str = subscripts.replace(' ', '')
    indices = idx_str.replace(',', '').replace('->', '')
    if '->' not in idx_str or any(indices.count(x) > 2 for x in set(indices)):
        # TODO 1. No ->, contract over repeated indices 2. more than 2 indices need to contract.
        raise NotImplementedError

    A, B = tensors

    # mix type, transfer to dense case
    if not (isinstance(A, BCOO) and isinstance(B, BCOO)):
        print(
            "Warning: the block einsum takes non-BCOO objects, try to transfer to dense..."
        )
        if hasattr(A, 'todense'):
            A = A.todense()
        if hasattr(B, 'todense'):
            B = B.todense()
        return np.einsum(idx_str, A, B)

    # ZHC NOTE threshold to determine which lib to use here?

    # Split the strings into a list of idx char's
    idxA, idxBC = idx_str.split(',')
    idxB, idxC = idxBC.split('->')
    #idxA, idxB, idxC = [list(x) for x in [idxA,idxB,idxC]]
    assert (len(idxA) == A.ndim)
    assert (len(idxB) == B.ndim)

    if DEBUG:
        print("*** Einsum for", idx_str)
        print(" idxA =", idxA)
        print(" idxB =", idxB)
        print(" idxC =", idxC)

    # Get the range for each index and put it in a dictionary
    rangeA = dict(zip(idxA, A.outer_shape))
    rangeB = dict(zip(idxB, B.outer_shape))
    block_rangeA = dict(zip(idxA, A.block_shape))
    block_rangeB = dict(zip(idxB, B.block_shape))

    if DEBUG:
        print("rangeA =", rangeA)
        print("rangeB =", rangeB)
        print("block_rangeA =", block_rangeA)
        print("block_rangeB =", block_rangeB)

    # duplicated indices 'in,ijj->n' # TODO: first index out the repeated indices.
    if len(rangeA) != A.ndim or len(rangeB) != B.ndim:
        raise NotImplementedError

    # Find the shared indices being summed over
    shared_idxAB = list(set(idxA).intersection(idxB))
    if len(shared_idxAB) == 0:  # TODO Indices must overlap
        raise NotImplementedError

    idxAt = list(idxA)
    idxBt = list(idxB)
    inner_shape = 1
    block_inner_shape = 1
    insert_B_loc = 0
    for n in shared_idxAB:
        if rangeA[n] != rangeB[n]:
            err = (
                'ERROR: In index string %s, the outer_shape range of index %s is '
                'different in A (%d) and B (%d)' %
                (idx_str, n, rangeA[n], rangeB[n]))
            raise ValueError(err)
        if block_rangeA[n] != block_rangeB[n]:
            err = (
                'ERROR: In index string %s, the block_shape range of index %s is '
                'different in A (%d) and B (%d)' %
                (idx_str, n, block_rangeA[n], block_rangeB[n]))
            raise RuntimeError(err)

        # Bring idx all the way to the right for A
        # and to the left (but preserve order) for B
        idxA_n = idxAt.index(n)
        idxAt.insert(len(idxAt) - 1, idxAt.pop(idxA_n))

        idxB_n = idxBt.index(n)
        idxBt.insert(insert_B_loc, idxBt.pop(idxB_n))
        insert_B_loc += 1

        inner_shape *= rangeA[n]
        block_inner_shape *= block_rangeA[n]

    if DEBUG:
        print("shared_idxAB =", shared_idxAB)
        print("inner_shape =", inner_shape)
        print("block_inner_shape =", block_inner_shape)

    # Transpose the tensors into the proper order and reshape into matrices
    new_orderA = [idxA.index(idx) for idx in idxAt]
    new_orderB = [idxB.index(idx) for idx in idxBt]

    if DEBUG:
        print("Transposing A as", new_orderA)
        print("Transposing B as", new_orderB)
        print("Reshaping A as (-1,", inner_shape, ")")
        print("Reshaping B as (", inner_shape, ",-1)")
        print("Reshaping block A as (-1,", block_inner_shape, ")")
        print("Reshaping block B as (", block_inner_shape, ",-1)")

    shapeCt = list()
    block_shapeCt = list()
    idxCt = list()
    for idx in idxAt:
        if idx in shared_idxAB:
            break
        shapeCt.append(rangeA[idx])
        block_shapeCt.append(block_rangeA[idx])
        idxCt.append(idx)
    for idx in idxBt:
        if idx in shared_idxAB:
            continue
        shapeCt.append(rangeB[idx])
        block_shapeCt.append(block_rangeB[idx])
        idxCt.append(idx)
    new_orderCt = [idxCt.index(idx) for idx in idxC]

    if A.nnz == 0 or B.nnz == 0:
        shapeCt = [shapeCt[i] for i in new_orderCt]
        block_shapeCt = [block_shapeCt[i] for i in new_orderCt]
        np_shapeCt = tuple(np.multiply(shapeCt, block_shapeCt))

        return BCOO(np.array([],dtype = np.int), data = np.array([], dtype = \
                    np.result_type(A.dtype,B.dtype)), shape=np_shapeCt,\
                    block_shape = block_shapeCt, has_duplicates=False,\
                    sorted=True)

    At = A.transpose(new_orderA)
    Bt = B.transpose(new_orderB)

    At = At.block_reshape((-1, inner_shape),
                          block_shape=(-1, block_inner_shape))
    Bt = Bt.block_reshape((inner_shape, -1),
                          block_shape=(block_inner_shape, -1))

    #AdotB = At.tobsr().dot(Bt.tobsr())
    At = At.tobsr()
    Bt = Bt.tobsr()
    AdotB = At.dot(Bt)
    AdotB_bcoo = BCOO.from_bsr(AdotB)

    if DEBUG:
        print("AdotB bsr format indptr, indices")
        print(AdotB.indptr)
        print(AdotB.indices)
        print("AdotB bcoo format coords")
        print(AdotB_bcoo.coords)

    return AdotB_bcoo.block_reshape(
        shapeCt, block_shape=block_shapeCt).transpose(new_orderCt)