def __init__(self, **kwargs): self.grid_data = None self.options = OptionsDictionary() self.options.declare('num_segments', types=int, desc='Number of segments') self.options.declare( 'segment_ends', default=None, types=(Sequence, np.ndarray), allow_none=True, desc='Locations of segment ends or None for equally ' 'spaced segments') self.options.declare('order', default=3, types=(int, Sequence, np.ndarray), desc='Order of the state transcription') self.options.declare( 'compressed', default=True, types=bool, desc='Use compressed transcription, meaning state and control values' 'at segment boundaries are not duplicated on input. This ' 'implicitly enforces value continuity between segments but in ' 'some cases may make the problem more difficult to solve.') self._declare_options() self.initialize() self.options.update(kwargs)
def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set read-only option 'permanent'.") assertRegex(self, str(context.exception), expected_msg)
def test_locking(self): opt1 = OptionsDictionary() opt1.add_option('zzz', 10.0, lock_on_setup=True) opt2 = OptionsDictionary() opt2.add_option('xxx', 10.0, lock_on_setup=True) opt1['zzz'] = 15.0 opt2['xxx'] = 12.0 OptionsDictionary.locked = True with self.assertRaises(RuntimeError) as err: opt1['zzz'] = 14.0 expected_msg = "The 'zzz' option cannot be changed after setup." self.assertEqual(str(err.exception), expected_msg) with self.assertRaises(RuntimeError) as err: opt2['xxx'] = 13.0 expected_msg = "The 'xxx' option cannot be changed after setup." self.assertEqual(str(err.exception), expected_msg)
class TestOptions(unittest.TestCase): def test_options_dictionary(self): self.options = OptionsDictionary() # Make sure we can't address keys we haven't added with self.assertRaises(KeyError) as cm: self.options['junk'] self.assertEqual('"Option \'{}\' has not been added"'.format('junk'), str(cm.exception)) # Type checking - don't set a float with an int self.options.add_option('atol', 1e-6) self.assertEqual(self.options['atol'], 1.0e-6) with self.assertRaises(ValueError) as cm: self.options['atol'] = 1 if PY2: self.assertEqual("'atol' should be a '<type 'float'>'", str(cm.exception)) else: self.assertEqual("'atol' should be a '<class 'float'>'", str(cm.exception)) # Check enum out of range self.options.add_option('xyzzy', 0, values = [0, 1, 2, 3]) for value in [0,1,2,3]: self.options['xyzzy'] = value with self.assertRaises(ValueError) as cm: self.options['xyzzy'] = 4 self.assertEqual("'xyzzy' must be one of the following values: '[0, 1, 2, 3]'", str(cm.exception)) # Type checking for boolean self.options.add_option('conmin_diff', True) self.options['conmin_diff'] = True self.options['conmin_diff'] = False with self.assertRaises(ValueError) as cm: self.options['conmin_diff'] = "YES!" if PY2: self.assertEqual("'conmin_diff' should be a '<type 'bool'>'", str(cm.exception)) else: self.assertEqual("'conmin_diff' should be a '<class 'bool'>'", str(cm.exception)) # Test Max and Min self.options.add_option('maxiter', 10, lower=0, upper=10) for value in range(0, 11): self.options['maxiter'] = value with self.assertRaises(ValueError) as cm: self.options['maxiter'] = 15 self.assertEqual("maximum allowed value for 'maxiter' is '10'", str(cm.exception)) with self.assertRaises(ValueError) as cm: self.options['maxiter'] = -1 self.assertEqual("minimum allowed value for 'maxiter' is '0'", str(cm.exception)) # Make sure we can't do this with self.assertRaises(ValueError) as cm: self.options.maxiter = -1 self.assertEqual("Use dict-like access for option 'maxiter'", str(cm.exception)) #test removal self.assertTrue('conmin_diff' in self.options) self.options.remove_option('conmin_diff') self.assertFalse('conmin_diff' in self.options) # test Deprecation self.options._add_deprecation('max_iter', 'maxiter') with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. self.options['max_iter'] = 5 self.assertEqual(len(w), 1) self.assertEqual(str(w[0].message), "Option 'max_iter' is deprecated. Use 'maxiter' instead.") def test_locking(self): opt1 = OptionsDictionary() opt1.add_option('zzz', 10.0, lock_on_setup=True) opt2 = OptionsDictionary() opt2.add_option('xxx', 10.0, lock_on_setup=True) opt1['zzz'] = 15.0 opt2['xxx'] = 12.0 OptionsDictionary.locked = True with self.assertRaises(RuntimeError) as err: opt1['zzz'] = 14.0 expected_msg = "The 'zzz' option cannot be changed after setup." self.assertEqual(str(err.exception), expected_msg) with self.assertRaises(RuntimeError) as err: opt2['xxx'] = 13.0 expected_msg = "The 'xxx' option cannot be changed after setup." self.assertEqual(str(err.exception), expected_msg)
def test_bad_option_name(self): opt = OptionsDictionary() msg = "'foo:bar' is not a valid python name and will become an invalid option name in a future release. You can prevent this warning (and future exceptions) by declaring this option using a valid python name." with assert_warning(OMDeprecationWarning, msg): opt.declare('foo:bar', 1.0)
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() @unittest.skipIf(tabulate is None, reason="package 'tabulate' is not installed") def test_reprs(self): class MyComp(ExplicitComponent): pass my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.') self.assertEqual(repr(self.dict), repr(self.dict._dict)) self.assertEqual( self.dict.__str__(width=89), '\n'.join([ "========= ============ =================== ===================== ====================", "Option Default Acceptable Values Acceptable Types Description", "========= ============ =================== ===================== ====================", "comp MyComp N/A ['ExplicitComponent']", "flag False [True, False] ['bool']", "long_desc **Required** N/A ['str'] This description is ", " long and verbose, so", " it takes up multipl", " e lines in the optio", " ns table.", "test **Required** ['a', 'b'] N/A Test integer value", "========= ============ =================== ===================== ====================", ])) # if the table can't be represented in specified width, then we get the full width version self.assertEqual( self.dict.__str__(width=40), '\n'.join([ "========= ============ =================== ===================== =====================" "====================================================================", "Option Default Acceptable Values Acceptable Types Description", "========= ============ =================== ===================== =====================" "====================================================================", "comp MyComp N/A ['ExplicitComponent']", "flag False [True, False] ['bool']", "long_desc **Required** N/A ['str'] This description is l" "ong and verbose, so it takes up multiple lines in the options table.", "test **Required** ['a', 'b'] N/A Test integer value", "========= ============ =================== ===================== =====================" "====================================================================", ])) @unittest.skipIf(tabulate is None, reason="package 'tabulate' is not installed") def test_to_table(self): class MyComp(ExplicitComponent): pass my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.') expected = "| Option | Default | Acceptable Values | Acceptable Types " \ "| Description " \ " |\n" \ "|-----------|--------------|---------------------|-----------------------|--" \ "----------------------------------------------------------------------------" \ "-------------|\n" \ "| comp | MyComp | N/A | ['ExplicitComponent'] | " \ " " \ " |\n" \ "| flag | False | [True, False] | ['bool'] | " \ " " \ " |\n" \ "| long_desc | **Required** | N/A | ['str'] | Th" \ "is description is long and verbose, so it takes up multiple lines in the opti" \ "ons table. |\n" \ "| test | **Required** | ['a', 'b'] | N/A | Te" \ "st integer value " \ " |" self.assertEqual(self.dict.to_table(fmt='github'), expected) @unittest.skipIf(tabulate is None, reason="package 'tabulate' is not installed") def test_deprecation_col(self): class MyComp(ExplicitComponent): pass my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.', deprecation='This option is deprecated') expected = "|Option|Default|AcceptableValues|AcceptableTypes|Description|Deprecation|\n|" \ "-----------|--------------|---------------------|-----------------------|-----------------" \ "--------------------------------------------------------------------------|---------------" \ "------------|\n|comp|MyComp|N/A|['ExplicitComponent']||N/A|\n|flag|False|[True,False]|" \ "['bool']||N/A|\n|long_desc|**Required**|N/A|['str']|Thisdescriptionislongandverbose,soit" \ "takesupmultiplelinesintheoptionstable.|Thisoptionisdeprecated|\n|test|**Required**|" \ "['a','b']|N/A|Testintegervalue|N/A|" self.assertEqual( self.dict.to_table(fmt='github').replace(" ", ""), expected) my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.') expected = "|Option|Default|AcceptableValues|AcceptableTypes|Description|\n|-----------|----" \ "----------|---------------------|-----------------------|-----------------------------------" \ "--------------------------------------------------------|\n|comp|MyComp|N/A|" \ "['ExplicitComponent']||\n|flag|False|[True,False]|['bool']||\n|long_desc|**Required**|N/A|" \ "['str']|Thisdescriptionislongandverbose,soittakesupmultiplelinesintheoptionstable.|\n|test|" \ "**Required**|['a','b']|N/A|Testintegervalue|" self.assertEqual( self.dict.to_table(fmt='github').replace(" ", ""), expected) def test_type_checking(self): self.dict.declare('test', types=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' expected_msg = "Value ('') of option 'test' has type 'str', " \ "but type 'int' was expected." self.assertEqual(expected_msg, str(context.exception)) # multiple types are allowed self.dict.declare('test_multi', types=(int, float), desc='Test multiple types') self.dict['test_multi'] = 1 self.assertEqual(self.dict['test_multi'], 1) self.assertEqual(type(self.dict['test_multi']), int) self.dict['test_multi'] = 1.0 self.assertEqual(self.dict['test_multi'], 1.0) self.assertEqual(type(self.dict['test_multi']), float) with self.assertRaises(TypeError) as context: self.dict['test_multi'] = '' expected_msg = "Value ('') of option 'test_multi' has type 'str', " \ "but one of types ('int', 'float') was expected." self.assertEqual(expected_msg, str(context.exception)) # make sure bools work and allowed values are populated self.dict.declare('flag', default=False, types=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) meta = self.dict._dict['flag'] self.assertEqual(meta['values'], (True, False)) def test_allow_none(self): self.dict.declare('test', types=int, allow_none=True, desc='Test integer value') self.dict['test'] = None self.assertEqual(self.dict['test'], None) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', types=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values with self.assertRaises(Exception) as context: self.dict.declare('test3', types=int, values=['a', 'b']) self.assertEqual( str(context.exception), "'types' and 'values' were both specified for option 'test3'.") def test_check_valid_template(self): # test the template 'check_valid' function from openmdao.utils.options_dictionary import check_valid self.dict.declare('test', check_valid=check_valid) with self.assertRaises(ValueError) as context: self.dict['test'] = 1 expected_msg = "Option 'test' with value 1 is not valid." self.assertEqual(expected_msg, str(context.exception)) def test_isvalid(self): self.dict.declare('even_test', types=int, check_valid=check_even) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Option 'even_test' with value 3 is not an even number." self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Option 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', types=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Option 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "Option 'missing' has not been declared." self.assertEqual(expected_msg, context.exception.args[0]) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, types=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ( "Value \(<object object at 0x[0-9A-Fa-f]+>\) of option 'test' is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\].") self.assertRegex(str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set read-only option 'permanent'.") self.assertRegex(str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = "Value (3.0) of option 'x' exceeds maximum allowed value of 2.0." self.assertEqual(str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = "Value (-3.0) of option 'x' is less than minimum allowed value of 0.0." self.assertEqual(str(context.exception), expected_msg) def test_undeclare(self): # create an entry in the dict self.dict.declare('test', types=int) self.dict['test'] = 1 # prove it's in the dict self.assertEqual(self.dict['test'], 1) # remove entry from the dict self.dict.undeclare("test") # prove it is no longer in the dict with self.assertRaises(KeyError) as context: self.dict['test'] expected_msg = "Option 'test' has not been declared." self.assertEqual(expected_msg, context.exception.args[0]) def test_deprecated_option(self): msg = 'Option "test1" is deprecated.' self.dict.declare('test1', deprecation=msg) # test double set with assert_warning(OMDeprecationWarning, msg): self.dict['test1'] = None # Should only generate warning first time with assert_no_warning(OMDeprecationWarning, msg): self.dict['test1'] = None # Also test set and then get msg = 'Option "test2" is deprecated.' self.dict.declare('test2', deprecation=msg) with assert_warning(OMDeprecationWarning, msg): self.dict['test2'] = None # Should only generate warning first time with assert_no_warning(OMDeprecationWarning, msg): option = self.dict['test2'] def test_deprecated_tuple_option(self): msg = 'Option "test1" is deprecated. Use "foo" instead.' self.dict.declare('test1', deprecation=(msg, 'foo')) self.dict.declare('foo') # test double set with assert_warning(OMDeprecationWarning, msg): self.dict['test1'] = 'xyz' # Should only generate warning first time with assert_no_warning(OMDeprecationWarning, msg): self.dict['test1'] = 'zzz' with assert_no_warning(OMDeprecationWarning, msg): option = self.dict['test1'] with assert_no_warning(OMDeprecationWarning): option2 = self.dict['foo'] self.assertEqual(option, option2) # Also test set and then get msg = 'Option "test2" is deprecated. Use "foo2" instead.' self.dict.declare('test2', deprecation=(msg, 'foo2')) self.dict.declare('foo2') with assert_warning(OMDeprecationWarning, msg): self.dict['test2'] = 'abcd' # Should only generate warning first time with assert_no_warning(OMDeprecationWarning, msg): option = self.dict['test2'] with assert_no_warning(OMDeprecationWarning): option2 = self.dict['foo2'] self.assertEqual(option, option2) # test bad alias msg = 'Option "test3" is deprecated. Use "foo3" instead.' self.dict.declare('test3', deprecation=(msg, 'foo3')) with self.assertRaises(KeyError) as context: self.dict['test3'] = 'abcd' expected_msg = "Can't find aliased option 'foo3' for deprecated option 'test3'." self.assertEqual(context.exception.args[0], expected_msg) def test_bad_option_name(self): opt = OptionsDictionary() msg = "'foo:bar' is not a valid python name and will become an invalid option name in a future release. You can prevent this warning (and future exceptions) by declaring this option using a valid python name." with assert_warning(OMDeprecationWarning, msg): opt.declare('foo:bar', 1.0)
def test_options_dictionary(self): self.options = OptionsDictionary() # Make sure we can't address keys we haven't added with self.assertRaises(KeyError) as cm: self.options['junk'] self.assertEqual('"Option \'{}\' has not been added"'.format('junk'), str(cm.exception)) # Type checking - don't set a float with an int self.options.add_option('atol', 1e-6) self.assertEqual(self.options['atol'], 1.0e-6) with self.assertRaises(ValueError) as cm: self.options['atol'] = 1 if PY2: self.assertEqual("'atol' should be a '<type 'float'>'", str(cm.exception)) else: self.assertEqual("'atol' should be a '<class 'float'>'", str(cm.exception)) # Check enum out of range self.options.add_option('iprint', 0, values=[0, 1, 2, 3]) for value in [0, 1, 2, 3]: self.options['iprint'] = value with self.assertRaises(ValueError) as cm: self.options['iprint'] = 4 self.assertEqual( "'iprint' must be one of the following values: '[0, 1, 2, 3]'", str(cm.exception)) # Type checking for boolean self.options.add_option('conmin_diff', True) self.options['conmin_diff'] = True self.options['conmin_diff'] = False with self.assertRaises(ValueError) as cm: self.options['conmin_diff'] = "YES!" if PY2: self.assertEqual("'conmin_diff' should be a '<type 'bool'>'", str(cm.exception)) else: self.assertEqual("'conmin_diff' should be a '<class 'bool'>'", str(cm.exception)) # Test Max and Min self.options.add_option('maxiter', 10, lower=0, upper=10) for value in range(0, 11): self.options['maxiter'] = value with self.assertRaises(ValueError) as cm: self.options['maxiter'] = 15 self.assertEqual("maximum allowed value for 'maxiter' is '10'", str(cm.exception)) with self.assertRaises(ValueError) as cm: self.options['maxiter'] = -1 self.assertEqual("minimum allowed value for 'maxiter' is '0'", str(cm.exception)) # Make sure we can't do this with self.assertRaises(ValueError) as cm: self.options.maxiter = -1 self.assertEqual("Use dict-like access for option 'maxiter'", str(cm.exception)) #test removal self.assertTrue('conmin_diff' in self.options) self.options.remove_option('conmin_diff') self.assertFalse('conmin_diff' in self.options) # test Deprecation self.options._add_deprecation('max_iter', 'maxiter') with warnings.catch_warnings(record=True) as w: # Cause all warnings to always be triggered. warnings.simplefilter("always") # Trigger a warning. self.options['max_iter'] = 5 self.assertEqual(len(w), 1) self.assertEqual( str(w[0].message), "Option 'max_iter' is deprecated. Use 'maxiter' instead.")
def test_options_dictionary(self): self.options = OptionsDictionary() # Make sure we can't address keys we haven't added with self.assertRaises(KeyError) as cm: self.options['junk'] self.assertEqual('"Option \'{}\' has not been added"'.format('junk'), str(cm.exception)) # Type checking - don't set a float with an int self.options.add_option('atol', 1e-6) self.assertEqual(self.options['atol'], 1.0e-6) with self.assertRaises(ValueError) as cm: self.options['atol'] = 1 if PY2: self.assertEqual("'atol' should be a '<type 'float'>'", str(cm.exception)) else: self.assertEqual("'atol' should be a '<class 'float'>'", str(cm.exception)) # Check enum out of range self.options.add_option('iprint', 0, values = [0, 1, 2, 3]) for value in [0,1,2,3]: self.options['iprint'] = value with self.assertRaises(ValueError) as cm: self.options['iprint'] = 4 self.assertEqual("'iprint' must be one of the following values: '[0, 1, 2, 3]'", str(cm.exception)) # Type checking for boolean self.options.add_option('conmin_diff', True) self.options['conmin_diff'] = True self.options['conmin_diff'] = False with self.assertRaises(ValueError) as cm: self.options['conmin_diff'] = "YES!" if PY2: self.assertEqual("'conmin_diff' should be a '<type 'bool'>'", str(cm.exception)) else: self.assertEqual("'conmin_diff' should be a '<class 'bool'>'", str(cm.exception)) # Test Max and Min self.options.add_option('maxiter', 10, low=0, high=10) for value in range(0, 11): self.options['maxiter'] = value with self.assertRaises(ValueError) as cm: self.options['maxiter'] = 15 self.assertEqual("maximum allowed value for 'maxiter' is '10'", str(cm.exception)) with self.assertRaises(ValueError) as cm: self.options['maxiter'] = -1 self.assertEqual("minimum allowed value for 'maxiter' is '0'", str(cm.exception)) # Make sure we can't do this with self.assertRaises(ValueError) as cm: self.options.maxiter = -1 self.assertEqual("Use dict-like access for option 'maxiter'", str(cm.exception))
def setUp(self): self.dict = OptionsDictionary()
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() def test_reprs(self): class MyComp(ExplicitComponent): pass my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.') self.assertEqual(repr(self.dict), repr(self.dict._dict)) self.assertEqual(self.dict.__str__(width=83), '\n'.join([ "========= ============ ================= ===================== ====================", "Option Default Acceptable Values Acceptable Types Description ", "========= ============ ================= ===================== ====================", "comp MyComp N/A ['ExplicitComponent'] ", "flag False [True, False] ['bool'] ", "long_desc **Required** N/A ['str'] This description is ", " long and verbose, so", " it takes up multipl", " e lines in the optio", " ns table.", "test **Required** ['a', 'b'] N/A Test integer value ", "========= ============ ================= ===================== ====================", ])) # if the table can't be represented in specified width, then we get the full width version self.assertEqual(self.dict.__str__(width=40), '\n'.join([ "========= ============ ================= ===================== =====================" "==================================================================== ", "Option Default Acceptable Values Acceptable Types Description " " ", "========= ============ ================= ===================== =====================" "==================================================================== ", "comp MyComp N/A ['ExplicitComponent'] " " ", "flag False [True, False] ['bool'] " " ", "long_desc **Required** N/A ['str'] This description is l" "ong and verbose, so it takes up multiple lines in the options table. ", "test **Required** ['a', 'b'] N/A Test integer value " " ", "========= ============ ================= ===================== =====================" "==================================================================== ", ])) def test_type_checking(self): self.dict.declare('test', types=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' expected_msg = "Value ('') of option 'test' has type 'str', " \ "but type 'int' was expected." self.assertEqual(expected_msg, str(context.exception)) # multiple types are allowed self.dict.declare('test_multi', types=(int, float), desc='Test multiple types') self.dict['test_multi'] = 1 self.assertEqual(self.dict['test_multi'], 1) self.assertEqual(type(self.dict['test_multi']), int) self.dict['test_multi'] = 1.0 self.assertEqual(self.dict['test_multi'], 1.0) self.assertEqual(type(self.dict['test_multi']), float) with self.assertRaises(TypeError) as context: self.dict['test_multi'] = '' expected_msg = "Value ('') of option 'test_multi' has type 'str', " \ "but one of types ('int', 'float') was expected." self.assertEqual(expected_msg, str(context.exception)) # make sure bools work and allowed values are populated self.dict.declare('flag', default=False, types=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) meta = self.dict._dict['flag'] self.assertEqual(meta['values'], (True, False)) def test_allow_none(self): self.dict.declare('test', types=int, allow_none=True, desc='Test integer value') self.dict['test'] = None self.assertEqual(self.dict['test'], None) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', types=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values with self.assertRaises(Exception) as context: self.dict.declare('test3', types=int, values=['a', 'b']) self.assertEqual(str(context.exception), "'types' and 'values' were both specified for option 'test3'.") def test_isvalid(self): self.dict.declare('even_test', types=int, check_valid=check_even) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Option 'even_test' with value 3 is not an even number." self.assertEqual(expected_msg, str(context.exception)) def test_isvalid_deprecated_type(self): msg = "In declaration of option 'even_test' the '_type' arg is deprecated. Use 'types' instead." with assert_warning(DeprecationWarning, msg): self.dict.declare('even_test', type_=int, check_valid=check_even) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Option 'even_test' with value 3 is not an even number." self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Option 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', types=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Option 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "\"Option 'missing' cannot be found\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, types=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ("Value \(<object object at 0x[0-9A-Fa-f]+>\) of option 'test' is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\].") assertRegex(self, str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set read-only option 'permanent'.") assertRegex(self, str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = "Value (3.0) of option 'x' exceeds maximum allowed value of 2.0." self.assertEqual(str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = "Value (-3.0) of option 'x' is less than minimum allowed value of 0.0." self.assertEqual(str(context.exception), expected_msg) def test_undeclare(self): # create an entry in the dict self.dict.declare('test', types=int) self.dict['test'] = 1 # prove it's in the dict self.assertEqual(self.dict['test'], 1) # remove entry from the dict self.dict.undeclare("test") # prove it is no longer in the dict with self.assertRaises(KeyError) as context: self.dict['test'] expected_msg = "\"Option 'test' cannot be found\"" self.assertEqual(expected_msg, str(context.exception))
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() def test_reprs(self): class MyComp(ExplicitComponent): pass my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.') self.assertEqual(self.dict.__repr__(), self.dict._dict) self.assertEqual(self.dict.__str__(width=83), '\n'.join([ "========= ============ ================= ===================== ====================", "Option Default Acceptable Values Acceptable Types Description ", "========= ============ ================= ===================== ====================", "comp MyComp N/A ['ExplicitComponent'] ", "flag False N/A ['bool'] ", "long_desc **Required** N/A ['str'] This description is ", " long and verbose, so", " it takes up multipl", " e lines in the optio", " ns table.", "test **Required** ['a', 'b'] N/A Test integer value ", "========= ============ ================= ===================== ====================", ])) # if the table can't be represented in specified width, then we get the full width version self.assertEqual(self.dict.__str__(width=40), '\n'.join([ "========= ============ ================= ===================== =====================" "==================================================================== ", "Option Default Acceptable Values Acceptable Types Description " " ", "========= ============ ================= ===================== =====================" "==================================================================== ", "comp MyComp N/A ['ExplicitComponent'] " " ", "flag False N/A ['bool'] " " ", "long_desc **Required** N/A ['str'] This description is l" "ong and verbose, so it takes up multiple lines in the options table. ", "test **Required** ['a', 'b'] N/A Test integer value " " ", "========= ============ ================= ===================== =====================" "==================================================================== ", ])) def test_type_checking(self): self.dict.declare('test', types=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' class_or_type = 'class' if PY3 else 'type' expected_msg = "Option 'test' has the wrong type (<{} 'int'>)".format(class_or_type) self.assertEqual(expected_msg, str(context.exception)) # make sure bools work self.dict.declare('flag', default=False, types=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) def test_allow_none(self): self.dict.declare('test', types=int, allow_none=True, desc='Test integer value') self.dict['test'] = None self.assertEqual(self.dict['test'], None) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', types=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values with self.assertRaises(Exception) as context: self.dict.declare('test3', types=int, values=['a', 'b']) self.assertEqual(str(context.exception), "'types' and 'values' were both specified for option 'test3'.") def test_isvalid(self): self.dict.declare('even_test', types=int, is_valid=lambda x: x % 2 == 0) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_isvalid_deprecated_type(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.dict.declare('even_test', type_=int, is_valid=lambda x: x % 2 == 0) self.assertEqual(len(w), 1) self.assertEqual(str(w[-1].message), "In declaration of option 'even_test' the '_type' arg is deprecated. Use 'types' instead.") self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', types=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "\"Option 'missing' cannot be found\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, types=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ("Option 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\]") assertRegex(self, str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary") assertRegex(self, str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for option 'x'") assertRegex(self, str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for option 'x'") assertRegex(self, str(context.exception), expected_msg) def test_undeclare(self): # create an entry in the dict self.dict.declare('test', types=int) self.dict['test'] = 1 # prove it's in the dict self.assertEqual(self.dict['test'], 1) # remove entry from the dict self.dict.undeclare("test") # prove it is no longer in the dict with self.assertRaises(KeyError) as context: self.dict['test'] expected_msg = "\"Option 'test' cannot be found\"" self.assertEqual(expected_msg, str(context.exception))
class TranscriptionBase(object): def __init__(self, **kwargs): self.grid_data = None self.options = OptionsDictionary() self.options.declare('num_segments', types=int, desc='Number of segments') self.options.declare( 'segment_ends', default=None, types=(Sequence, np.ndarray), allow_none=True, desc='Locations of segment ends or None for equally ' 'spaced segments') self.options.declare('order', default=3, types=(int, Sequence, np.ndarray), desc='Order of the state transcription') self.options.declare( 'compressed', default=True, types=bool, desc='Use compressed transcription, meaning state and control values' 'at segment boundaries are not duplicated on input. This ' 'implicitly enforces value continuity between segments but in ' 'some cases may make the problem more difficult to solve.') self._declare_options() self.initialize() self.options.update(kwargs) def _declare_options(self): pass def initialize(self): pass def setup_grid(self, phase): """ Setup the GridData object for the Transcription Parameters ---------- phase The phase to which this transcription applies. """ raise NotImplementedError('Transcription {0} does not implement method' 'setup_grid.'.format( self.__class__.__name__)) def setup_time(self, phase): """ Setup up the time component and time extents for the phase. Returns ------- comps A list of the component names needed for time extents. """ time_options = phase.time_options time_units = time_options['units'] indeps = [] default_vals = { 't_initial': phase.time_options['initial_val'], 't_duration': phase.time_options['duration_val'] } externals = [] comps = [] # Warn about invalid options phase.check_time_options() if time_options['input_initial']: externals.append('t_initial') else: indeps.append('t_initial') # phase.connect('t_initial', 'time.t_initial') if time_options['input_duration']: externals.append('t_duration') else: indeps.append('t_duration') # phase.connect('t_duration', 'time.t_duration') if indeps: indep = IndepVarComp() for var in indeps: indep.add_output(var, val=default_vals[var], units=time_units) phase.add_subsystem('time_extents', indep, promotes_outputs=['*']) comps += ['time_extents'] if not (time_options['input_initial'] or time_options['fix_initial']): lb, ub = time_options['initial_bounds'] lb = -INF_BOUND if lb is None else lb ub = INF_BOUND if ub is None else ub phase.add_design_var('t_initial', lower=lb, upper=ub, scaler=time_options['initial_scaler'], adder=time_options['initial_adder'], ref0=time_options['initial_ref0'], ref=time_options['initial_ref']) if not (time_options['input_duration'] or time_options['fix_duration']): lb, ub = time_options['duration_bounds'] lb = -INF_BOUND if lb is None else lb ub = INF_BOUND if ub is None else ub phase.add_design_var('t_duration', lower=lb, upper=ub, scaler=time_options['duration_scaler'], adder=time_options['duration_adder'], ref0=time_options['duration_ref0'], ref=time_options['duration_ref']) def setup_controls(self, phase): """ Adds an IndepVarComp if necessary and issues appropriate connections based on transcription. """ phase._check_control_options() if phase.control_options: control_group = ControlGroup( control_options=phase.control_options, time_units=phase.time_options['units'], grid_data=self.grid_data) phase.add_subsystem( 'control_group', subsys=control_group, promotes=['controls:*', 'control_values:*', 'control_rates:*']) phase.connect('dt_dstau', 'control_group.dt_dstau') def setup_polynomial_controls(self, phase): """ Adds the polynomial control group to the model if any polynomial controls are present. """ if phase.polynomial_control_options: sys = PolynomialControlGroup( grid_data=self.grid_data, polynomial_control_options=phase.polynomial_control_options, time_units=phase.time_options['units']) phase.add_subsystem('polynomial_control_group', subsys=sys, promotes_inputs=['*'], promotes_outputs=['*']) def setup_design_parameters(self, phase): """ Adds an IndepVarComp if necessary and issues appropriate connections based on transcription. """ phase._check_design_parameter_options() if phase.design_parameter_options: indep = phase.add_subsystem('design_params', subsys=IndepVarComp(), promotes_outputs=['*']) for name, options in iteritems(phase.design_parameter_options): src_name = 'design_parameters:{0}'.format(name) if options['opt']: lb = -INF_BOUND if options['lower'] is None else options[ 'lower'] ub = INF_BOUND if options['upper'] is None else options[ 'upper'] phase.add_design_var(name=src_name, lower=lb, upper=ub, scaler=options['scaler'], adder=options['adder'], ref0=options['ref0'], ref=options['ref']) _shape = (1, ) + options['shape'] indep.add_output(name=src_name, val=options['val'], shape=_shape, units=options['units']) for tgts, src_idxs in self.get_parameter_connections( name, phase): phase.connect(src_name, [t for t in tgts], src_indices=src_idxs, flat_src_indices=True) def setup_input_parameters(self, phase): """ Adds a InputParameterComp to allow input parameters to be connected from sources external to the phase. """ if phase.input_parameter_options: passthru = InputParameterComp( input_parameter_options=phase.input_parameter_options) phase.add_subsystem('input_params', subsys=passthru, promotes_inputs=['*'], promotes_outputs=['*']) for name in phase.input_parameter_options: src_name = 'input_parameters:{0}_out'.format(name) for tgts, src_idxs in self.get_parameter_connections(name, phase): phase.connect(src_name, [t for t in tgts], src_indices=src_idxs, flat_src_indices=True) def setup_traj_parameters(self, phase): """ Adds a InputParameterComp to allow input parameters to be connected from sources external to the phase. """ if phase.traj_parameter_options: passthru = \ InputParameterComp(input_parameter_options=phase.traj_parameter_options, traj_params=True) phase.add_subsystem('traj_params', subsys=passthru, promotes_inputs=['*'], promotes_outputs=['*']) for name, options in iteritems(phase.traj_parameter_options): src_name = 'traj_parameters:{0}_out'.format(name) for tgts, src_idxs in self.get_parameter_connections(name, phase): phase.connect(src_name, [t for t in tgts], src_indices=src_idxs) def setup_states(self, phase): raise NotImplementedError( 'Transcription {0} does not implement method ' 'setup_states.'.format(self.__class__.__name__)) def setup_ode(self, phase): raise NotImplementedError( 'Transcription {0} does not implement method ' 'setup_ode.'.format(self.__class__.__name__)) def setup_timeseries_outputs(self, phase): raise NotImplementedError( 'Transcription {0} does not implement method ' 'setup_timeseries_outputs.'.format(self.__class__.__name__)) def setup_boundary_constraints(self, loc, phase): """ Adds BoundaryConstraintComp for initial and/or final boundary constraints if necessary and issues appropriate connections. Parameters ---------- loc : str The kind of boundary constraints being setup. Must be one of 'initial' or 'final'. phase The phase object to which this transcription instance applies. """ if loc not in ('initial', 'final'): raise ValueError('loc must be one of \'initial\' or \'final\'.') bc_comp = None bc_dict = phase._initial_boundary_constraints \ if loc == 'initial' else phase._final_boundary_constraints if bc_dict: bc_comp = phase.add_subsystem( '{0}_boundary_constraints'.format(loc), subsys=BoundaryConstraintComp(loc=loc)) for var, options in iteritems(bc_dict): con_name = options['constraint_name'] # Constraint options are a copy of options with constraint_name key removed. con_options = options.copy() con_options.pop('constraint_name') src, shape, units, linear = self._get_boundary_constraint_src( var, loc, phase) con_units = options.get('units', None) shape = options['shape'] if shape is None else shape if shape is None: shape = (1, ) if options['indices'] is not None: # Indices are provided, make sure lower/upper/equals are compatible. con_shape = (len(options['indices']), ) # Indices provided, make sure lower/upper/equals have shape of the indices. if options['lower'] and not np.isscalar(options['lower']) and \ np.asarray(options['lower']).shape != con_shape: raise ValueError( 'The lower bounds of boundary constraint on {0} are not ' 'compatible with its shape, and no indices were ' 'provided.'.format(var)) if options['upper'] and not np.isscalar(options['upper']) and \ np.asarray(options['upper']).shape != con_shape: raise ValueError( 'The upper bounds of boundary constraint on {0} are not ' 'compatible with its shape, and no indices were ' 'provided.'.format(var)) if options['equals'] and not np.isscalar(options['equals']) and \ np.asarray(options['equals']).shape != con_shape: raise ValueError( 'The equality boundary constraint value on {0} is not ' 'compatible the provided indices. Provide them as a ' 'flat array with the same size as indices.'.format( var)) elif options['lower'] or options['upper'] or options['equals']: # Indices not provided, make sure lower/upper/equals have shape of source. if options['lower'] and not np.isscalar(options['lower']) and \ np.asarray(options['lower']).shape != shape: raise ValueError( 'The lower bounds of boundary constraint on {0} are not ' 'compatible with its shape, and no indices were ' 'provided.'.format(var)) if options['upper'] and not np.isscalar(options['upper']) and \ np.asarray(options['upper']).shape != shape: raise ValueError( 'The upper bounds of boundary constraint on {0} are not ' 'compatible with its shape, and no indices were ' 'provided.'.format(var)) if options['equals'] and not np.isscalar(options['equals']) \ and np.asarray(options['equals']).shape != shape: raise ValueError( 'The equality boundary constraint value on {0} is not ' 'compatible with its shape, and no indices were ' 'provided.'.format(var)) con_shape = (np.prod(shape), ) size = np.prod(shape) con_options['shape'] = shape if shape is not None else con_shape con_options['units'] = units if con_units is None else con_units con_options['linear'] = linear # Build the correct src_indices regardless of shape if loc == 'initial': src_idxs = np.arange(size, dtype=int).reshape(shape) else: src_idxs = np.arange(-size, 0, dtype=int).reshape(shape) bc_comp._add_constraint(con_name, **con_options) phase.connect(src, '{0}_boundary_constraints.{0}_value_in:{1}'.format( loc, con_name), src_indices=src_idxs, flat_src_indices=True) def setup_objective(self, phase): """ Find the path of the objective(s) and add the objective using the standard OpenMDAO method. """ for name, options in iteritems(phase._objectives): index = options['index'] loc = options['loc'] obj_path, shape, units, _ = self._get_boundary_constraint_src( name, loc, phase) shape = options['shape'] if shape is None else shape size = int(np.prod(shape)) if size > 1 and index is None: raise ValueError( 'Objective variable is non-scaler {0} but no index specified ' 'for objective'.format(shape)) idx = 0 if index is None else index if idx < 0: idx = size + idx if idx >= size or idx < -size: raise ValueError( 'Objective index={0}, but the shape of the objective ' 'variable is {1}'.format(index, shape)) if loc == 'final': obj_index = -size + idx elif loc == 'initial': obj_index = idx else: raise ValueError( 'Invalid value for objective loc: {0}. Must be ' 'one of \'initial\' or \'final\'.'.format(loc)) from ..phase import Phase super(Phase, phase).add_objective( obj_path, ref=options['ref'], ref0=options['ref0'], index=obj_index, adder=options['adder'], scaler=options['scaler'], parallel_deriv_color=options['parallel_deriv_color'], vectorize_derivs=options['vectorize_derivs']) def _get_boundary_constraint_src(self, name, loc, phase): raise NotImplementedError('Transcription {0} does not implement method' '_get_boundary_constraint_source.'.format( self.__class__.__name__)) def _get_rate_source_path(self, name, loc, phase): raise NotImplementedError('Transcription {0} does not implement method' '_get_rate_source_path.'.format( self.__class__.__name__)) def get_parameter_connections(self, name, phase): """ Returns a list containing tuples of each path and related indices to which the given parameter name is to be connected. Parameters ---------- name : str The name of the parameter for which connection information is desired. phase The phase object to which this transcription applies. Returns ------- connection_info : list of (paths, indices) A list containing a tuple of target paths and corresponding src_indices to which the given design variable is to be connected. """ raise NotImplementedError( 'Transcription {0} does not implement method ' 'get_parameter_connections.'.format(self.__class__.__name__)) def check_config(self, phase, logger): for var, options in iteritems(phase._path_constraints): # Determine the path to the variable which we will be constraining # This is more complicated for path constraints since, for instance, # a single state variable has two sources which must be connected to # the path component. var_type = phase.classify_var(var) if var_type == 'ode': # Failed to find variable, assume it is in the ODE if options['shape'] is None: logger.warning( 'Unable to infer shape of path constraint \'{0}\' in ' 'phase \'{1}\'. Scalar assumed. If this ODE output is ' 'is not scalar, connection errors will ' 'result.'.format(var, phase.name)) for var, options in iteritems(phase._initial_boundary_constraints): # Determine the path to the variable which we will be constraining # This is more complicated for path constraints since, for instance, # a single state variable has two sources which must be connected to # the path component. var_type = phase.classify_var(var) if var_type == 'ode': # Failed to find variable, assume it is in the ODE if options['shape'] is None: logger.warning( 'Unable to infer shape of boundary constraint \'{0}\' in ' 'phase \'{1}\'. Scalar assumed. If this ODE output is ' 'is not scalar, connection errors will ' 'result.'.format(var, phase.name)) for var, options in iteritems(phase._final_boundary_constraints): # Determine the path to the variable which we will be constraining # This is more complicated for path constraints since, for instance, # a single state variable has two sources which must be connected to # the path component. var_type = phase.classify_var(var) if var_type == 'ode': # Failed to find variable, assume it is in the ODE if options['shape'] is None: logger.warning( 'Unable to infer shape of boundary constraint \'{0}\' in ' 'phase \'{1}\'. Scalar assumed. If this ODE output is ' 'is not scalar, connection errors will ' 'result.'.format(var, phase.name)) for var, options in iteritems(phase._timeseries_outputs): # Determine the path to the variable which we will be constraining # This is more complicated for path constraints since, for instance, # a single state variable has two sources which must be connected to # the path component. var_type = phase.classify_var(var) # Ignore any variables that we've already added (states, times, controls, etc) if var_type != 'ode': continue # Assume scalar shape here, but check config will warn that it's inferred. if options['shape'] is None: logger.warning( 'Unable to infer shape of timeseries output \'{0}\' in ' 'phase \'{1}\'. Scalar assumed. If this ODE output is ' 'is not scalar, connection errors will ' 'result.'.format(var, phase.name))
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() def test_type_checking(self): self.dict.declare('test', types=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' class_or_type = 'class' if PY3 else 'type' expected_msg = "Option 'test' has the wrong type (<{} 'int'>)".format(class_or_type) self.assertEqual(expected_msg, str(context.exception)) # make sure bools work self.dict.declare('flag', default=False, types=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) def test_allow_none(self): self.dict.declare('test', types=int, allow_none=True, desc='Test integer value') self.dict['test'] = None self.assertEqual(self.dict['test'], None) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', types=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values with self.assertRaises(Exception) as context: self.dict.declare('test3', types=int, values=['a', 'b']) self.assertEqual(str(context.exception), "'types' and 'values' were both specified for option 'test3'.") def test_isvalid(self): self.dict.declare('even_test', types=int, is_valid=lambda x: x%2 == 0) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_isvalid_deprecated_type(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.dict.declare('even_test', type_=int, is_valid=lambda x: x%2 == 0) self.assertEqual(len(w), 1) self.assertEqual(str(w[-1].message), "In declaration of option 'even_test' the '_type' arg is deprecated. Use 'types' instead.") self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', types=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "\"Option 'missing' cannot be found\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, types=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ("Option 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\]") assertRegex(self, str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary") assertRegex(self, str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for option 'x'") assertRegex(self, str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for option 'x'") assertRegex(self, str(context.exception), expected_msg) def test_undeclare(self): # create an entry in the dict self.dict.declare('test', types=int) self.dict['test'] = 1 # prove it's in the dict self.assertEqual(self.dict['test'], 1) # remove entry from the dict self.dict.undeclare("test") # prove it is no longer in the dict with self.assertRaises(KeyError) as context: self.dict['test'] expected_msg = "\"Option 'test' cannot be found\"" self.assertEqual(expected_msg, str(context.exception))
def __init__(self, **kwargs): super(SalibDOEDriver, self).__init__() if SALIB_NOT_INSTALLED: raise RuntimeError("SALib library is not installed. \ cf. https://salib.readthedocs.io/en/latest/getting-started.html" ) self.options.declare( "sa_method_name", default="Morris", values=["Morris", "Sobol"], desc="either Morris or Sobol", ) self.options.declare( "sa_doe_options", types=dict, default={}, desc="options for given SMT sensitivity analysis method", ) self.options.update(kwargs) self.sa_settings = OptionsDictionary() if self.options["sa_method_name"] == "Morris": self.sa_settings.declare( "n_trajs", types=int, default=2, desc="number of trajectories to apply morris method", ) self.sa_settings.declare("n_levels", types=int, default=4, desc="number of grid levels") self.sa_settings.update(self.options["sa_doe_options"]) n_trajs = self.sa_settings["n_trajs"] n_levels = self.sa_settings["n_levels"] self.options["generator"] = SalibMorrisDOEGenerator( n_trajs, n_levels) elif self.options["sa_method_name"] == "Sobol": self.sa_settings.declare( "n_samples", types=int, default=500, desc="number of samples to generate", ) self.sa_settings.declare( "calc_second_order", types=bool, default=True, desc="calculate second-order sensitivities ", ) self.sa_settings.update(self.options["sa_doe_options"]) n_samples = self.sa_settings["n_samples"] calc_snd = self.sa_settings["calc_second_order"] self.options["generator"] = SalibSobolDOEGenerator( n_samples, calc_snd) else: raise RuntimeError( "Bad sensitivity analysis method name '{}'".format( self.options["sa_method_name"]))
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() def test_type_checking(self): self.dict.declare('test', type_=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' class_or_type = 'class' if PY3 else 'type' expected_msg = "Entry 'test' has the wrong type (<{} 'int'>)".format(class_or_type) self.assertEqual(expected_msg, str(context.exception)) # make sure bools work self.dict.declare('flag', default=False, type_=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', type_=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values self.dict.declare('test3', type_=int, values=['a', 'b']) self.dict['test3'] = 1 self.assertEqual(self.dict['test3'], 1) self.dict['test3'] = 'a' self.assertEqual(self.dict['test3'], 'a') def test_isvalid(self): self.dict.declare('even_test', type_=int, is_valid=lambda x: x%2 == 0) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', type_=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "\"Entry 'missing' cannot be found\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, type_=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ("Entry 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\]") assertRegex(self, str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary") assertRegex(self, str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for entry 'x'") assertRegex(self, str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for entry 'x'") assertRegex(self, str(context.exception), expected_msg)
from openmdao.api import OptionsDictionary options = OptionsDictionary() options.declare( 'include_check_partials', default=True, types=bool, desc='If True, include dymos components when checking partials.') options.declare( 'plots', default='matplotlib', values=['matplotlib', 'bokeh'], desc='The plot library used to generate output plots for Dymos.') options.declare('notebook_mode', default=False, types=bool, desc='If True, provide notebook-enhanced plots and outputs.')
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() def test_reprs(self): class MyComp(ExplicitComponent): pass my_comp = MyComp() self.dict.declare('test', values=['a', 'b'], desc='Test integer value') self.dict.declare('flag', default=False, types=bool) self.dict.declare('comp', default=my_comp, types=ExplicitComponent) self.dict.declare('long_desc', types=str, desc='This description is long and verbose, so it ' 'takes up multiple lines in the options table.') self.assertEqual(repr(self.dict), repr(self.dict._dict)) self.assertEqual( self.dict.__str__(width=83), '\n'.join([ "========= ============ ================= ===================== ====================", "Option Default Acceptable Values Acceptable Types Description ", "========= ============ ================= ===================== ====================", "comp MyComp N/A ['ExplicitComponent'] ", "flag False [True, False] ['bool'] ", "long_desc **Required** N/A ['str'] This description is ", " long and verbose, so", " it takes up multipl", " e lines in the optio", " ns table.", "test **Required** ['a', 'b'] N/A Test integer value ", "========= ============ ================= ===================== ====================", ])) # if the table can't be represented in specified width, then we get the full width version self.assertEqual( self.dict.__str__(width=40), '\n'.join([ "========= ============ ================= ===================== =====================" "==================================================================== ", "Option Default Acceptable Values Acceptable Types Description " " ", "========= ============ ================= ===================== =====================" "==================================================================== ", "comp MyComp N/A ['ExplicitComponent'] " " ", "flag False [True, False] ['bool'] " " ", "long_desc **Required** N/A ['str'] This description is l" "ong and verbose, so it takes up multiple lines in the options table. ", "test **Required** ['a', 'b'] N/A Test integer value " " ", "========= ============ ================= ===================== =====================" "==================================================================== ", ])) def test_type_checking(self): self.dict.declare('test', types=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' expected_msg = "Value ('') of option 'test' has type 'str', " \ "but type 'int' was expected." self.assertEqual(expected_msg, str(context.exception)) # multiple types are allowed self.dict.declare('test_multi', types=(int, float), desc='Test multiple types') self.dict['test_multi'] = 1 self.assertEqual(self.dict['test_multi'], 1) self.assertEqual(type(self.dict['test_multi']), int) self.dict['test_multi'] = 1.0 self.assertEqual(self.dict['test_multi'], 1.0) self.assertEqual(type(self.dict['test_multi']), float) with self.assertRaises(TypeError) as context: self.dict['test_multi'] = '' expected_msg = "Value ('') of option 'test_multi' has type 'str', " \ "but one of types ('int', 'float') was expected." self.assertEqual(expected_msg, str(context.exception)) # make sure bools work and allowed values are populated self.dict.declare('flag', default=False, types=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) meta = self.dict._dict['flag'] self.assertEqual(meta['values'], (True, False)) def test_allow_none(self): self.dict.declare('test', types=int, allow_none=True, desc='Test integer value') self.dict['test'] = None self.assertEqual(self.dict['test'], None) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', types=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values with self.assertRaises(Exception) as context: self.dict.declare('test3', types=int, values=['a', 'b']) self.assertEqual( str(context.exception), "'types' and 'values' were both specified for option 'test3'.") def test_isvalid(self): self.dict.declare('even_test', types=int, check_valid=check_even) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Option 'even_test' with value 3 is not an even number." self.assertEqual(expected_msg, str(context.exception)) def test_isvalid_deprecated_type(self): msg = "In declaration of option 'even_test' the '_type' arg is deprecated. Use 'types' instead." with assert_warning(DeprecationWarning, msg): self.dict.declare('even_test', type_=int, check_valid=check_even) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Option 'even_test' with value 3 is not an even number." self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Option 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', types=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Option 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "\"Option 'missing' cannot be found\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, types=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ( "Value \(<object object at 0x[0-9A-Fa-f]+>\) of option 'test' is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\].") assertRegex(self, str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set read-only option 'permanent'.") assertRegex(self, str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = "Value (3.0) of option 'x' exceeds maximum allowed value of 2.0." self.assertEqual(str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = "Value (-3.0) of option 'x' is less than minimum allowed value of 0.0." self.assertEqual(str(context.exception), expected_msg) def test_undeclare(self): # create an entry in the dict self.dict.declare('test', types=int) self.dict['test'] = 1 # prove it's in the dict self.assertEqual(self.dict['test'], 1) # remove entry from the dict self.dict.undeclare("test") # prove it is no longer in the dict with self.assertRaises(KeyError) as context: self.dict['test'] expected_msg = "\"Option 'test' cannot be found\"" self.assertEqual(expected_msg, str(context.exception))
from openmdao.api import OptionsDictionary options = OptionsDictionary() options.declare('include_check_partials', default=False, types=bool, desc='If True, include dymos components when checking partials.')
def add_linkage(self, name, vars, shape=(1, ), equals=None, lower=None, upper=None, units=None, scaler=None, adder=None, ref0=None, ref=None, linear=False): """ Add a linkage constraint to be managed by this component. .. math :: C_n = y_{n1} - y_{n0} where :math:`y_1` is the value of the variable at the beginning or end of phase 1, and :math:`y_0` is the value of the variable at the beginning or end of phase 0. The location of the source of the constraint can be set by the user based on connected indices. The name of each linkage constraint will be LNK_var where LNK is the name of the linkage and var are the vars in that linkage. Parameters ---------- name : str. The name of one or more linkage constraints to be added. vars : str or iterable The name of one or more linked variables to be added. shape : tuple The shape of the constraint being formed. Must be compliant with the shape of the variable. units : str, dict, or None The units of the linkage constraint. If given as a string, the units will apply to each variable in vars. If given as a dict, it should be keyed with variables in var, and the associated value being the corresponding units. Default is None. lower : float or ndarray The minimum allowable difference of y_1 - y_0, enforced by the optimizer. upper : float or ndarray The minimum allowable difference of y_1 - y_0, enforced by the optimizer. equals : float or ndarray The prescribed difference of y_1 - y_0, enforced bt the optimizer. scaler : float, ndarray, or None The scalar applied to this constraint by the driver. adder : float, ndarray, or None The adder applied to this constraint by the driver. ref0 : float, ndarray, or None The zero-reference value of this constraint, used for scaling by the driver. ref : float, ndarray, or None The one-reference value of this constraint, used for scaling by the driver. linear : bool If True, this constraint will be treated as a linear constraint by the optimizer. This should only be done if the *total derivative* of the constraint is linear. That is, the affected variables in each phase are design variables or linear functions of design variables. Default is False. """ if equals is None and lower is None and upper is None: equals = np.zeros(shape) if isinstance(vars, string_types): _vars = (vars, ) else: _vars = vars if isinstance(units, string_types) or units is None: _units = {} for var in _vars: _units[var] = units else: _units = units for var in _vars: lnk = OptionsDictionary() lnk.declare('name', types=(string_types, )) lnk.declare('equals', types=(float, np.ndarray), allow_none=True) lnk.declare('lower', types=(float, np.ndarray), allow_none=True) lnk.declare('upper', types=(float, np.ndarray), allow_none=True) lnk.declare('units', types=string_types, allow_none=True) lnk.declare('scaler', types=(float, np.ndarray), allow_none=True) lnk.declare('adder', types=(float, np.ndarray), allow_none=True) lnk.declare('ref0', types=(float, np.ndarray), allow_none=True) lnk.declare('ref', types=(float, np.ndarray), allow_none=True) lnk.declare('linear', types=bool) lnk.declare('shape', types=tuple) lnk.declare('cond0_name', types=string_types) lnk.declare('cond1_name', types=string_types) lnk['name'] = '{0}_{1}'.format(name, var) lnk['equals'] = equals lnk['lower'] = lower lnk['upper'] = upper lnk['scaler'] = scaler lnk['adder'] = adder lnk['ref0'] = ref0 lnk['ref'] = ref lnk['shape'] = shape lnk['linear'] = linear lnk['units'] = _units.get(var, None) lnk['cond0_name'] = '{0}:lhs'.format(lnk['name']) lnk['cond1_name'] = '{0}:rhs'.format(lnk['name']) self.options['linkages'].append(lnk)
def test_options_dictionary(self): self.options = OptionsDictionary() # Make sure we can't address keys we haven't added with self.assertRaises(KeyError) as cm: self.options['junk'] self.assertEqual('"Option \'{}\' has not been added"'.format('junk'), str(cm.exception)) # Type checking - don't set a float with an int self.options.add_option('atol', 1e-6) self.assertEqual(self.options['atol'], 1.0e-6) with self.assertRaises(ValueError) as cm: self.options['atol'] = 1 if PY2: self.assertEqual("'atol' should be a '<type 'float'>'", str(cm.exception)) else: self.assertEqual("'atol' should be a '<class 'float'>'", str(cm.exception)) # Check enum out of range self.options.add_option('iprint', 0, values=[0, 1, 2, 3]) for value in [0, 1, 2, 3]: self.options['iprint'] = value with self.assertRaises(ValueError) as cm: self.options['iprint'] = 4 self.assertEqual( "'iprint' must be one of the following values: '[0, 1, 2, 3]'", str(cm.exception)) # Type checking for boolean self.options.add_option('conmin_diff', True) self.options['conmin_diff'] = True self.options['conmin_diff'] = False with self.assertRaises(ValueError) as cm: self.options['conmin_diff'] = "YES!" if PY2: self.assertEqual("'conmin_diff' should be a '<type 'bool'>'", str(cm.exception)) else: self.assertEqual("'conmin_diff' should be a '<class 'bool'>'", str(cm.exception)) # Test Max and Min self.options.add_option('maxiter', 10, lower=0, upper=10) for value in range(0, 11): self.options['maxiter'] = value with self.assertRaises(ValueError) as cm: self.options['maxiter'] = 15 self.assertEqual("maximum allowed value for 'maxiter' is '10'", str(cm.exception)) with self.assertRaises(ValueError) as cm: self.options['maxiter'] = -1 self.assertEqual("minimum allowed value for 'maxiter' is '0'", str(cm.exception)) # Make sure we can't do this with self.assertRaises(ValueError) as cm: self.options.maxiter = -1 self.assertEqual("Use dict-like access for option 'maxiter'", str(cm.exception))
class TestOptionsDict(unittest.TestCase): def setUp(self): self.dict = OptionsDictionary() def test_type_checking(self): self.dict.declare('test', types=int, desc='Test integer value') self.dict['test'] = 1 self.assertEqual(self.dict['test'], 1) with self.assertRaises(TypeError) as context: self.dict['test'] = '' class_or_type = 'class' if PY3 else 'type' expected_msg = "Entry 'test' has the wrong type (<{} 'int'>)".format(class_or_type) self.assertEqual(expected_msg, str(context.exception)) # make sure bools work self.dict.declare('flag', default=False, types=bool) self.assertEqual(self.dict['flag'], False) self.dict['flag'] = True self.assertEqual(self.dict['flag'], True) def test_allow_none(self): self.dict.declare('test', types=int, allow_none=True, desc='Test integer value') self.dict['test'] = None self.assertEqual(self.dict['test'], None) def test_type_and_values(self): # Test with only type_ self.dict.declare('test1', types=int) self.dict['test1'] = 1 self.assertEqual(self.dict['test1'], 1) # Test with only values self.dict.declare('test2', values=['a', 'b']) self.dict['test2'] = 'a' self.assertEqual(self.dict['test2'], 'a') # Test with both type_ and values with self.assertRaises(Exception) as context: self.dict.declare('test3', types=int, values=['a', 'b']) self.assertEqual(str(context.exception), "'types' and 'values' were both specified for option 'test3'.") def test_isvalid(self): self.dict.declare('even_test', types=int, is_valid=lambda x: x%2 == 0) self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_isvalid_deprecated_type(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.dict.declare('even_test', type_=int, is_valid=lambda x: x%2 == 0) self.assertEqual(len(w), 1) self.assertEqual(str(w[-1].message), "In declaration of option 'even_test' the '_type' arg is deprecated. Use 'types' instead.") self.dict['even_test'] = 2 self.dict['even_test'] = 4 with self.assertRaises(ValueError) as context: self.dict['even_test'] = 3 expected_msg = "Function is_valid returns False for {}.".format('even_test') self.assertEqual(expected_msg, str(context.exception)) def test_unnamed_args(self): with self.assertRaises(KeyError) as context: self.dict['test'] = 1 # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_contains(self): self.dict.declare('test') contains = 'undeclared' in self.dict self.assertTrue(not contains) contains = 'test' in self.dict self.assertTrue(contains) def test_update(self): self.dict.declare('test', default='Test value', types=object) obj = object() self.dict.update({'test': obj}) self.assertIs(self.dict['test'], obj) def test_update_extra(self): with self.assertRaises(KeyError) as context: self.dict.update({'test': 2}) # KeyError ends up with an extra set of quotes. expected_msg = "\"Key 'test' cannot be set because it has not been declared.\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_missing(self): with self.assertRaises(KeyError) as context: self.dict['missing'] expected_msg = "\"Entry 'missing' cannot be found\"" self.assertEqual(expected_msg, str(context.exception)) def test_get_default(self): obj_def = object() obj_new = object() self.dict.declare('test', default=obj_def, types=object) self.assertIs(self.dict['test'], obj_def) self.dict['test'] = obj_new self.assertIs(self.dict['test'], obj_new) def test_values(self): obj1 = object() obj2 = object() self.dict.declare('test', values=[obj1, obj2]) self.dict['test'] = obj1 self.assertIs(self.dict['test'], obj1) with self.assertRaises(ValueError) as context: self.dict['test'] = object() expected_msg = ("Entry 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>," " <object object at 0x[0-9A-Fa-f]+>\]") assertRegex(self, str(context.exception), expected_msg) def test_read_only(self): opt = OptionsDictionary(read_only=True) opt.declare('permanent', 3.0) with self.assertRaises(KeyError) as context: opt['permanent'] = 4.0 expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary") assertRegex(self, str(context.exception), expected_msg) def test_bounds(self): self.dict.declare('x', default=1.0, lower=0.0, upper=2.0) with self.assertRaises(ValueError) as context: self.dict['x'] = 3.0 expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for entry 'x'") assertRegex(self, str(context.exception), expected_msg) with self.assertRaises(ValueError) as context: self.dict['x'] = -3.0 expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for entry 'x'") assertRegex(self, str(context.exception), expected_msg)