def test_contract(self): try: from pyscf.lib import tblis_einsum tblis_available = True except (ImportError, OSError): tblis_available = False if tblis_available: a = numpy.random.random((5,4,6)) b = numpy.random.random((4,9,6)) c1 = numpy.ones((9,5), dtype=numpy.complex128) c0 = tblis_einsum._contract('ijk,jlk->li', a, b, out=c1, alpha=.5j, beta=.2) c1 = numpy.ones((9,5), dtype=numpy.complex128) c1 = c1*.2 + numpy.einsum('ijk,jlk->li', a, b)*.5j self.assertTrue(abs(c0-c1).max() < 1e-13)
def test_contract(self): try: from pyscf.lib import tblis_einsum tblis_available = True except (ImportError, OSError): tblis_available = False if tblis_available: a = numpy.random.random((5, 4, 6)) b = numpy.random.random((4, 9, 6)) c1 = numpy.ones((9, 5), dtype=numpy.complex128) c0 = tblis_einsum._contract('ijk,jlk->li', a, b, out=c1, alpha=.5j, beta=.2) c1 = numpy.ones((9, 5), dtype=numpy.complex128) c1 = c1 * .2 + numpy.einsum('ijk,jlk->li', a, b) * .5j self.assertTrue(abs(c0 - c1).max() < 1e-13)
def _contract(subscripts, *tensors, **kwargs): idx_str = subscripts.replace(' ','') indices = idx_str.replace(',', '').replace('->', '') if '->' not in idx_str or any(indices.count(x)>2 for x in set(indices)): return numpy.einsum(idx_str, *tensors) A, B = tensors # Call numpy.asarray because A or B may be HDF5 Datasets A = numpy.asarray(A, order='A') B = numpy.asarray(B, order='A') if A.size < EINSUM_MAX_SIZE or B.size < EINSUM_MAX_SIZE: return numpy.einsum(idx_str, *tensors) C_dtype = numpy.result_type(A, B) if FOUND_TBLIS and C_dtype == numpy.double: # tblis is slow for complex type return tblis_einsum._contract(idx_str, A, B, **kwargs) DEBUG = kwargs.get('DEBUG', False) # Split the strings into a list of idx char's idxA, idxBC = idx_str.split(',') idxB, idxC = idxBC.split('->') assert(len(idxA) == A.ndim) assert(len(idxB) == B.ndim) if DEBUG: print("*** Einsum for", idx_str) print(" idxA =", idxA) print(" idxB =", idxB) print(" idxC =", idxC) # Get the range for each index and put it in a dictionary rangeA = dict(zip(idxA, A.shape)) rangeB = dict(zip(idxB, B.shape)) #rangeC = dict(zip(idxC, C.shape)) if DEBUG: print("rangeA =", rangeA) print("rangeB =", rangeB) # duplicated indices 'in,ijj->n' if len(rangeA) != A.ndim or len(rangeB) != B.ndim: return numpy.einsum(idx_str, A, B) # Find the shared indices being summed over shared_idxAB = set(idxA).intersection(idxB) if len(shared_idxAB) == 0: # Indices must overlap return numpy.einsum(idx_str, A, B) idxAt = list(idxA) idxBt = list(idxB) inner_shape = 1 insert_B_loc = 0 for n in shared_idxAB: if rangeA[n] != rangeB[n]: err = ('ERROR: In index string %s, the range of index %s is ' 'different in A (%d) and B (%d)' % (idx_str, n, rangeA[n], rangeB[n])) raise ValueError(err) # Bring idx all the way to the right for A # and to the left (but preserve order) for B idxA_n = idxAt.index(n) idxAt.insert(len(idxAt)-1, idxAt.pop(idxA_n)) idxB_n = idxBt.index(n) idxBt.insert(insert_B_loc, idxBt.pop(idxB_n)) insert_B_loc += 1 inner_shape *= rangeA[n] if DEBUG: print("shared_idxAB =", shared_idxAB) print("inner_shape =", inner_shape) # Transpose the tensors into the proper order and reshape into matrices new_orderA = [idxA.index(idx) for idx in idxAt] new_orderB = [idxB.index(idx) for idx in idxBt] if DEBUG: print("Transposing A as", new_orderA) print("Transposing B as", new_orderB) print("Reshaping A as (-1,", inner_shape, ")") print("Reshaping B as (", inner_shape, ",-1)") shapeCt = list() idxCt = list() for idx in idxAt: if idx in shared_idxAB: break shapeCt.append(rangeA[idx]) idxCt.append(idx) for idx in idxBt: if idx in shared_idxAB: continue shapeCt.append(rangeB[idx]) idxCt.append(idx) new_orderCt = [idxCt.index(idx) for idx in idxC] if A.size == 0 or B.size == 0: shapeCt = [shapeCt[i] for i in new_orderCt] return numpy.zeros(shapeCt, dtype=C_dtype) At = A.transpose(new_orderA) Bt = B.transpose(new_orderB) if At.flags.f_contiguous: At = numpy.asarray(At.reshape(-1,inner_shape), order='F') else: At = numpy.asarray(At.reshape(-1,inner_shape), order='C') if Bt.flags.f_contiguous: Bt = numpy.asarray(Bt.reshape(inner_shape,-1), order='F') else: Bt = numpy.asarray(Bt.reshape(inner_shape,-1), order='C') return dot(At,Bt).reshape(shapeCt, order='A').transpose(new_orderCt)