def _gemmdot(a, b, alpha=1.0, beta=1.0, out=None, trans='n'): """Matrix multiplication using gemm. return reference to out, where:: out <- alpha * a . b + beta * out If out is None, a suitably sized zero array will be created. ``a.b`` denotes matrix multiplication, where the product-sum is over the last dimension of a, and either the first dimension of b (for trans='n'), or the last dimension of b (for trans='t' or 'c'). If trans='c', the complex conjugate of b is used. """ # Store original shapes ashape = a.shape bshape = b.shape # Vector-vector multiplication is handled by dotu if a.ndim == 1 and b.ndim == 1: assert out is None if trans == 'c': return alpha * _gpaw.dotc(b, a) # dotc conjugates *first* argument else: return alpha * _gpaw.dotu(a, b) ## # Use gemv if a or b is a vector, and the other is a matrix?? ## if a.ndim == 1 and trans == 'n': ## gemv(alpha, b, a, beta, out, trans='n') ## if b.ndim == 1 and trans == 'n': ## gemv(alpha, a, b, beta, out, trans='t') # Map all arrays to 2D arrays a = a.reshape(-1, a.shape[-1]) if trans == 'n': b = b.reshape(b.shape[0], -1) outshape = a.shape[0], b.shape[1] else: # 't' or 'c' b = b.reshape(-1, b.shape[-1]) # Apply BLAS gemm routine outshape = a.shape[0], b.shape[trans == 'n'] if out is None: # (ATLAS can't handle uninitialized output array) out = np.zeros(outshape, a.dtype) else: out = out.reshape(outshape) gemm(alpha, b, a, beta, out, trans) # Determine actual shape of result array if trans == 'n': outshape = ashape[:-1] + bshape[1:] else: # 't' or 'c' outshape = ashape[:-1] + bshape[:-1] return out.reshape(outshape)
def dotc(a, b): """Dot product, conjugating the first vector with complex arguments. Returns the value of the operation:: _ \ cc ) a * b /_ ijk... ijk... ijk... ``cc`` denotes complex conjugation. """ assert ((is_contiguous(a, float) and is_contiguous(b, float)) or (is_contiguous(a, complex) and is_contiguous(b,complex))) assert a.shape == b.shape return _gpaw.dotc(a, b)
def dotc(a, b): """Dot product, conjugating the first vector with complex arguments. Returns the value of the operation:: _ \ cc ) a * b /_ ijk... ijk... ijk... ``cc`` denotes complex conjugation. """ assert ((is_contiguous(a, float) and is_contiguous(b, float)) or (is_contiguous(a, complex) and is_contiguous(b, complex))) assert a.shape == b.shape return _gpaw.dotc(a, b)