def test_eval(self): x = np.arange(1, 31) * (1.0 + 1.0j) x.shape = (10, 3) y = np.arange(1, 31) * (2.0 - 3.0j) y.shape = (10, 3) z = np.arange(1, 31) * (3.0 - 9.0j) z.shape = (10, 3) w = np.arange(1, 31) * (4.0 + 2.0j) w.shape = (10, 3) x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w))) expressions = [ 'x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z' ] data = {} for ex in expressions: data[ex] = eval(ex) currentSol = {'x': 1.1 * x, 'y': .9 * y, 'z': 1.1 * z, 'w': 1.2 * w} for i in range( 5 ): # reducing iters prevents printing a bunch of "The exact solution is x = 0" testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) currentSol = testSolve.solve() for var in 'wxyz': np.testing.assert_almost_equal(currentSol[var], eval(var), 4) result = testSolve.eval(currentSol) for eq in data: np.testing.assert_almost_equal(data[eq], result[eq], 4)
def test_chisq(self): x = 1. d = {'x*y': 1, '.5*x*y+.5*x*y': 2, 'y': 1} currentSol = {'x': 2.3, 'y': .9} for i in range( 5 ): # reducing iters prevents printing a bunch of "The exact solution is x = 0" testSolve = linsolve.LinProductSolver(d, currentSol, sparse=self.sparse) currentSol = testSolve.solve() chisq = testSolve.chisq(currentSol) np.testing.assert_almost_equal(chisq, .5)
def test_real_solve(self): x, y, z = 1., 2., 3. keys = ['x*y', 'x*z', 'y*z'] d, w = {}, {} for k in keys: d[k], w[k] = eval(k), 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k) + .01 ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) sol = ls.solve() for k in sol: np.testing.assert_almost_equal(sol[k], eval(k), 4)
def test_init(self): x,y,z = 1.+1j, 2.+2j, 3.+3j d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} for k in list(d.keys()): w[k] = 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k)+.01 ls = linsolve.LinProductSolver(d,sol0,w,sparse=self.sparse) x,y,z = 1.,1.,1. x_,y_,z_ = 1.,1.,1. dx = dy = dz = .001 dx_ = dy_ = dz_ = .001 for k in ls.ls.keys: self.assertAlmostEqual(eval(k), 0.002) self.assertEqual(len(ls.ls.prms), 3)
def test_complex_conj_solve(self): x, y, z = 1. + 1j, 2. + 2j, 3. + 3j d, w = { 'x*y_': x * y.conjugate(), 'x*z_': x * z.conjugate(), 'y*z_': y * z.conjugate() }, {} for k in list(d.keys()): w[k] = 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k) + .01 ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) ls.prm_order = {'x': 0, 'y': 1, 'z': 2} _, sol = ls.solve_iteratively(mode='lsqr') # XXX fails for pinv x, y, z = sol['x'], sol['y'], sol['z'] np.testing.assert_almost_equal(x * y.conjugate(), d['x*y_'], 3) np.testing.assert_almost_equal(x * z.conjugate(), d['x*z_'], 3) np.testing.assert_almost_equal(y * z.conjugate(), d['y*z_'], 3)
def test_degen_sol(self): # test how various solvers deal with degenerate solutions x, y, z = 1. + 1j, 2. + 2j, 3. + 3j d, w = { 'x*y_': x * y.conjugate(), 'x*z_': x * z.conjugate(), 'y*z_': y * z.conjugate() }, {} for k in list(d.keys()): w[k] = 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k) + .01 ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) ls.prm_order = {'x': 0, 'y': 1, 'z': 2} for mode in ('pinv', 'lsqr'): _, sol = ls.solve_iteratively(mode=mode) x, y, z = sol['x'], sol['y'], sol['z'] np.testing.assert_almost_equal(x * y.conjugate(), d['x*y_'], 3) np.testing.assert_almost_equal(x * z.conjugate(), d['x*z_'], 3) np.testing.assert_almost_equal(y * z.conjugate(), d['y*z_'], 3)
def test_solve_iteratively_dtype(self): x = np.arange(1, 31) * (1.0 + 1.0j) x.shape = (10, 3) y = np.arange(1, 31) * (2.0 - 3.0j) y.shape = (10, 3) z = np.arange(1, 31) * (3.0 - 9.0j) z.shape = (10, 3) w = np.arange(1, 31) * (4.0 + 2.0j) w.shape = (10, 3) x_, y_, z_, w_ = list(map(np.conjugate, (x, y, z, w))) expressions = [ 'x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z' ] data = {} for dtype in (np.complex128, np.complex64): for ex in expressions: data[ex] = eval(ex).astype(dtype) currentSol = { 'x': 1.1 * x, 'y': .9 * y, 'z': 1.1 * z, 'w': 1.2 * w } currentSol = {k: v.astype(dtype) for k, v in currentSol.items()} testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) # some ridiculousness to avoid "The exact solution is x = 0" prints save_stdout = sys.stdout sys.stdout = io.StringIO() meta, new_sol = testSolve.solve_iteratively(conv_crit=1e-7) sys.stdout = save_stdout for var in 'wxyz': assert new_sol[var].dtype == dtype np.testing.assert_almost_equal(new_sol[var], eval(var), 4)