示例#1
0
 def test_eval(self):
     x = np.arange(1, 31) * (1.0 + 1.0j)
     x.shape = (10, 3)
     y = np.arange(1, 31) * (2.0 - 3.0j)
     y.shape = (10, 3)
     z = np.arange(1, 31) * (3.0 - 9.0j)
     z.shape = (10, 3)
     w = np.arange(1, 31) * (4.0 + 2.0j)
     w.shape = (10, 3)
     x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w)))
     expressions = [
         'x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z',
         '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z'
     ]
     data = {}
     for ex in expressions:
         data[ex] = eval(ex)
     currentSol = {'x': 1.1 * x, 'y': .9 * y, 'z': 1.1 * z, 'w': 1.2 * w}
     for i in range(
             5
     ):  # reducing iters prevents printing a bunch of "The exact solution is  x = 0"
         testSolve = linsolve.LinProductSolver(data,
                                               currentSol,
                                               sparse=self.sparse)
         currentSol = testSolve.solve()
     for var in 'wxyz':
         np.testing.assert_almost_equal(currentSol[var], eval(var), 4)
     result = testSolve.eval(currentSol)
     for eq in data:
         np.testing.assert_almost_equal(data[eq], result[eq], 4)
示例#2
0
 def test_chisq(self):
     x = 1.
     d = {'x*y': 1, '.5*x*y+.5*x*y': 2, 'y': 1}
     currentSol = {'x': 2.3, 'y': .9}
     for i in range(
             5
     ):  # reducing iters prevents printing a bunch of "The exact solution is  x = 0"
         testSolve = linsolve.LinProductSolver(d,
                                               currentSol,
                                               sparse=self.sparse)
         currentSol = testSolve.solve()
     chisq = testSolve.chisq(currentSol)
     np.testing.assert_almost_equal(chisq, .5)
示例#3
0
 def test_real_solve(self):
     x, y, z = 1., 2., 3.
     keys = ['x*y', 'x*z', 'y*z']
     d, w = {}, {}
     for k in keys:
         d[k], w[k] = eval(k), 1.
     sol0 = {}
     for k in 'xyz':
         sol0[k] = eval(k) + .01
     ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse)
     sol = ls.solve()
     for k in sol:
         np.testing.assert_almost_equal(sol[k], eval(k), 4)
示例#4
0
 def test_init(self):
     x,y,z = 1.+1j, 2.+2j, 3.+3j
     d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {}
     for k in list(d.keys()): w[k] = 1.
     sol0 = {}
     for k in 'xyz': sol0[k] = eval(k)+.01
     ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse)
     x,y,z = 1.,1.,1.
     x_,y_,z_ = 1.,1.,1.
     dx = dy = dz = .001
     dx_ = dy_ = dz_ = .001
     for k in ls.ls.keys:
         self.assertAlmostEqual(eval(k), 0.002)
     self.assertEqual(len(ls.ls.prms), 3)
示例#5
0
 def test_complex_conj_solve(self):
     x, y, z = 1. + 1j, 2. + 2j, 3. + 3j
     d, w = {
         'x*y_': x * y.conjugate(),
         'x*z_': x * z.conjugate(),
         'y*z_': y * z.conjugate()
     }, {}
     for k in list(d.keys()):
         w[k] = 1.
     sol0 = {}
     for k in 'xyz':
         sol0[k] = eval(k) + .01
     ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse)
     ls.prm_order = {'x': 0, 'y': 1, 'z': 2}
     _, sol = ls.solve_iteratively(mode='lsqr')  # XXX fails for pinv
     x, y, z = sol['x'], sol['y'], sol['z']
     np.testing.assert_almost_equal(x * y.conjugate(), d['x*y_'], 3)
     np.testing.assert_almost_equal(x * z.conjugate(), d['x*z_'], 3)
     np.testing.assert_almost_equal(y * z.conjugate(), d['y*z_'], 3)
示例#6
0
 def test_degen_sol(self):
     # test how various solvers deal with degenerate solutions
     x, y, z = 1. + 1j, 2. + 2j, 3. + 3j
     d, w = {
         'x*y_': x * y.conjugate(),
         'x*z_': x * z.conjugate(),
         'y*z_': y * z.conjugate()
     }, {}
     for k in list(d.keys()):
         w[k] = 1.
     sol0 = {}
     for k in 'xyz':
         sol0[k] = eval(k) + .01
     ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse)
     ls.prm_order = {'x': 0, 'y': 1, 'z': 2}
     for mode in ('pinv', 'lsqr'):
         _, sol = ls.solve_iteratively(mode=mode)
         x, y, z = sol['x'], sol['y'], sol['z']
         np.testing.assert_almost_equal(x * y.conjugate(), d['x*y_'], 3)
         np.testing.assert_almost_equal(x * z.conjugate(), d['x*z_'], 3)
         np.testing.assert_almost_equal(y * z.conjugate(), d['y*z_'], 3)
示例#7
0
 def test_solve_iteratively_dtype(self):
     x = np.arange(1, 31) * (1.0 + 1.0j)
     x.shape = (10, 3)
     y = np.arange(1, 31) * (2.0 - 3.0j)
     y.shape = (10, 3)
     z = np.arange(1, 31) * (3.0 - 9.0j)
     z.shape = (10, 3)
     w = np.arange(1, 31) * (4.0 + 2.0j)
     w.shape = (10, 3)
     x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w)))
     expressions = [
         'x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z',
         '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z'
     ]
     data = {}
     for dtype in (np.complex128, np.complex64):
         for ex in expressions:
             data[ex] = eval(ex).astype(dtype)
         currentSol = {
             'x': 1.1 * x,
             'y': .9 * y,
             'z': 1.1 * z,
             'w': 1.2 * w
         }
         currentSol = {k: v.astype(dtype) for k, v in currentSol.items()}
         testSolve = linsolve.LinProductSolver(data,
                                               currentSol,
                                               sparse=self.sparse)
         # some ridiculousness to avoid "The exact solution is  x = 0" prints
         save_stdout = sys.stdout
         sys.stdout = io.StringIO()
         meta, new_sol = testSolve.solve_iteratively(conv_crit=1e-7)
         sys.stdout = save_stdout
         for var in 'wxyz':
             assert new_sol[var].dtype == dtype
             np.testing.assert_almost_equal(new_sol[var], eval(var), 4)