示例#1
0
文件: bicgstab.py 项目: qsnake/gpaw
 def multi_zdotc(s, x,y, nvec):
     for i in range(nvec):
         s[i] = dotc(x[i],y[i])
     self.gd.comm.sum(s)
     return s
示例#2
0
# Check gemmdot with floats
assert np.all(np.dot(a, b) == gemmdot(a, b))
assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t'))
assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='c'))
assert np.dot(b, b) == gemmdot(b, b)

# Check gemmdot with complex arrays
a = a * (2 + 1.j)
a2 = a2 * (-1 + 3.j)
b = b * (3 - 2.j)
assert np.all(np.dot(a, b) == gemmdot(a, b))
assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t'))
assert np.all(np.dot(a, a2.T.conj()) == gemmdot(a, a2, trans='c'))
assert np.dot(b, b) == gemmdot(b, b, trans='n')
assert np.dot(b, b.conj()) == gemmdot(b, b, trans='c')
assert np.vdot(a, 5.j * a) == dotc(a, 5.j * a)

# Check gemm for transa='n'
a2 = np.arange(7 * 5 * 1 * 3).reshape(7, 5, 1, 3) * (-1. + 4.j) + 3.
c = np.tensordot(a, a2, [1, 0])
gemm(1., a2, a, -1., c, 'n')
assert not c.any()

# Check gemm for transa='c'
a = np.arange(4 * 5 * 1 * 3).reshape(4, 5, 1, 3) * (3. - 2.j) + 4.
c = np.tensordot(a, a2.conj(), [[1, 2, 3], [1, 2, 3]])
gemm(1., a2, a, -1., c, 'c')
assert not c.any()

# Check axpy
c = 5.j * a
示例#3
0
 def multi_zdotc(s, x, y, nvec):
     for i in range(nvec):
         s[i] = dotc(x[i], y[i])
     self.gd.comm.sum(s)
     return s
示例#4
0
文件: blas.py 项目: qsnake/gpaw
# Check gemmdot with floats
assert np.all(np.dot(a, b) == gemmdot(a, b))
assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t'))
assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='c'))
assert np.dot(b, b) == gemmdot(b, b)

# Check gemmdot with complex arrays
a = a * (2 + 1.j)
a2 = a2 * (-1 + 3.j)
b = b * (3 - 2.j)
assert np.all(np.dot(a, b) == gemmdot(a, b))
assert np.all(np.dot(a, a2.T) == gemmdot(a, a2, trans='t'))
assert np.all(np.dot(a, a2.T.conj()) == gemmdot(a, a2, trans='c'))
assert np.dot(b, b) == gemmdot(b, b, trans='n')
assert np.dot(b, b.conj()) == gemmdot(b, b, trans='c')
assert np.vdot(a, 5.j * a) == dotc(a, 5.j * a)

# Check gemm for transa='n'
a2 = np.arange(7 * 5 * 1 * 3).reshape(7, 5, 1, 3) * (-1. + 4.j) + 3.
c = np.tensordot(a, a2, [1, 0])
gemm(1., a2, a, -1., c, 'n')
assert not c.any()

# Check gemm for transa='c'
a = np.arange(4 * 5 * 1 * 3).reshape(4, 5, 1, 3) * (3. - 2.j) + 4.
c = np.tensordot(a, a2.conj(), [[1, 2, 3], [1, 2, 3]])
gemm(1., a2, a, -1., c, 'c')
assert not c.any()

# Check axpy
c = 5.j * a