def test_eval(self): x = np.arange(1, 31) * (1.0 + 1.0j) x.shape = (10, 3) y = np.arange(1, 31) * (2.0 - 3.0j) y.shape = (10, 3) z = np.arange(1, 31) * (3.0 - 9.0j) z.shape = (10, 3) w = np.arange(1, 31) * (4.0 + 2.0j) w.shape = (10, 3) expressions = [ 'x*y+z*w', '2*x*y+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w', '2*x + 3*y - 4*z' ] data = {} for ex in expressions: data[ex] = eval(ex) currentSol = {'x': 1.1 * x, 'y': .9 * y, 'z': 1.1 * z, 'w': 1.2 * w} for i in range(40): testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) currentSol = testSolve.solve() for var in 'wxyz': np.testing.assert_almost_equal(currentSol[var], eval(var), 4) result = testSolve.eval(currentSol) for eq in data: np.testing.assert_almost_equal(data[eq], result[eq], 4)
def test_sums_of_products(self): x = np.arange(1, 31) * (1.0 + 1.0j) x.shape = (10, 3) y = np.arange(1, 31) * (2.0 - 3.0j) y.shape = (10, 3) z = np.arange(1, 31) * (3.0 - 9.0j) z.shape = (10, 3) w = np.arange(1, 31) * (4.0 + 2.0j) w.shape = (10, 3) x_, y_, z_, w_ = map(np.conjugate, (x, y, z, w)) expressions = [ 'x*y+z*w', '2*x_*y_+z*w-1.0j*z*w', '2*x*w', '1.0j*x + y*z', '-1*x*z+3*y*w*x+y', '2*w_', '2*x_ + 3*y - 4*z' ] data = {} for ex in expressions: data[ex] = eval(ex) currentSol = {'x': 1.1 * x, 'y': .9 * y, 'z': 1.1 * z, 'w': 1.2 * w} for i in range(20): testSolve = linsolve.LinProductSolver(data, currentSol, sparse=self.sparse) currentSol = testSolve.solve() for var in 'wxyz': np.testing.assert_almost_equal(currentSol[var], eval(var), 4)
def test_chisq(self): x = 1. d = {'x*y': 1, '.5*x*y+.5*x*y': 2, 'y': 1} currentSol = {'x': 2.3, 'y': .9} for i in range(40): testSolve = linsolve.LinProductSolver(d, currentSol, sparse=self.sparse) currentSol = testSolve.solve() chisq = testSolve.chisq(currentSol) np.testing.assert_almost_equal(chisq, .5)
def test_complex_solve(self): x,y,z = 1+1j, 2+2j, 3+2j keys = ['x*y', 'x*z', 'y*z'] d,w = {}, {} for k in keys: d[k],w[k] = eval(k), 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k)+.01 ls = linsolve.LinProductSolver(d,sol0,w) sol = ls.solve() for k in sol: self.assertAlmostEqual(sol[k], eval(k), 4)
def test_real_solve(self): x,y,z = 1., 2., 3. keys = ['x*y', 'x*z', 'y*z'] d,w = {}, {} for k in keys: d[k],w[k] = eval(k), 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k)+.01 ls = linsolve.LinProductSolver(d,sol0,w) sol = ls.solve() for k in sol: #print sol0[k], sol[k] self.assertAlmostEqual(sol[k], eval(k), 4)
def test_single_term(self): x, y, z = 1., 2., 3. keys = ['x*y', 'x*z', '2*z'] d, w = {}, {} for k in keys: d[k], w[k] = eval(k), 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k) + .01 ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse) sol = ls.solve() for k in sol: self.assertAlmostEqual(sol[k], eval(k), 4)
def test_complex_array_solve(self): x = np.arange(30, dtype=np.complex); x.shape = (3,10) y = np.arange(30, dtype=np.complex); y.shape = (3,10) z = np.arange(30, dtype=np.complex); z.shape = (3,10) d,w = {'x*y':x*y, 'x*z':x*z, 'y*z':y*z}, {} for k in d.keys(): w[k] = np.ones(d[k].shape) sol0 = {} for k in 'xyz': sol0[k] = eval(k) + .01 ls = linsolve.LinProductSolver(d,sol0,w) ls.prm_order = {'x':0,'y':1,'z':2} sol = ls.solve() np.testing.assert_almost_equal(sol['x'], x, 2) np.testing.assert_almost_equal(sol['y'], y, 2) np.testing.assert_almost_equal(sol['z'], z, 2)
def test_complex_conj_solve(self): x,y,z = 1.+1j, 2.+2j, 3.+3j #x,y,z = 1., 2., 3. d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} for k in d.keys(): w[k] = 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k) + .01 ls = linsolve.LinProductSolver(d,sol0,w) ls.prm_order = {'x':0,'y':1,'z':2} sol = ls.solve() x,y,z = sol['x'], sol['y'], sol['z'] self.assertAlmostEqual(x*y.conjugate(), d['x*y_'], 3) self.assertAlmostEqual(x*z.conjugate(), d['x*z_'], 3) self.assertAlmostEqual(y*z.conjugate(), d['y*z_'], 3)
def test_init(self): x,y,z = 1.+1j, 2.+2j, 3.+3j d,w = {'x*y_':x*y.conjugate(), 'x*z_':x*z.conjugate(), 'y*z_':y*z.conjugate()}, {} for k in d.keys(): w[k] = 1. sol0 = {} for k in 'xyz': sol0[k] = eval(k)+.01 ls = linsolve.LinProductSolver(d,sol0,w) x,y,z = 1.,1.,1. x_,y_,z_ = 1.,1.,1. dx = dy = dz = .001 dx_ = dy_ = dz_ = .001 for k in ls.ls.keys: self.assertAlmostEqual(eval(k), 0.002) self.assertEqual(len(ls.ls.prms), 3)