def init(z, b): '''Forms the args that will be used to update stuff.''' term_err = eps * vecfield.norm(b) x = vecfield.zeros(b.shape) r = b - A(x, z) p = r return p, r, x, term_err
def test_loop_fns(self): b = onp.zeros((1, 1, 10, 10, 10), onp.complex128) b[0, 0, 5, 5, 5] = 1. b = vecfield.VecField(0 * b, 0 * b, b) z = vecfield.zeros((1, 1, 10, 10, 10)) x, errs = jaxwell.solve(z, b, ths=((10, 10), ) * 3, pml_params=operators.PmlParams(w_eff=0.3), max_iters=1) self.assertIsInstance(x, vecfield.VecField) self.assertEqual(x.shape, (1, 1, 10, 10, 10)) self.assertEqual(len(errs), 1) self.assertAlmostEqual(errs[0], 35.25115523)
def test_loop_fns(self): loop_init, loop_iter = solver.loop_fns( shape=(1, 1, 10, 10, 10), ths=((2, 2), ) * 3, pml_params=operators.PmlParams(w_eff=0.3), eps=1e-6) b = onp.zeros((1, 1, 10, 10, 10), onp.complex128) b[0, 0, 5, 5, 5] = 1. b = vecfield.VecField(0 * b, 0 * b, b) z = vecfield.zeros((1, 1, 10, 10, 10)) p, r, x, term_err = loop_init(z, b) onp.testing.assert_array_equal(p, r) onp.testing.assert_array_equal(x, onp.zeros_like(x)) self.assertEqual(term_err, 1e-6) p, r, x, err = loop_iter(p, r, x, z) self.assertAlmostEqual(err, 0.8660254)
def test_fns(self): shape = (1, 1, 10, 10, 10) ths = ((2, 2), ) * 3 pml_params = operators.PmlParams(w_eff=0.3) pre, inv_pre = operators.preconditioners(shape[2:], ths, pml_params) def A(x, z): return operators.operator(x, z, pre, inv_pre, ths, pml_params) b = onp.zeros(shape, onp.complex128) b[0, 0, 5, 5, 5] = 1. b = vecfield.VecField(0 * b, 0 * b, b) z = vecfield.zeros(shape) init, iter = cocg.solver(A, b, eps=1e-6) p, r, x, term_err = init(z, b) onp.testing.assert_array_equal(p, r) onp.testing.assert_array_equal(x, onp.zeros_like(x)) self.assertEqual(term_err, 1e-6) p, r, x, err = iter(p, r, x, z) self.assertAlmostEqual(err, 0.8660254)
def test_to_tuple(self): self.assertEqual( vf.to_tuple(vf.zeros((1, 1, 2, 3, 4)))[0].shape, (2, 3, 4))
def test_zeros(self): self.assertEqual(vf.zeros((10, 20, 30)).shape, (10, 20, 30)) self.assertEqual(vf.zeros((10, 20, 30)).dtype, np.complex128)