def check_derivs(self, mode, use_jit, nondiff): if nondiff: def apply_nl(a, b, c, x, ex1, ex2): R_x = a * x**2 + b * x + c return R_x def solve_nl(a, b, c, x, ex1, ex2): x = (-b + (b**2 - 4 * a * c)**0.5) / (2 * a) return x f = (omf.wrap(apply_nl).add_output( 'x', resid='R_x', val=0.0).declare_option('ex1', default='foo').declare_option( 'ex2', default='bar').declare_partials(of='*', wrt='*', method='jax')) else: def apply_nl(a, b, c, x): R_x = a * x**2 + b * x + c return R_x def solve_nl(a, b, c, x): x = (-b + (b**2 - 4 * a * c)**0.5) / (2 * a) return x f = (omf.wrap(apply_nl).add_output( 'x', resid='R_x', val=0.0).declare_partials(of='*', wrt='*', method='jax')) p = om.Problem() p.model.add_subsystem( 'comp', om.ImplicitFuncComp(f, solve_nonlinear=solve_nl, use_jit=use_jit)) # need this since comp is implicit and doesn't have a solve_linear p.model.comp.linear_solver = om.DirectSolver() p.setup(check=True, mode=mode) p.set_val('comp.a', 1.) p.set_val('comp.b', -4.) p.set_val('comp.c', 3.) p.run_model() J = p.compute_totals(wrt=['comp.a', 'comp.b', 'comp.c'], of=['comp.x', 'comp.x']) assert_near_equal(J['comp.x', 'comp.a'], [[-4.5]]) assert_near_equal(J['comp.x', 'comp.b'], [[-1.5]]) assert_near_equal(J['comp.x', 'comp.c'], [[-0.5]])
def test_inout_var(self): def func(a, b, c): x = a * b y = b * c return x, y f = (omf.wrap(func).add_input('a', units='m').add_input( 'b', units='inch').add_input('c', units='ft').add_output( 'x', units='cm').add_output('y', units='km')) invar_meta = list(f.get_input_meta()) self.assertEqual(list(f.get_input_names()), ['a', 'b', 'c']) self.assertEqual(invar_meta[0][1]['val'], 1.0) self.assertEqual(invar_meta[0][1]['shape'], ()) self.assertEqual(invar_meta[0][1]['units'], 'm') self.assertEqual(invar_meta[1][1]['val'], 1.0) self.assertEqual(invar_meta[1][1]['shape'], ()) self.assertEqual(invar_meta[1][1]['units'], 'inch') self.assertEqual(invar_meta[2][1]['val'], 1.0) self.assertEqual(invar_meta[2][1]['shape'], ()) self.assertEqual(invar_meta[2][1]['units'], 'ft') outvar_meta = list(f.get_output_meta()) self.assertEqual(list(f.get_output_names()), ['x', 'y']) self.assertEqual(outvar_meta[0][1]['val'], 1.0) self.assertEqual(outvar_meta[0][1]['shape'], ()) self.assertEqual(outvar_meta[0][1]['units'], 'cm') self.assertEqual(outvar_meta[0][1]['deps'], {'b', 'a'}) self.assertEqual(outvar_meta[1][1]['val'], 1.0) self.assertEqual(outvar_meta[1][1]['shape'], ()) self.assertEqual(outvar_meta[1][1]['units'], 'km') self.assertEqual(outvar_meta[1][1]['deps'], {'b', 'c'})
def test_apply_nonlinear_option(self): def apply_nl(a, b, c, x, opt): R_x = a * x**2 + b * x + c if opt == 'foo': R_x = -R_x return R_x f = (omf.wrap(apply_nl).add_output( 'x', resid='R_x', val=0.0).declare_option( 'opt', default='foo').declare_partials(of='*', wrt='*', method='cs')) p = om.Problem() p.model.add_subsystem('comp', om.ImplicitFuncComp(f)) # need this since comp is implicit and doesn't have a solve_linear p.model.linear_solver = om.DirectSolver() p.model.nonlinear_solver = om.NewtonSolver(solve_subsystems=False, iprint=0) p.setup() p.set_val('comp.a', 2.) p.set_val('comp.b', -8.) p.set_val('comp.c', 6.) p.run_model() assert_check_partials(p.check_partials(includes=['comp'], out_stream=None), atol=1e-5) assert_check_totals( p.check_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'], out_stream=None))
def __init__(self, compute, compute_partials=None, **kwargs): """ Initialize attributes. """ super().__init__(**kwargs) self._compute = omf.wrap(compute) # in case we're doing jit, force setup of wrapped func because we compute output shapes # during setup and that won't work on a jit compiled function if self._compute._call_setup: self._compute._setup() if self._compute._use_jax: self.options['use_jax'] = True if self.options['use_jax']: if jax is None: raise RuntimeError(f"{self.msginfo}: jax is not installed. Try 'pip install jax'.") self._compute_jax = omf.jax_decorate(self._compute._f) self._tangents = None self._compute_partials = compute_partials if self.options['use_jax'] and self.options['use_jit']: static_argnums = [i for i, m in enumerate(self._compute._inputs.values()) if 'is_option' in m] try: self._compute_jax = jit(self._compute_jax, static_argnums=static_argnums) except Exception as err: raise RuntimeError(f"{self.msginfo}: failed jit compile of compute function: {err}")
def test_return_names(self): def func(a): b = a + 1 # no return statement f = omf.wrap(func) self.assertEqual(f.get_return_names(), [])
def test_solve_nonlinear(self): def apply_nl(a, b, c, x): R_x = a * x**2 + b * x + c return R_x def solve_nl(a, b, c, x): x = (-b + (b**2 - 4 * a * c)**0.5) / (2 * a) return x f = (omf.wrap(apply_nl).add_output( 'x', resid='R_x', val=0.0).declare_partials(of='*', wrt='*', method='cs')) p = om.Problem() p.model.add_subsystem('comp', om.ImplicitFuncComp(f, solve_nonlinear=solve_nl)) # need this since comp is implicit and doesn't have a solve_linear p.model.linear_solver = om.DirectSolver() p.setup() p.set_val('comp.a', 2.) p.set_val('comp.b', -8.) p.set_val('comp.c', 6.) p.run_model() assert_check_partials(p.check_partials(includes=['comp'], out_stream=None), atol=1e-5) assert_check_totals( p.check_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'], out_stream=None))
def test_nometa(self): def func(a, b, c): x = a * b y = b * c return x, y f = omf.wrap(func) invar_meta = list(f.get_input_meta()) self.assertEqual(list(f.get_input_names()), ['a', 'b', 'c']) self.assertEqual(invar_meta[0][1]['val'], 1.0) self.assertEqual(invar_meta[0][1]['shape'], ()) self.assertEqual(invar_meta[1][1]['val'], 1.0) self.assertEqual(invar_meta[1][1]['shape'], ()) self.assertEqual(invar_meta[2][1]['val'], 1.0) self.assertEqual(invar_meta[2][1]['shape'], ()) outvar_meta = list(f.get_output_meta()) self.assertEqual(list(f.get_output_names()), ['x', 'y']) self.assertEqual(outvar_meta[0][1]['val'], 1.0) self.assertEqual(outvar_meta[0][1]['shape'], ()) self.assertEqual(outvar_meta[1][1]['val'], 1.0) self.assertEqual(outvar_meta[1][1]['shape'], ())
def check_derivs(self, mode, use_jit, method): def func(a, b, c, ex1, ex2): x = 2. * a * b + 3. * c y = 5. * a * c - 2.5 * b return x, y f = (omf.wrap(func) .defaults(shape=3) .declare_option('ex1', default='foo') .declare_option('ex2', default='bar') .declare_partials(of='*', wrt='*', method=method) ) p = om.Problem() p.model.add_subsystem('comp', om.ExplicitFuncComp(f, use_jit=use_jit)) p.setup(mode=mode) p.run_model() J = p.compute_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b', 'comp.c']) I = np.eye(3) assert_near_equal(J['comp.x', 'comp.a'], I * 2.) assert_near_equal(J['comp.x', 'comp.b'], I * 2.) assert_near_equal(J['comp.x', 'comp.c'], I * 3.) assert_near_equal(J['comp.y', 'comp.a'], I * 5.) assert_near_equal(J['comp.y', 'comp.b'], I * -2.5) assert_near_equal(J['comp.y', 'comp.c'], I * 5.)
def test_user_compute_partials_func(self): def J_func(x, y, z, J): # the following sub-jacs are 4x4 based on the sizes of foo, bar, x, and y, but the partials # were declared specifying rows and cols (in this case sub-jacs are diagonal), so we only # store the nonzero values of the sub-jacs, resulting in an actual size of 4 rather than 4x4. J['foo', 'x'] = -3*np.log(z)/(3*x+2*y)**2 J['foo', 'y'] = -2*np.log(z)/(3*x+2*y)**2 J['bar', 'x'] = 2.*np.ones(4) J['bar', 'y'] = np.ones(4) # z is a scalar so the true size of this sub-jac is 4x1 J['foo', 'z'] = 1/(z*(3*x+2*y)) def func(x=np.zeros(4), y=np.ones(4), z=3): foo = np.log(z)/(3*x+2*y) bar = 2.*x + y return foo, bar f = (omf.wrap(func) .defaults(units='m') .add_output('foo', units='1/m', shape=4) .add_output('bar', shape=4) .declare_partials(of='foo', wrt=('x', 'y'), rows=np.arange(4), cols=np.arange(4)) .declare_partials(of='foo', wrt='z') .declare_partials(of='bar', wrt=('x', 'y'), rows=np.arange(4), cols=np.arange(4))) p = om.Problem() p.model.add_subsystem('comp', om.ExplicitFuncComp(f, compute_partials=J_func)) p.setup(force_alloc_complex=True) p.run_model() assert_check_totals(p.check_totals(of=['comp.foo', 'comp.bar'], wrt=['comp.x', 'comp.y', 'comp.z'], method='cs'))
def test_abs_complex_step(self): def func(x=-2.0): y=2.0*abs(x) return y f = omf.wrap(func).declare_partials(of='*', wrt='*', method='cs') prob = om.Problem() C1 = prob.model.add_subsystem('C1', om.ExplicitFuncComp(f)) prob.setup() prob.set_solver_print(level=0) prob.run_model() assert_near_equal(C1._outputs['y'], 4.0, 0.00001) # any positive C1.x should give a 2.0 derivative for dy/dx C1._inputs['x'] = 1.0e-10 C1._linearize() assert_near_equal(C1._jacobian['y', 'x'], [[2.0]], 0.00001) C1._inputs['x'] = -3.0 C1._linearize() assert_near_equal(C1._jacobian['y', 'x'], [[-2.0]], 0.00001) C1._inputs['x'] = 0.0 C1._linearize() assert_near_equal(C1._jacobian['y', 'x'], [[2.0]], 0.00001)
def test_list_outputs_resids_tol(self): def func(a=2.0, b=5.0, c=3.0, x=np.ones(2)): y = a * x ** 2 + b * x + c return y f = omf.wrap(func).add_output('y', shape=2) prob = om.Problem() model = prob.model model.add_subsystem("quad_1", om.ExplicitFuncComp(f)) balance = model.add_subsystem("balance", om.BalanceComp()) balance.add_balance("x_1", val=np.array([1, -1]), rhs_val=np.array([0., 0.])) model.connect("balance.x_1", "quad_1.x") model.connect("quad_1.y", "balance.lhs:x_1") prob.model.linear_solver = om.ScipyKrylov() prob.model.nonlinear_solver = om.NewtonSolver(solve_subsystems=False, maxiter=100, iprint=0) prob.setup() prob.model.nonlinear_solver.options["maxiter"] = 0 prob.run_model() stream = StringIO() outputs = prob.model.list_outputs(residuals=True, residuals_tol=1e-5, out_stream=stream) text = stream.getvalue() self.assertTrue("balance" in text) self.assertTrue("x_1" in text)
def test_coloring1(self): mat, inshapes, outshapes = mat_factory(3, 2) outsizes = [np.prod(shp) for shp in outshapes] def func(a, b, c): ivec = np.hstack([a.flat, b.flat, c.flat]) ovec = mat.dot(ivec) x, y = ovec2outs(ovec, outsizes) return x, y f = (omf.wrap(func) .add_inputs(a={'shape': inshapes[0]}, b={'shape': inshapes[1]}, c={'shape': inshapes[2]}) .add_outputs(x={'shape': outshapes[0]}, y={'shape': outshapes[1]}) .declare_coloring(wrt='*', method='cs', show_summary=False) ) p = om.Problem() p.model.add_subsystem('comp', om.ExplicitFuncComp(f)) p.setup(mode='fwd') p.run_model() assert_check_totals(p.check_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b', 'comp.c'], method='cs', out_stream=None)) p.setup(mode='rev') p.run_model() assert_check_totals(p.check_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b', 'comp.c'], method='cs', out_stream=None))
def test_defaults(self): def func(a): x = a * 2.0 return x f = (omf.wrap(func) .defaults(units='cm', val=7., method='cs') .declare_partials(of='x', wrt='a') .declare_coloring(wrt='*')) invar_meta = list(f.get_input_meta()) self.assertEqual(list(f.get_input_names()), ['a']) self.assertEqual(invar_meta[0][1]['val'], 7.0) self.assertEqual(invar_meta[0][1]['units'], 'cm') self.assertEqual(invar_meta[0][1]['shape'], ()) outvar_meta = list(f.get_output_meta()) self.assertEqual(list(f.get_output_names()), ['x']) self.assertEqual(outvar_meta[0][1]['val'], 7.0) self.assertEqual(outvar_meta[0][1]['shape'], ()) self.assertEqual(outvar_meta[0][1]['units'], 'cm') partials_meta = list(f.get_declare_partials()) self.assertEqual(partials_meta[0]['method'], 'cs') coloring_meta = f.get_declare_coloring() self.assertEqual(coloring_meta['method'], 'cs')
def test_infer_outnames_replace_inpname(self): def func(a, b, c): x = a * b return x, c f = (omf.wrap(func) .add_output('q')) # replace second return value name with 'q' since 'c' is an input name self.assertEqual([n for n,_ in f.get_output_meta()], ['x', 'q'])
def __init__(self, apply_nonlinear, solve_nonlinear=None, linearize=None, solve_linear=None, **kwargs): """ Initialize attributes. """ super().__init__(**kwargs) self._apply_nonlinear_func = omf.wrap(apply_nonlinear) self._solve_nonlinear_func = solve_nonlinear self._solve_linear_func = solve_linear self._linearize_func = linearize self._linearize_info = None self._tangents = None self._jac2func_inds = None if solve_nonlinear: self.solve_nonlinear = self._user_solve_nonlinear if linearize: self.linearize = self._user_linearize if solve_linear: self.solve_linear = self._user_solve_linear if self._apply_nonlinear_func._use_jax: self.options['use_jax'] = True # setup requires an undecorated, unjitted function, so do it now if self._apply_nonlinear_func._call_setup: self._apply_nonlinear_func._setup() if self.options['use_jax']: if jax is None: raise RuntimeError( f"{self.msginfo}: jax is not installed. Try 'pip install jax'." ) self._apply_nonlinear_func_jax = omf.jax_decorate( self._apply_nonlinear_func._f) if self.options['use_jax'] and self.options['use_jit']: static_argnums = [ i for i, m in enumerate( self._apply_nonlinear_func._inputs.values()) if 'is_option' in m ] try: with omf.jax_context( self._apply_nonlinear_func._f.__globals__): self._apply_nonlinear_func_jax = jit( self._apply_nonlinear_func_jax, static_argnums=static_argnums) except Exception as err: raise RuntimeError( f"{self.msginfo}: failed jit compile of solve_nonlinear " f"function: {err}")
def test_declare_coloring(self): def func(a, b): x = a * b y = a / b return x, y f = (omf.wrap(func) .declare_coloring(wrt='*', method='cs')) meta = f.get_declare_coloring() self.assertEqual(meta, {'wrt': '*', 'method': 'cs'}) with self.assertRaises(Exception) as cm: f2 = (omf.wrap(func) .declare_coloring(wrt='a', method='cs') .declare_coloring(wrt='b', method='cs')) self.assertEqual(str(cm.exception), "declare_coloring has already been called.")
def test_declare_coloring(self): def func(a, b): x = a * b y = a / b return x, y f = (omf.wrap(func).declare_coloring(wrt='*', method='cs')) meta = f.get_declare_coloring() self.assertEqual(meta, {'wrt': '*', 'method': 'cs'})
def test_inout_vars(self): def func(a, b, c): x = a * b y = b * c return x, y f = (omf.wrap(func).add_inputs(a={ 'units': 'm' }, b={ 'units': 'inch', 'shape': 3, 'val': 7. }, c={ 'units': 'ft' }).add_outputs(x={ 'units': 'cm', 'shape': 3 }, y={ 'units': 'km', 'shape': 3 })) invar_meta = list(f.get_input_meta()) names = [n for n, _ in invar_meta] self.assertEqual(names, ['a', 'b', 'c']) self.assertEqual(invar_meta[0][1]['val'], 1.0) self.assertEqual(invar_meta[0][1]['shape'], ()) self.assertEqual(invar_meta[0][1]['units'], 'm') np.testing.assert_allclose(invar_meta[1][1]['val'], np.ones(3) * 7.) self.assertEqual(invar_meta[1][1]['shape'], (3, )) self.assertEqual(invar_meta[1][1]['units'], 'inch') self.assertEqual(invar_meta[2][1]['val'], 1.0) self.assertEqual(invar_meta[2][1]['shape'], ()) self.assertEqual(invar_meta[2][1]['units'], 'ft') outvar_meta = list(f.get_output_meta()) names = [n for n, _ in outvar_meta] self.assertEqual(names, ['x', 'y']) np.testing.assert_allclose(outvar_meta[0][1]['val'], np.ones(3)) self.assertEqual(outvar_meta[0][1]['shape'], (3, )) self.assertEqual(outvar_meta[0][1]['units'], 'cm') self.assertEqual(outvar_meta[0][1]['deps'], {'b', 'a'}) np.testing.assert_allclose(outvar_meta[1][1]['val'], np.ones(3)) self.assertEqual(outvar_meta[1][1]['shape'], (3, )) self.assertEqual(outvar_meta[1][1]['units'], 'km') self.assertEqual(outvar_meta[1][1]['deps'], {'b', 'c'})
def test_solve_lin_nl_linearize_reordered_args(self): def apply_nl(x, a, b, c): R_x = a * x**2 + b * x + c return R_x def solve_nl(x, a, b, c): x = (-b + (b**2 - 4 * a * c)**0.5) / (2 * a) return x def linearize(x, a, b, c, partials): partials['x', 'a'] = x**2 partials['x', 'b'] = x partials['x', 'c'] = 1.0 partials['x', 'x'] = 2 * a * x + b inv_jac = 1.0 / (2 * a * x + b) return inv_jac def solve_linear(d_x, mode, inv_jac): if mode == 'fwd': d_x = inv_jac * d_x return d_x elif mode == 'rev': dR_x = inv_jac * d_x return dR_x f = (omf.wrap(apply_nl).add_output('x', resid='R_x', val=0.0).declare_partials(of='*', wrt='*')) p = om.Problem() p.model.add_subsystem( 'comp', om.ImplicitFuncComp(f, solve_linear=solve_linear, linearize=linearize, solve_nonlinear=solve_nl)) p.setup() p.set_val('comp.a', 2.) p.set_val('comp.b', -8.) p.set_val('comp.c', 6.) p.run_model() assert_check_partials(p.check_partials(includes=['comp'], out_stream=None), atol=1e-5) assert_check_totals( p.check_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'], out_stream=None))
def test_declare_partials_jax_mixed2(self): def func(a, b): x = a * b y = a / b return x, y with self.assertRaises(Exception) as cm: f = (omf.wrap(func) .declare_partials(of='y', wrt=['a', 'b'], method='fd') .declare_partials(of='x', wrt=['a', 'b'], method='jax')) self.assertEqual(cm.exception.args[0], "If multiple calls to declare_partials() are made on the same function object and any set method='jax', then all must set method='jax'.")
def test_declare_partials(self): def func(a, b): x = a * b y = a / b return x, y f = (omf.wrap(func) .declare_partials(of='x', wrt=['a', 'b'], method='cs') .declare_partials(of='y', wrt=['a', 'b'], method='fd')) meta = list(f.get_declare_partials()) self.assertEqual(meta[0], {'of': 'x', 'wrt': ['a', 'b'], 'method': 'cs'}) self.assertEqual(meta[1], {'of': 'y', 'wrt': ['a', 'b'], 'method': 'fd'})
def test_jax_out_shape_compute(self): def func(a=np.ones((3,3)), b=np.ones((3,3))): x = a * b y = (a / b)[:,[1,2]] return x, y f = omf.wrap(func).declare_partials(of='*', wrt='*', method='jax') outvar_meta = list(f.get_output_meta()) self.assertEqual(list(f.get_output_names()), ['x', 'y']) self.assertEqual(outvar_meta[0][0], 'x') self.assertEqual(outvar_meta[0][1]['shape'], (3,3)) self.assertEqual(outvar_meta[1][0], 'y') self.assertEqual(outvar_meta[1][1]['shape'], (3,2))
def test_infer_outnames_err(self): def func(a, b, c): x = a * b y = b * c return x, y f = (omf.wrap(func) .add_output('q')) with self.assertRaises(Exception) as context: f.get_output_meta() self.assertEqual(context.exception.args[0], "There must be an unnamed return value for every unmatched output name ['q'] but only found 0.")
def test_complex_step_multivars(self): def func(a=np.arange(1,4,dtype=float), b=np.arange(3,6,dtype=float), c=np.arange(5,8,dtype=float)): x = a**2 + c * 3. y = b * -1. z = 1.5 * a + b * b - c return x, y, z f = (omf.wrap(func) .declare_partials(of='*', wrt='*', method='cs') .defaults(shape=3)) prob = om.Problem(om.Group()) prob.model.add_subsystem('comp', om.ExplicitFuncComp(f)) prob.set_solver_print(level=0) prob.setup(mode='fwd') prob.run_model() J = prob.compute_totals(of=['comp.x', 'comp.y', 'comp.z'], wrt=['comp.a', 'comp.b', 'comp.c'], return_format='flat_dict') Jcomp = prob.model.comp._jacobian._subjacs_info assert_near_equal(J['comp.x', 'comp.a'], np.diag(np.arange(1,4,dtype=float)*2.), 0.00001) assert_near_equal(J['comp.x', 'comp.b'], np.zeros((3,3)), 0.00001) assert_near_equal(J['comp.x', 'comp.c'], np.eye(3)*3., 0.00001) assert_near_equal(J['comp.y', 'comp.a'], np.zeros((3,3)), 0.00001) assert_near_equal(J['comp.y', 'comp.b'], -np.eye(3), 0.00001) assert_near_equal(J['comp.y', 'comp.c'], np.zeros((3,3)), 0.00001) assert_near_equal(J['comp.z', 'comp.a'], np.eye(3)*1.5, 0.00001) assert_near_equal(J['comp.z', 'comp.b'], np.diag(np.arange(3,6,dtype=float)*2.), 0.00001) assert_near_equal(J['comp.z', 'comp.c'], -np.eye(3), 0.00001) prob.setup(mode='rev') prob.run_model() J = prob.compute_totals(['comp.x', 'comp.y', 'comp.z'], wrt=['comp.a', 'comp.b', 'comp.c'], return_format='flat_dict') Jcomp = prob.model.comp._jacobian._subjacs_info assert_near_equal(J['comp.x', 'comp.a'], np.diag(np.arange(1,4,dtype=float)*2.), 0.00001) assert_near_equal(J['comp.x', 'comp.b'], np.zeros((3,3)), 0.00001) assert_near_equal(J['comp.x', 'comp.c'], np.eye(3)*3., 0.00001) assert_near_equal(J['comp.y', 'comp.a'], np.zeros((3,3)), 0.00001) assert_near_equal(J['comp.y', 'comp.b'], -np.eye(3), 0.00001) assert_near_equal(J['comp.y', 'comp.c'], np.zeros((3,3)), 0.00001) assert_near_equal(J['comp.z', 'comp.a'], np.eye(3)*1.5, 0.00001) assert_near_equal(J['comp.z', 'comp.b'], np.diag(np.arange(3,6,dtype=float)*2.), 0.00001) assert_near_equal(J['comp.z', 'comp.c'], -np.eye(3), 0.00001)
def test_jax_out_shape_check(self): def func(a=np.ones((3,3)), b=np.ones((3,3))): x = a * b y = (a / b)[:,[1,2]] return x, y f = (omf.wrap(func) .add_outputs(x={}, y={'shape': (3,3)}) .declare_partials(of='*', wrt='*', method='jax')) with self.assertRaises(Exception) as cm: outvar_meta = list(f.get_output_meta()) msg = "shape from metadata for return value 'y' of (3, 3) doesn't match computed shape of (3, 2)." self.assertEqual(cm.exception.args[0], msg)
def check_derivs(self, mode, m, n, o, p, q): def func(a, b, c): x = 2. * a.dot(b) y = 3. * c return x, y f = omf.wrap(func).declare_partials(of='*', wrt='*', method='jax') ishapes = {'a': (n,m), 'b': (m,o), 'c': (p,q)} oshapes = {'x': (n,o), 'y': (p,q)} for name in ['a', 'b', 'c']: f.add_input(name, shape=ishapes[name]) for name in ['x', 'y']: f.add_output(name, shape=oshapes[name]) rand_inputs = { n: np.random.random(ishapes[n]) for n in ('a', 'b', 'c') } p = om.Problem() p.model.add_subsystem('comp', om.ExplicitFuncComp(f, use_jax=True)) p.setup(mode=mode) for n in ('a', 'b', 'c'): p[f"comp.{n}"] = rand_inputs[n] p.run_model() J = p.compute_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b', 'comp.c']) p = om.Problem() p.model.add_subsystem('comp', om.ExecComp(['x=2.*a.dot(b)', 'y=3.*c'], x={'shape':oshapes['x']}, y={'shape':oshapes['y']}, a={'shape':ishapes['a']}, b={'shape':ishapes['b']}, c={'shape':ishapes['c']}, )) p.setup(mode=mode) for n in ('a', 'b', 'c'): p[f"comp.{n}"] = rand_inputs[n] p.run_model() Jchk = p.compute_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b', 'comp.c']) for out in ['comp.x', 'comp.y']: for inp in ['comp.a', 'comp.b', 'comp.c']: assert_near_equal(J[out, inp], Jchk[out, inp])
def test_array_lhs(self): def func(x=np.array([1., 2., 3.])): y=np.array([x[1], x[0]]) return y f = omf.wrap(func).add_output('y', shape=2) prob = om.Problem() C1 = prob.model.add_subsystem('C1', om.ExplicitFuncComp(f)) prob.setup() prob.set_solver_print(level=0) prob.run_model() assert_near_equal(C1._outputs['y'], np.array([2., 1.]), 0.00001)
def test_set_out_names(self): def func(a, b, c): return a * b, b * c f = (omf.wrap(func) .output_names('x', 'y')) outvar_meta = list(f.get_output_meta()) self.assertEqual(list(f.get_output_names()), ['x', 'y']) self.assertEqual(outvar_meta[0][1]['val'], 1.0) self.assertEqual(outvar_meta[0][1]['shape'], ()) self.assertEqual(outvar_meta[1][1]['val'], 1.0) self.assertEqual(outvar_meta[1][1]['shape'], ())
def test_declare_option(self): def func(a, opt): if opt == 'foo': x = a * 2.0 else: x = a / 2.0 return x f = (omf.wrap(func) .declare_option('opt', types=str, values=('foo', 'bar'), desc='an opt')) opt_meta = list(f.get_input_meta())[1] self.assertEqual(opt_meta[0], 'opt') self.assertEqual(opt_meta[1]['types'], str) self.assertEqual(opt_meta[1]['values'], ('foo', 'bar')) self.assertEqual(opt_meta[1]['desc'], 'an opt')
def check_derivs(self, mode, shape, use_jit): def func(a, b, c): x = 2. * a * b + 3. * c return x f = omf.wrap(func).defaults(shape=shape).declare_partials(of='*', wrt='*', method='jax') p = om.Problem() p.model.add_subsystem('comp', om.ExplicitFuncComp(f, use_jax=True, use_jit=use_jit)) p.setup(mode=mode) p.run_model() J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c']) I = np.eye(np.product(shape, dtype=int)) assert_near_equal(J['comp.x', 'comp.a'], I * 2.) assert_near_equal(J['comp.x', 'comp.b'], I * 2.) assert_near_equal(J['comp.x', 'comp.c'], I * 3.)