示例#1
0
    def __init__(self, **kwargs):

        self.grid_data = None

        self.options = OptionsDictionary()

        self.options.declare('num_segments',
                             types=int,
                             desc='Number of segments')
        self.options.declare(
            'segment_ends',
            default=None,
            types=(Sequence, np.ndarray),
            allow_none=True,
            desc='Locations of segment ends or None for equally '
            'spaced segments')
        self.options.declare('order',
                             default=3,
                             types=(int, Sequence, np.ndarray),
                             desc='Order of the state transcription')
        self.options.declare(
            'compressed',
            default=True,
            types=bool,
            desc='Use compressed transcription, meaning state and control values'
            'at segment boundaries are not duplicated on input.  This '
            'implicitly enforces value continuity between segments but in '
            'some cases may make the problem more difficult to solve.')

        self._declare_options()
        self.initialize()
        self.options.update(kwargs)
    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set read-only option 'permanent'.")
        assertRegex(self, str(context.exception), expected_msg)
    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set read-only option 'permanent'.")
        assertRegex(self, str(context.exception), expected_msg)
示例#4
0
    def test_locking(self):

        opt1 = OptionsDictionary()
        opt1.add_option('zzz', 10.0, lock_on_setup=True)
        opt2 = OptionsDictionary()
        opt2.add_option('xxx', 10.0, lock_on_setup=True)

        opt1['zzz'] = 15.0
        opt2['xxx'] = 12.0

        OptionsDictionary.locked = True

        with self.assertRaises(RuntimeError) as err:
            opt1['zzz'] = 14.0

        expected_msg = "The 'zzz' option cannot be changed after setup."
        self.assertEqual(str(err.exception), expected_msg)

        with self.assertRaises(RuntimeError) as err:
            opt2['xxx'] = 13.0

        expected_msg = "The 'xxx' option cannot be changed after setup."
        self.assertEqual(str(err.exception), expected_msg)
示例#5
0
    def test_locking(self):

        opt1 = OptionsDictionary()
        opt1.add_option('zzz', 10.0, lock_on_setup=True)
        opt2 = OptionsDictionary()
        opt2.add_option('xxx', 10.0, lock_on_setup=True)

        opt1['zzz'] = 15.0
        opt2['xxx'] = 12.0

        OptionsDictionary.locked = True

        with self.assertRaises(RuntimeError) as err:
            opt1['zzz'] = 14.0

        expected_msg = "The 'zzz' option cannot be changed after setup."
        self.assertEqual(str(err.exception), expected_msg)

        with self.assertRaises(RuntimeError) as err:
            opt2['xxx'] = 13.0

        expected_msg = "The 'xxx' option cannot be changed after setup."
        self.assertEqual(str(err.exception), expected_msg)
示例#6
0
class TestOptions(unittest.TestCase):

    def test_options_dictionary(self):
        self.options = OptionsDictionary()

        # Make sure we can't address keys we haven't added

        with self.assertRaises(KeyError) as cm:
            self.options['junk']

        self.assertEqual('"Option \'{}\' has not been added"'.format('junk'), str(cm.exception))

        # Type checking - don't set a float with an int

        self.options.add_option('atol', 1e-6)
        self.assertEqual(self.options['atol'], 1.0e-6)

        with self.assertRaises(ValueError) as cm:
            self.options['atol'] = 1

        if PY2:
            self.assertEqual("'atol' should be a '<type 'float'>'", str(cm.exception))
        else:
            self.assertEqual("'atol' should be a '<class 'float'>'", str(cm.exception))

        # Check enum out of range

        self.options.add_option('xyzzy', 0, values = [0, 1, 2, 3])
        for value in [0,1,2,3]:
            self.options['xyzzy'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['xyzzy'] = 4

        self.assertEqual("'xyzzy' must be one of the following values: '[0, 1, 2, 3]'", str(cm.exception))

        # Type checking for boolean

        self.options.add_option('conmin_diff', True)
        self.options['conmin_diff'] = True
        self.options['conmin_diff'] = False

        with self.assertRaises(ValueError) as cm:
            self.options['conmin_diff'] = "YES!"

        if PY2:
            self.assertEqual("'conmin_diff' should be a '<type 'bool'>'", str(cm.exception))
        else:
            self.assertEqual("'conmin_diff' should be a '<class 'bool'>'", str(cm.exception))

        # Test Max and Min

        self.options.add_option('maxiter', 10, lower=0, upper=10)
        for value in range(0, 11):
            self.options['maxiter'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = 15

        self.assertEqual("maximum allowed value for 'maxiter' is '10'", str(cm.exception))

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = -1

        self.assertEqual("minimum allowed value for 'maxiter' is '0'", str(cm.exception))

        # Make sure we can't do this
        with self.assertRaises(ValueError) as cm:
            self.options.maxiter = -1

        self.assertEqual("Use dict-like access for option 'maxiter'", str(cm.exception))

        #test removal
        self.assertTrue('conmin_diff' in self.options)
        self.options.remove_option('conmin_diff')
        self.assertFalse('conmin_diff' in self.options)

        # test Deprecation
        self.options._add_deprecation('max_iter', 'maxiter')
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("always")

            # Trigger a warning.
            self.options['max_iter'] = 5

            self.assertEqual(len(w), 1)
            self.assertEqual(str(w[0].message),
                     "Option 'max_iter' is deprecated. Use 'maxiter' instead.")

    def test_locking(self):

        opt1 = OptionsDictionary()
        opt1.add_option('zzz', 10.0, lock_on_setup=True)
        opt2 = OptionsDictionary()
        opt2.add_option('xxx', 10.0, lock_on_setup=True)

        opt1['zzz'] = 15.0
        opt2['xxx'] = 12.0

        OptionsDictionary.locked = True

        with self.assertRaises(RuntimeError) as err:
            opt1['zzz'] = 14.0

        expected_msg = "The 'zzz' option cannot be changed after setup."
        self.assertEqual(str(err.exception), expected_msg)

        with self.assertRaises(RuntimeError) as err:
            opt2['xxx'] = 13.0

        expected_msg = "The 'xxx' option cannot be changed after setup."
        self.assertEqual(str(err.exception), expected_msg)
示例#7
0
    def test_bad_option_name(self):
        opt = OptionsDictionary()
        msg = "'foo:bar' is not a valid python name and will become an invalid option name in a future release. You can prevent this warning (and future exceptions) by declaring this option using a valid python name."

        with assert_warning(OMDeprecationWarning, msg):
            opt.declare('foo:bar', 1.0)
示例#8
0
class TestOptionsDict(unittest.TestCase):
    def setUp(self):
        self.dict = OptionsDictionary()

    @unittest.skipIf(tabulate is None,
                     reason="package 'tabulate' is not installed")
    def test_reprs(self):
        class MyComp(ExplicitComponent):
            pass

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc',
                          types=str,
                          desc='This description is long and verbose, so it '
                          'takes up multiple lines in the options table.')

        self.assertEqual(repr(self.dict), repr(self.dict._dict))

        self.assertEqual(
            self.dict.__str__(width=89), '\n'.join([
                "=========  ============  ===================  =====================  ====================",
                "Option     Default       Acceptable Values    Acceptable Types       Description",
                "=========  ============  ===================  =====================  ====================",
                "comp       MyComp        N/A                  ['ExplicitComponent']",
                "flag       False         [True, False]        ['bool']",
                "long_desc  **Required**  N/A                  ['str']                This description is ",
                "                                                                     long and verbose, so",
                "                                                                      it takes up multipl",
                "                                                                     e lines in the optio",
                "                                                                     ns table.",
                "test       **Required**  ['a', 'b']           N/A                    Test integer value",
                "=========  ============  ===================  =====================  ====================",
            ]))

        # if the table can't be represented in specified width, then we get the full width version
        self.assertEqual(
            self.dict.__str__(width=40), '\n'.join([
                "=========  ============  ===================  =====================  ====================="
                "====================================================================",
                "Option     Default       Acceptable Values    Acceptable Types       Description",
                "=========  ============  ===================  =====================  ====================="
                "====================================================================",
                "comp       MyComp        N/A                  ['ExplicitComponent']",
                "flag       False         [True, False]        ['bool']",
                "long_desc  **Required**  N/A                  ['str']                This description is l"
                "ong and verbose, so it takes up multiple lines in the options table.",
                "test       **Required**  ['a', 'b']           N/A                    Test integer value",
                "=========  ============  ===================  =====================  ====================="
                "====================================================================",
            ]))

    @unittest.skipIf(tabulate is None,
                     reason="package 'tabulate' is not installed")
    def test_to_table(self):
        class MyComp(ExplicitComponent):
            pass

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc',
                          types=str,
                          desc='This description is long and verbose, so it '
                          'takes up multiple lines in the options table.')

        expected = "| Option    | Default      | Acceptable Values   | Acceptable Types      " \
                   "| Description                                                            " \
                   "                   |\n" \
                   "|-----------|--------------|---------------------|-----------------------|--" \
                   "----------------------------------------------------------------------------" \
                   "-------------|\n" \
                   "| comp      | MyComp       | N/A                 | ['ExplicitComponent'] |   " \
                   "                                                                             " \
                   "           |\n" \
                   "| flag      | False        | [True, False]       | ['bool']              |   " \
                   "                                                                             " \
                   "           |\n" \
                   "| long_desc | **Required** | N/A                 | ['str']               | Th" \
                   "is description is long and verbose, so it takes up multiple lines in the opti" \
                   "ons table. |\n" \
                   "| test      | **Required** | ['a', 'b']          | N/A                   | Te" \
                   "st integer value                                                             " \
                   "           |"

        self.assertEqual(self.dict.to_table(fmt='github'), expected)

    @unittest.skipIf(tabulate is None,
                     reason="package 'tabulate' is not installed")
    def test_deprecation_col(self):
        class MyComp(ExplicitComponent):
            pass

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc',
                          types=str,
                          desc='This description is long and verbose, so it '
                          'takes up multiple lines in the options table.',
                          deprecation='This option is deprecated')

        expected = "|Option|Default|AcceptableValues|AcceptableTypes|Description|Deprecation|\n|" \
        "-----------|--------------|---------------------|-----------------------|-----------------" \
        "--------------------------------------------------------------------------|---------------" \
        "------------|\n|comp|MyComp|N/A|['ExplicitComponent']||N/A|\n|flag|False|[True,False]|" \
        "['bool']||N/A|\n|long_desc|**Required**|N/A|['str']|Thisdescriptionislongandverbose,soit" \
        "takesupmultiplelinesintheoptionstable.|Thisoptionisdeprecated|\n|test|**Required**|" \
        "['a','b']|N/A|Testintegervalue|N/A|"

        self.assertEqual(
            self.dict.to_table(fmt='github').replace(" ", ""), expected)

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc',
                          types=str,
                          desc='This description is long and verbose, so it '
                          'takes up multiple lines in the options table.')

        expected = "|Option|Default|AcceptableValues|AcceptableTypes|Description|\n|-----------|----" \
        "----------|---------------------|-----------------------|-----------------------------------" \
        "--------------------------------------------------------|\n|comp|MyComp|N/A|" \
        "['ExplicitComponent']||\n|flag|False|[True,False]|['bool']||\n|long_desc|**Required**|N/A|" \
        "['str']|Thisdescriptionislongandverbose,soittakesupmultiplelinesintheoptionstable.|\n|test|" \
        "**Required**|['a','b']|N/A|Testintegervalue|"

        self.assertEqual(
            self.dict.to_table(fmt='github').replace(" ", ""), expected)

    def test_type_checking(self):
        self.dict.declare('test', types=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        expected_msg = "Value ('') of option 'test' has type 'str', " \
                       "but type 'int' was expected."
        self.assertEqual(expected_msg, str(context.exception))

        # multiple types are allowed
        self.dict.declare('test_multi',
                          types=(int, float),
                          desc='Test multiple types')

        self.dict['test_multi'] = 1
        self.assertEqual(self.dict['test_multi'], 1)
        self.assertEqual(type(self.dict['test_multi']), int)

        self.dict['test_multi'] = 1.0
        self.assertEqual(self.dict['test_multi'], 1.0)
        self.assertEqual(type(self.dict['test_multi']), float)

        with self.assertRaises(TypeError) as context:
            self.dict['test_multi'] = ''

        expected_msg = "Value ('') of option 'test_multi' has type 'str', " \
                       "but one of types ('int', 'float') was expected."
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work and allowed values are populated
        self.dict.declare('flag', default=False, types=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

        meta = self.dict._dict['flag']
        self.assertEqual(meta['values'], (True, False))

    def test_allow_none(self):
        self.dict.declare('test',
                          types=int,
                          allow_none=True,
                          desc='Test integer value')
        self.dict['test'] = None
        self.assertEqual(self.dict['test'], None)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', types=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        with self.assertRaises(Exception) as context:
            self.dict.declare('test3', types=int, values=['a', 'b'])
        self.assertEqual(
            str(context.exception),
            "'types' and 'values' were both specified for option 'test3'.")

    def test_check_valid_template(self):
        # test the template 'check_valid' function
        from openmdao.utils.options_dictionary import check_valid
        self.dict.declare('test', check_valid=check_valid)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = 1

        expected_msg = "Option 'test' with value 1 is not valid."
        self.assertEqual(expected_msg, str(context.exception))

    def test_isvalid(self):
        self.dict.declare('even_test', types=int, check_valid=check_even)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Option 'even_test' with value 3 is not an even number."
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Option 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', types=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Option 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "Option 'missing' has not been declared."
        self.assertEqual(expected_msg, context.exception.args[0])

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, types=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = (
            "Value \(<object object at 0x[0-9A-Fa-f]+>\) of option 'test' is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
            " <object object at 0x[0-9A-Fa-f]+>\].")
        self.assertRegex(str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set read-only option 'permanent'.")
        self.assertRegex(str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = "Value (3.0) of option 'x' exceeds maximum allowed value of 2.0."
        self.assertEqual(str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = "Value (-3.0) of option 'x' is less than minimum allowed value of 0.0."
        self.assertEqual(str(context.exception), expected_msg)

    def test_undeclare(self):
        # create an entry in the dict
        self.dict.declare('test', types=int)
        self.dict['test'] = 1

        # prove it's in the dict
        self.assertEqual(self.dict['test'], 1)

        # remove entry from the dict
        self.dict.undeclare("test")

        # prove it is no longer in the dict
        with self.assertRaises(KeyError) as context:
            self.dict['test']

        expected_msg = "Option 'test' has not been declared."
        self.assertEqual(expected_msg, context.exception.args[0])

    def test_deprecated_option(self):
        msg = 'Option "test1" is deprecated.'
        self.dict.declare('test1', deprecation=msg)

        # test double set
        with assert_warning(OMDeprecationWarning, msg):
            self.dict['test1'] = None
        # Should only generate warning first time
        with assert_no_warning(OMDeprecationWarning, msg):
            self.dict['test1'] = None

        # Also test set and then get
        msg = 'Option "test2" is deprecated.'
        self.dict.declare('test2', deprecation=msg)

        with assert_warning(OMDeprecationWarning, msg):
            self.dict['test2'] = None
        # Should only generate warning first time
        with assert_no_warning(OMDeprecationWarning, msg):
            option = self.dict['test2']

    def test_deprecated_tuple_option(self):
        msg = 'Option "test1" is deprecated. Use "foo" instead.'
        self.dict.declare('test1', deprecation=(msg, 'foo'))
        self.dict.declare('foo')

        # test double set
        with assert_warning(OMDeprecationWarning, msg):
            self.dict['test1'] = 'xyz'
        # Should only generate warning first time
        with assert_no_warning(OMDeprecationWarning, msg):
            self.dict['test1'] = 'zzz'

        with assert_no_warning(OMDeprecationWarning, msg):
            option = self.dict['test1']
        with assert_no_warning(OMDeprecationWarning):
            option2 = self.dict['foo']
        self.assertEqual(option, option2)

        # Also test set and then get
        msg = 'Option "test2" is deprecated. Use "foo2" instead.'
        self.dict.declare('test2', deprecation=(msg, 'foo2'))
        self.dict.declare('foo2')

        with assert_warning(OMDeprecationWarning, msg):
            self.dict['test2'] = 'abcd'
        # Should only generate warning first time
        with assert_no_warning(OMDeprecationWarning, msg):
            option = self.dict['test2']
        with assert_no_warning(OMDeprecationWarning):
            option2 = self.dict['foo2']
        self.assertEqual(option, option2)

        # test bad alias
        msg = 'Option "test3" is deprecated. Use "foo3" instead.'
        self.dict.declare('test3', deprecation=(msg, 'foo3'))

        with self.assertRaises(KeyError) as context:
            self.dict['test3'] = 'abcd'

        expected_msg = "Can't find aliased option 'foo3' for deprecated option 'test3'."
        self.assertEqual(context.exception.args[0], expected_msg)

    def test_bad_option_name(self):
        opt = OptionsDictionary()
        msg = "'foo:bar' is not a valid python name and will become an invalid option name in a future release. You can prevent this warning (and future exceptions) by declaring this option using a valid python name."

        with assert_warning(OMDeprecationWarning, msg):
            opt.declare('foo:bar', 1.0)
示例#9
0
    def test_options_dictionary(self):
        self.options = OptionsDictionary()

        # Make sure we can't address keys we haven't added

        with self.assertRaises(KeyError) as cm:
            self.options['junk']

        self.assertEqual('"Option \'{}\' has not been added"'.format('junk'),
                         str(cm.exception))

        # Type checking - don't set a float with an int

        self.options.add_option('atol', 1e-6)
        self.assertEqual(self.options['atol'], 1.0e-6)

        with self.assertRaises(ValueError) as cm:
            self.options['atol'] = 1

        if PY2:
            self.assertEqual("'atol' should be a '<type 'float'>'",
                             str(cm.exception))
        else:
            self.assertEqual("'atol' should be a '<class 'float'>'",
                             str(cm.exception))

        # Check enum out of range

        self.options.add_option('iprint', 0, values=[0, 1, 2, 3])
        for value in [0, 1, 2, 3]:
            self.options['iprint'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['iprint'] = 4

        self.assertEqual(
            "'iprint' must be one of the following values: '[0, 1, 2, 3]'",
            str(cm.exception))

        # Type checking for boolean

        self.options.add_option('conmin_diff', True)
        self.options['conmin_diff'] = True
        self.options['conmin_diff'] = False

        with self.assertRaises(ValueError) as cm:
            self.options['conmin_diff'] = "YES!"

        if PY2:
            self.assertEqual("'conmin_diff' should be a '<type 'bool'>'",
                             str(cm.exception))
        else:
            self.assertEqual("'conmin_diff' should be a '<class 'bool'>'",
                             str(cm.exception))

        # Test Max and Min

        self.options.add_option('maxiter', 10, lower=0, upper=10)
        for value in range(0, 11):
            self.options['maxiter'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = 15

        self.assertEqual("maximum allowed value for 'maxiter' is '10'",
                         str(cm.exception))

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = -1

        self.assertEqual("minimum allowed value for 'maxiter' is '0'",
                         str(cm.exception))

        # Make sure we can't do this
        with self.assertRaises(ValueError) as cm:
            self.options.maxiter = -1

        self.assertEqual("Use dict-like access for option 'maxiter'",
                         str(cm.exception))

        #test removal
        self.assertTrue('conmin_diff' in self.options)
        self.options.remove_option('conmin_diff')
        self.assertFalse('conmin_diff' in self.options)

        # test Deprecation
        self.options._add_deprecation('max_iter', 'maxiter')
        with warnings.catch_warnings(record=True) as w:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("always")

            # Trigger a warning.
            self.options['max_iter'] = 5

            self.assertEqual(len(w), 1)
            self.assertEqual(
                str(w[0].message),
                "Option 'max_iter' is deprecated. Use 'maxiter' instead.")
示例#10
0
    def test_options_dictionary(self):
        self.options = OptionsDictionary()

        # Make sure we can't address keys we haven't added

        with self.assertRaises(KeyError) as cm:
            self.options['junk']

        self.assertEqual('"Option \'{}\' has not been added"'.format('junk'), str(cm.exception))

        # Type checking - don't set a float with an int

        self.options.add_option('atol', 1e-6)
        self.assertEqual(self.options['atol'], 1.0e-6)

        with self.assertRaises(ValueError) as cm:
            self.options['atol'] = 1

        if PY2:
            self.assertEqual("'atol' should be a '<type 'float'>'", str(cm.exception))
        else:
            self.assertEqual("'atol' should be a '<class 'float'>'", str(cm.exception))

        # Check enum out of range

        self.options.add_option('iprint', 0, values = [0, 1, 2, 3])
        for value in [0,1,2,3]:
            self.options['iprint'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['iprint'] = 4

        self.assertEqual("'iprint' must be one of the following values: '[0, 1, 2, 3]'", str(cm.exception))

        # Type checking for boolean

        self.options.add_option('conmin_diff', True)
        self.options['conmin_diff'] = True
        self.options['conmin_diff'] = False

        with self.assertRaises(ValueError) as cm:
            self.options['conmin_diff'] = "YES!"

        if PY2:
            self.assertEqual("'conmin_diff' should be a '<type 'bool'>'", str(cm.exception))
        else:
            self.assertEqual("'conmin_diff' should be a '<class 'bool'>'", str(cm.exception))

        # Test Max and Min

        self.options.add_option('maxiter', 10, low=0, high=10)
        for value in range(0, 11):
            self.options['maxiter'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = 15

        self.assertEqual("maximum allowed value for 'maxiter' is '10'", str(cm.exception))

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = -1

        self.assertEqual("minimum allowed value for 'maxiter' is '0'", str(cm.exception))

        # Make sure we can't do this
        with self.assertRaises(ValueError) as cm:
            self.options.maxiter = -1

        self.assertEqual("Use dict-like access for option 'maxiter'", str(cm.exception))
 def setUp(self):
     self.dict = OptionsDictionary()
class TestOptionsDict(unittest.TestCase):

    def setUp(self):
        self.dict = OptionsDictionary()

    def test_reprs(self):
        class MyComp(ExplicitComponent):
            pass

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc', types=str,
                          desc='This description is long and verbose, so it '
                               'takes up multiple lines in the options table.')

        self.assertEqual(repr(self.dict), repr(self.dict._dict))

        self.assertEqual(self.dict.__str__(width=83), '\n'.join([
            "========= ============ ================= ===================== ====================",
            "Option    Default      Acceptable Values Acceptable Types      Description         ",
            "========= ============ ================= ===================== ====================",
            "comp      MyComp       N/A               ['ExplicitComponent']                     ",
            "flag      False        [True, False]     ['bool']                                  ",
            "long_desc **Required** N/A               ['str']               This description is ",
            "                                                               long and verbose, so",
            "                                                                it takes up multipl",
            "                                                               e lines in the optio",
            "                                                               ns table.",
            "test      **Required** ['a', 'b']        N/A                   Test integer value  ",
            "========= ============ ================= ===================== ====================",
        ]))

        # if the table can't be represented in specified width, then we get the full width version
        self.assertEqual(self.dict.__str__(width=40), '\n'.join([
            "========= ============ ================= ===================== ====================="
            "==================================================================== ",
            "Option    Default      Acceptable Values Acceptable Types      Description          "
            "                                                                     ",
            "========= ============ ================= ===================== ====================="
            "==================================================================== ",
            "comp      MyComp       N/A               ['ExplicitComponent']                      "
            "                                                                     ",
            "flag      False        [True, False]     ['bool']                                   "
            "                                                                     ",
            "long_desc **Required** N/A               ['str']               This description is l"
            "ong and verbose, so it takes up multiple lines in the options table. ",
            "test      **Required** ['a', 'b']        N/A                   Test integer value   "
            "                                                                     ",
            "========= ============ ================= ===================== ====================="
            "==================================================================== ",
        ]))

    def test_type_checking(self):
        self.dict.declare('test', types=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        expected_msg = "Value ('') of option 'test' has type 'str', " \
                       "but type 'int' was expected."
        self.assertEqual(expected_msg, str(context.exception))

        # multiple types are allowed
        self.dict.declare('test_multi', types=(int, float), desc='Test multiple types')

        self.dict['test_multi'] = 1
        self.assertEqual(self.dict['test_multi'], 1)
        self.assertEqual(type(self.dict['test_multi']), int)

        self.dict['test_multi'] = 1.0
        self.assertEqual(self.dict['test_multi'], 1.0)
        self.assertEqual(type(self.dict['test_multi']), float)

        with self.assertRaises(TypeError) as context:
            self.dict['test_multi'] = ''

        expected_msg = "Value ('') of option 'test_multi' has type 'str', " \
                       "but one of types ('int', 'float') was expected."
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work and allowed values are populated
        self.dict.declare('flag', default=False, types=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

        meta = self.dict._dict['flag']
        self.assertEqual(meta['values'], (True, False))

    def test_allow_none(self):
        self.dict.declare('test', types=int, allow_none=True, desc='Test integer value')
        self.dict['test'] = None
        self.assertEqual(self.dict['test'], None)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', types=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        with self.assertRaises(Exception) as context:
            self.dict.declare('test3', types=int, values=['a', 'b'])
        self.assertEqual(str(context.exception),
                         "'types' and 'values' were both specified for option 'test3'.")

    def test_isvalid(self):
        self.dict.declare('even_test', types=int, check_valid=check_even)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Option 'even_test' with value 3 is not an even number."
        self.assertEqual(expected_msg, str(context.exception))

    def test_isvalid_deprecated_type(self):

        msg = "In declaration of option 'even_test' the '_type' arg is deprecated.  Use 'types' instead."

        with assert_warning(DeprecationWarning, msg):
            self.dict.declare('even_test', type_=int, check_valid=check_even)

        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Option 'even_test' with value 3 is not an even number."
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Option 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', types=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Option 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "\"Option 'missing' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, types=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = ("Value \(<object object at 0x[0-9A-Fa-f]+>\) of option 'test' is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
                        " <object object at 0x[0-9A-Fa-f]+>\].")
        assertRegex(self, str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set read-only option 'permanent'.")
        assertRegex(self, str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = "Value (3.0) of option 'x' exceeds maximum allowed value of 2.0."
        self.assertEqual(str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = "Value (-3.0) of option 'x' is less than minimum allowed value of 0.0."
        self.assertEqual(str(context.exception), expected_msg)

    def test_undeclare(self):
        # create an entry in the dict
        self.dict.declare('test', types=int)
        self.dict['test'] = 1

        # prove it's in the dict
        self.assertEqual(self.dict['test'], 1)

        # remove entry from the dict
        self.dict.undeclare("test")

        # prove it is no longer in the dict
        with self.assertRaises(KeyError) as context:
            self.dict['test']

        expected_msg = "\"Option 'test' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))
class TestOptionsDict(unittest.TestCase):

    def setUp(self):
        self.dict = OptionsDictionary()

    def test_reprs(self):
        class MyComp(ExplicitComponent):
            pass

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc', types=str,
                          desc='This description is long and verbose, so it '
                               'takes up multiple lines in the options table.')

        self.assertEqual(self.dict.__repr__(), self.dict._dict)

        self.assertEqual(self.dict.__str__(width=83), '\n'.join([
            "========= ============ ================= ===================== ====================",
            "Option    Default      Acceptable Values Acceptable Types      Description         ",
            "========= ============ ================= ===================== ====================",
            "comp      MyComp       N/A               ['ExplicitComponent']                     ",
            "flag      False        N/A               ['bool']                                  ",
            "long_desc **Required** N/A               ['str']               This description is ",
            "                                                               long and verbose, so",
            "                                                                it takes up multipl",
            "                                                               e lines in the optio",
            "                                                               ns table.",
            "test      **Required** ['a', 'b']        N/A                   Test integer value  ",
            "========= ============ ================= ===================== ====================",
        ]))

        # if the table can't be represented in specified width, then we get the full width version
        self.assertEqual(self.dict.__str__(width=40), '\n'.join([
            "========= ============ ================= ===================== ====================="
            "==================================================================== ",
            "Option    Default      Acceptable Values Acceptable Types      Description          "
            "                                                                     ",
            "========= ============ ================= ===================== ====================="
            "==================================================================== ",
            "comp      MyComp       N/A               ['ExplicitComponent']                      "
            "                                                                     ",
            "flag      False        N/A               ['bool']                                   "
            "                                                                     ",
            "long_desc **Required** N/A               ['str']               This description is l"
            "ong and verbose, so it takes up multiple lines in the options table. ",
            "test      **Required** ['a', 'b']        N/A                   Test integer value   "
            "                                                                     ",
            "========= ============ ================= ===================== ====================="
            "==================================================================== ",
        ]))


    def test_type_checking(self):
        self.dict.declare('test', types=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        class_or_type = 'class' if PY3 else 'type'
        expected_msg = "Option 'test' has the wrong type (<{} 'int'>)".format(class_or_type)
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work
        self.dict.declare('flag', default=False, types=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

    def test_allow_none(self):
        self.dict.declare('test', types=int, allow_none=True, desc='Test integer value')
        self.dict['test'] = None
        self.assertEqual(self.dict['test'], None)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', types=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        with self.assertRaises(Exception) as context:
            self.dict.declare('test3', types=int, values=['a', 'b'])
        self.assertEqual(str(context.exception),
                         "'types' and 'values' were both specified for option 'test3'.")

    def test_isvalid(self):
        self.dict.declare('even_test', types=int, is_valid=lambda x: x % 2 == 0)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_isvalid_deprecated_type(self):

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            self.dict.declare('even_test', type_=int, is_valid=lambda x: x % 2 == 0)
            self.assertEqual(len(w), 1)
            self.assertEqual(str(w[-1].message), "In declaration of option 'even_test' the '_type' arg is deprecated.  Use 'types' instead.")

        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', types=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "\"Option 'missing' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, types=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = ("Option 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
                        " <object object at 0x[0-9A-Fa-f]+>\]")
        assertRegex(self, str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary")
        assertRegex(self, str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for option 'x'")
        assertRegex(self, str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for option 'x'")
        assertRegex(self, str(context.exception), expected_msg)

    def test_undeclare(self):
        # create an entry in the dict
        self.dict.declare('test', types=int)
        self.dict['test'] = 1

        # prove it's in the dict
        self.assertEqual(self.dict['test'], 1)

        # remove entry from the dict
        self.dict.undeclare("test")

        # prove it is no longer in the dict
        with self.assertRaises(KeyError) as context:
            self.dict['test']

        expected_msg = "\"Option 'test' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))
示例#14
0
class TranscriptionBase(object):
    def __init__(self, **kwargs):

        self.grid_data = None

        self.options = OptionsDictionary()

        self.options.declare('num_segments',
                             types=int,
                             desc='Number of segments')
        self.options.declare(
            'segment_ends',
            default=None,
            types=(Sequence, np.ndarray),
            allow_none=True,
            desc='Locations of segment ends or None for equally '
            'spaced segments')
        self.options.declare('order',
                             default=3,
                             types=(int, Sequence, np.ndarray),
                             desc='Order of the state transcription')
        self.options.declare(
            'compressed',
            default=True,
            types=bool,
            desc='Use compressed transcription, meaning state and control values'
            'at segment boundaries are not duplicated on input.  This '
            'implicitly enforces value continuity between segments but in '
            'some cases may make the problem more difficult to solve.')

        self._declare_options()
        self.initialize()
        self.options.update(kwargs)

    def _declare_options(self):
        pass

    def initialize(self):
        pass

    def setup_grid(self, phase):
        """
        Setup the GridData object for the Transcription

        Parameters
        ----------
        phase
            The phase to which this transcription applies.
        """
        raise NotImplementedError('Transcription {0} does not implement method'
                                  'setup_grid.'.format(
                                      self.__class__.__name__))

    def setup_time(self, phase):
        """
        Setup up the time component and time extents for the phase.

        Returns
        -------
        comps
            A list of the component names needed for time extents.
        """
        time_options = phase.time_options
        time_units = time_options['units']

        indeps = []
        default_vals = {
            't_initial': phase.time_options['initial_val'],
            't_duration': phase.time_options['duration_val']
        }
        externals = []
        comps = []

        # Warn about invalid options
        phase.check_time_options()

        if time_options['input_initial']:
            externals.append('t_initial')
        else:
            indeps.append('t_initial')
            # phase.connect('t_initial', 'time.t_initial')

        if time_options['input_duration']:
            externals.append('t_duration')
        else:
            indeps.append('t_duration')
            # phase.connect('t_duration', 'time.t_duration')

        if indeps:
            indep = IndepVarComp()

            for var in indeps:
                indep.add_output(var, val=default_vals[var], units=time_units)

            phase.add_subsystem('time_extents', indep, promotes_outputs=['*'])
            comps += ['time_extents']

        if not (time_options['input_initial'] or time_options['fix_initial']):
            lb, ub = time_options['initial_bounds']
            lb = -INF_BOUND if lb is None else lb
            ub = INF_BOUND if ub is None else ub

            phase.add_design_var('t_initial',
                                 lower=lb,
                                 upper=ub,
                                 scaler=time_options['initial_scaler'],
                                 adder=time_options['initial_adder'],
                                 ref0=time_options['initial_ref0'],
                                 ref=time_options['initial_ref'])

        if not (time_options['input_duration']
                or time_options['fix_duration']):
            lb, ub = time_options['duration_bounds']
            lb = -INF_BOUND if lb is None else lb
            ub = INF_BOUND if ub is None else ub

            phase.add_design_var('t_duration',
                                 lower=lb,
                                 upper=ub,
                                 scaler=time_options['duration_scaler'],
                                 adder=time_options['duration_adder'],
                                 ref0=time_options['duration_ref0'],
                                 ref=time_options['duration_ref'])

    def setup_controls(self, phase):
        """
        Adds an IndepVarComp if necessary and issues appropriate connections based
        on transcription.
        """
        phase._check_control_options()

        if phase.control_options:
            control_group = ControlGroup(
                control_options=phase.control_options,
                time_units=phase.time_options['units'],
                grid_data=self.grid_data)

            phase.add_subsystem(
                'control_group',
                subsys=control_group,
                promotes=['controls:*', 'control_values:*', 'control_rates:*'])

            phase.connect('dt_dstau', 'control_group.dt_dstau')

    def setup_polynomial_controls(self, phase):
        """
        Adds the polynomial control group to the model if any polynomial controls are present.
        """
        if phase.polynomial_control_options:
            sys = PolynomialControlGroup(
                grid_data=self.grid_data,
                polynomial_control_options=phase.polynomial_control_options,
                time_units=phase.time_options['units'])
            phase.add_subsystem('polynomial_control_group',
                                subsys=sys,
                                promotes_inputs=['*'],
                                promotes_outputs=['*'])

    def setup_design_parameters(self, phase):
        """
        Adds an IndepVarComp if necessary and issues appropriate connections based
        on transcription.
        """
        phase._check_design_parameter_options()

        if phase.design_parameter_options:
            indep = phase.add_subsystem('design_params',
                                        subsys=IndepVarComp(),
                                        promotes_outputs=['*'])

            for name, options in iteritems(phase.design_parameter_options):
                src_name = 'design_parameters:{0}'.format(name)

                if options['opt']:
                    lb = -INF_BOUND if options['lower'] is None else options[
                        'lower']
                    ub = INF_BOUND if options['upper'] is None else options[
                        'upper']

                    phase.add_design_var(name=src_name,
                                         lower=lb,
                                         upper=ub,
                                         scaler=options['scaler'],
                                         adder=options['adder'],
                                         ref0=options['ref0'],
                                         ref=options['ref'])

                _shape = (1, ) + options['shape']

                indep.add_output(name=src_name,
                                 val=options['val'],
                                 shape=_shape,
                                 units=options['units'])

                for tgts, src_idxs in self.get_parameter_connections(
                        name, phase):
                    phase.connect(src_name, [t for t in tgts],
                                  src_indices=src_idxs,
                                  flat_src_indices=True)

    def setup_input_parameters(self, phase):
        """
        Adds a InputParameterComp to allow input parameters to be connected from sources
        external to the phase.
        """
        if phase.input_parameter_options:
            passthru = InputParameterComp(
                input_parameter_options=phase.input_parameter_options)

            phase.add_subsystem('input_params',
                                subsys=passthru,
                                promotes_inputs=['*'],
                                promotes_outputs=['*'])

        for name in phase.input_parameter_options:
            src_name = 'input_parameters:{0}_out'.format(name)

            for tgts, src_idxs in self.get_parameter_connections(name, phase):
                phase.connect(src_name, [t for t in tgts],
                              src_indices=src_idxs,
                              flat_src_indices=True)

    def setup_traj_parameters(self, phase):
        """
        Adds a InputParameterComp to allow input parameters to be connected from sources
        external to the phase.
        """
        if phase.traj_parameter_options:
            passthru = \
                InputParameterComp(input_parameter_options=phase.traj_parameter_options,
                                   traj_params=True)

            phase.add_subsystem('traj_params',
                                subsys=passthru,
                                promotes_inputs=['*'],
                                promotes_outputs=['*'])

        for name, options in iteritems(phase.traj_parameter_options):
            src_name = 'traj_parameters:{0}_out'.format(name)

            for tgts, src_idxs in self.get_parameter_connections(name, phase):
                phase.connect(src_name, [t for t in tgts],
                              src_indices=src_idxs)

    def setup_states(self, phase):
        raise NotImplementedError(
            'Transcription {0} does not implement method '
            'setup_states.'.format(self.__class__.__name__))

    def setup_ode(self, phase):
        raise NotImplementedError(
            'Transcription {0} does not implement method '
            'setup_ode.'.format(self.__class__.__name__))

    def setup_timeseries_outputs(self, phase):
        raise NotImplementedError(
            'Transcription {0} does not implement method '
            'setup_timeseries_outputs.'.format(self.__class__.__name__))

    def setup_boundary_constraints(self, loc, phase):
        """
        Adds BoundaryConstraintComp for initial and/or final boundary constraints if necessary
        and issues appropriate connections.

        Parameters
        ----------
        loc : str
            The kind of boundary constraints being setup.  Must be one of 'initial' or 'final'.
        phase
            The phase object to which this transcription instance applies.

        """
        if loc not in ('initial', 'final'):
            raise ValueError('loc must be one of \'initial\' or \'final\'.')
        bc_comp = None

        bc_dict = phase._initial_boundary_constraints \
            if loc == 'initial' else phase._final_boundary_constraints

        if bc_dict:
            bc_comp = phase.add_subsystem(
                '{0}_boundary_constraints'.format(loc),
                subsys=BoundaryConstraintComp(loc=loc))

        for var, options in iteritems(bc_dict):
            con_name = options['constraint_name']

            # Constraint options are a copy of options with constraint_name key removed.
            con_options = options.copy()
            con_options.pop('constraint_name')

            src, shape, units, linear = self._get_boundary_constraint_src(
                var, loc, phase)

            con_units = options.get('units', None)

            shape = options['shape'] if shape is None else shape
            if shape is None:
                shape = (1, )

            if options['indices'] is not None:
                # Indices are provided, make sure lower/upper/equals are compatible.
                con_shape = (len(options['indices']), )
                # Indices provided, make sure lower/upper/equals have shape of the indices.
                if options['lower'] and not np.isscalar(options['lower']) and \
                        np.asarray(options['lower']).shape != con_shape:
                    raise ValueError(
                        'The lower bounds of boundary constraint on {0} are not '
                        'compatible with its shape, and no indices were '
                        'provided.'.format(var))

                if options['upper'] and not np.isscalar(options['upper']) and \
                        np.asarray(options['upper']).shape != con_shape:
                    raise ValueError(
                        'The upper bounds of boundary constraint on {0} are not '
                        'compatible with its shape, and no indices were '
                        'provided.'.format(var))

                if options['equals'] and not np.isscalar(options['equals']) and \
                        np.asarray(options['equals']).shape != con_shape:
                    raise ValueError(
                        'The equality boundary constraint value on {0} is not '
                        'compatible the provided indices. Provide them as a '
                        'flat array with the same size as indices.'.format(
                            var))

            elif options['lower'] or options['upper'] or options['equals']:
                # Indices not provided, make sure lower/upper/equals have shape of source.
                if options['lower'] and not np.isscalar(options['lower']) and \
                        np.asarray(options['lower']).shape != shape:
                    raise ValueError(
                        'The lower bounds of boundary constraint on {0} are not '
                        'compatible with its shape, and no indices were '
                        'provided.'.format(var))

                if options['upper'] and not np.isscalar(options['upper']) and \
                        np.asarray(options['upper']).shape != shape:
                    raise ValueError(
                        'The upper bounds of boundary constraint on {0} are not '
                        'compatible with its shape, and no indices were '
                        'provided.'.format(var))

                if options['equals'] and not np.isscalar(options['equals']) \
                        and np.asarray(options['equals']).shape != shape:
                    raise ValueError(
                        'The equality boundary constraint value on {0} is not '
                        'compatible with its shape, and no indices were '
                        'provided.'.format(var))
                con_shape = (np.prod(shape), )

            size = np.prod(shape)
            con_options['shape'] = shape if shape is not None else con_shape
            con_options['units'] = units if con_units is None else con_units
            con_options['linear'] = linear

            # Build the correct src_indices regardless of shape
            if loc == 'initial':
                src_idxs = np.arange(size, dtype=int).reshape(shape)
            else:
                src_idxs = np.arange(-size, 0, dtype=int).reshape(shape)

            bc_comp._add_constraint(con_name, **con_options)

            phase.connect(src,
                          '{0}_boundary_constraints.{0}_value_in:{1}'.format(
                              loc, con_name),
                          src_indices=src_idxs,
                          flat_src_indices=True)

    def setup_objective(self, phase):
        """
        Find the path of the objective(s) and add the objective using the standard OpenMDAO method.
        """
        for name, options in iteritems(phase._objectives):
            index = options['index']
            loc = options['loc']

            obj_path, shape, units, _ = self._get_boundary_constraint_src(
                name, loc, phase)

            shape = options['shape'] if shape is None else shape

            size = int(np.prod(shape))

            if size > 1 and index is None:
                raise ValueError(
                    'Objective variable is non-scaler {0} but no index specified '
                    'for objective'.format(shape))

            idx = 0 if index is None else index
            if idx < 0:
                idx = size + idx

            if idx >= size or idx < -size:
                raise ValueError(
                    'Objective index={0}, but the shape of the objective '
                    'variable is {1}'.format(index, shape))

            if loc == 'final':
                obj_index = -size + idx
            elif loc == 'initial':
                obj_index = idx
            else:
                raise ValueError(
                    'Invalid value for objective loc: {0}. Must be '
                    'one of \'initial\' or \'final\'.'.format(loc))

            from ..phase import Phase
            super(Phase, phase).add_objective(
                obj_path,
                ref=options['ref'],
                ref0=options['ref0'],
                index=obj_index,
                adder=options['adder'],
                scaler=options['scaler'],
                parallel_deriv_color=options['parallel_deriv_color'],
                vectorize_derivs=options['vectorize_derivs'])

    def _get_boundary_constraint_src(self, name, loc, phase):
        raise NotImplementedError('Transcription {0} does not implement method'
                                  '_get_boundary_constraint_source.'.format(
                                      self.__class__.__name__))

    def _get_rate_source_path(self, name, loc, phase):
        raise NotImplementedError('Transcription {0} does not implement method'
                                  '_get_rate_source_path.'.format(
                                      self.__class__.__name__))

    def get_parameter_connections(self, name, phase):
        """
        Returns a list containing tuples of each path and related indices to which the
        given parameter name is to be connected.

        Parameters
        ----------
        name : str
            The name of the parameter for which connection information is desired.
        phase
            The phase object to which this transcription applies.

        Returns
        -------
        connection_info : list of (paths, indices)
            A list containing a tuple of target paths and corresponding src_indices to which the
            given design variable is to be connected.
        """
        raise NotImplementedError(
            'Transcription {0} does not implement method '
            'get_parameter_connections.'.format(self.__class__.__name__))

    def check_config(self, phase, logger):

        for var, options in iteritems(phase._path_constraints):
            # Determine the path to the variable which we will be constraining
            # This is more complicated for path constraints since, for instance,
            # a single state variable has two sources which must be connected to
            # the path component.
            var_type = phase.classify_var(var)

            if var_type == 'ode':
                # Failed to find variable, assume it is in the ODE
                if options['shape'] is None:
                    logger.warning(
                        'Unable to infer shape of path constraint \'{0}\' in '
                        'phase \'{1}\'. Scalar assumed.  If this ODE output is '
                        'is not scalar, connection errors will '
                        'result.'.format(var, phase.name))

        for var, options in iteritems(phase._initial_boundary_constraints):
            # Determine the path to the variable which we will be constraining
            # This is more complicated for path constraints since, for instance,
            # a single state variable has two sources which must be connected to
            # the path component.
            var_type = phase.classify_var(var)

            if var_type == 'ode':
                # Failed to find variable, assume it is in the ODE
                if options['shape'] is None:
                    logger.warning(
                        'Unable to infer shape of boundary constraint \'{0}\' in '
                        'phase \'{1}\'. Scalar assumed.  If this ODE output is '
                        'is not scalar, connection errors will '
                        'result.'.format(var, phase.name))

        for var, options in iteritems(phase._final_boundary_constraints):
            # Determine the path to the variable which we will be constraining
            # This is more complicated for path constraints since, for instance,
            # a single state variable has two sources which must be connected to
            # the path component.
            var_type = phase.classify_var(var)

            if var_type == 'ode':
                # Failed to find variable, assume it is in the ODE
                if options['shape'] is None:
                    logger.warning(
                        'Unable to infer shape of boundary constraint \'{0}\' in '
                        'phase \'{1}\'. Scalar assumed.  If this ODE output is '
                        'is not scalar, connection errors will '
                        'result.'.format(var, phase.name))

        for var, options in iteritems(phase._timeseries_outputs):

            # Determine the path to the variable which we will be constraining
            # This is more complicated for path constraints since, for instance,
            # a single state variable has two sources which must be connected to
            # the path component.
            var_type = phase.classify_var(var)

            # Ignore any variables that we've already added (states, times, controls, etc)
            if var_type != 'ode':
                continue

            # Assume scalar shape here, but check config will warn that it's inferred.
            if options['shape'] is None:
                logger.warning(
                    'Unable to infer shape of timeseries output \'{0}\' in '
                    'phase \'{1}\'. Scalar assumed.  If this ODE output is '
                    'is not scalar, connection errors will '
                    'result.'.format(var, phase.name))
示例#15
0
class TestOptionsDict(unittest.TestCase):

    def setUp(self):
        self.dict = OptionsDictionary()

    def test_type_checking(self):
        self.dict.declare('test', types=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        class_or_type = 'class' if PY3 else 'type'
        expected_msg = "Option 'test' has the wrong type (<{} 'int'>)".format(class_or_type)
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work
        self.dict.declare('flag', default=False, types=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

    def test_allow_none(self):
        self.dict.declare('test', types=int, allow_none=True, desc='Test integer value')
        self.dict['test'] = None
        self.assertEqual(self.dict['test'], None)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', types=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        with self.assertRaises(Exception) as context:
            self.dict.declare('test3', types=int, values=['a', 'b'])
        self.assertEqual(str(context.exception),
                         "'types' and 'values' were both specified for option 'test3'.")

    def test_isvalid(self):
        self.dict.declare('even_test', types=int, is_valid=lambda x: x%2 == 0)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_isvalid_deprecated_type(self):

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            self.dict.declare('even_test', type_=int, is_valid=lambda x: x%2 == 0)
            self.assertEqual(len(w), 1)
            self.assertEqual(str(w[-1].message), "In declaration of option 'even_test' the '_type' arg is deprecated.  Use 'types' instead.")

        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', types=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "\"Option 'missing' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, types=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = ("Option 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
                        " <object object at 0x[0-9A-Fa-f]+>\]")
        assertRegex(self, str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary")
        assertRegex(self, str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for option 'x'")
        assertRegex(self, str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for option 'x'")
        assertRegex(self, str(context.exception), expected_msg)

    def test_undeclare(self):
        # create an entry in the dict
        self.dict.declare('test', types=int)
        self.dict['test'] = 1

        # prove it's in the dict
        self.assertEqual(self.dict['test'], 1)

        # remove entry from the dict
        self.dict.undeclare("test")

        # prove it is no longer in the dict
        with self.assertRaises(KeyError) as context:
            self.dict['test']

        expected_msg = "\"Option 'test' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))
    def __init__(self, **kwargs):
        super(SalibDOEDriver, self).__init__()

        if SALIB_NOT_INSTALLED:
            raise RuntimeError("SALib library is not installed. \
                                cf. https://salib.readthedocs.io/en/latest/getting-started.html"
                               )

        self.options.declare(
            "sa_method_name",
            default="Morris",
            values=["Morris", "Sobol"],
            desc="either Morris or Sobol",
        )
        self.options.declare(
            "sa_doe_options",
            types=dict,
            default={},
            desc="options for given SMT sensitivity analysis method",
        )
        self.options.update(kwargs)

        self.sa_settings = OptionsDictionary()
        if self.options["sa_method_name"] == "Morris":
            self.sa_settings.declare(
                "n_trajs",
                types=int,
                default=2,
                desc="number of trajectories to apply morris method",
            )
            self.sa_settings.declare("n_levels",
                                     types=int,
                                     default=4,
                                     desc="number of grid levels")
            self.sa_settings.update(self.options["sa_doe_options"])
            n_trajs = self.sa_settings["n_trajs"]
            n_levels = self.sa_settings["n_levels"]
            self.options["generator"] = SalibMorrisDOEGenerator(
                n_trajs, n_levels)
        elif self.options["sa_method_name"] == "Sobol":
            self.sa_settings.declare(
                "n_samples",
                types=int,
                default=500,
                desc="number of samples to generate",
            )
            self.sa_settings.declare(
                "calc_second_order",
                types=bool,
                default=True,
                desc="calculate second-order sensitivities ",
            )
            self.sa_settings.update(self.options["sa_doe_options"])
            n_samples = self.sa_settings["n_samples"]
            calc_snd = self.sa_settings["calc_second_order"]
            self.options["generator"] = SalibSobolDOEGenerator(
                n_samples, calc_snd)
        else:
            raise RuntimeError(
                "Bad sensitivity analysis method name '{}'".format(
                    self.options["sa_method_name"]))
示例#17
0
class TestOptionsDict(unittest.TestCase):

    def setUp(self):
        self.dict = OptionsDictionary()

    def test_type_checking(self):
        self.dict.declare('test', type_=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        class_or_type = 'class' if PY3 else 'type'
        expected_msg = "Entry 'test' has the wrong type (<{} 'int'>)".format(class_or_type)
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work
        self.dict.declare('flag', default=False, type_=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', type_=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        self.dict.declare('test3', type_=int, values=['a', 'b'])
        self.dict['test3'] = 1
        self.assertEqual(self.dict['test3'], 1)
        self.dict['test3'] = 'a'
        self.assertEqual(self.dict['test3'], 'a')

    def test_isvalid(self):
        self.dict.declare('even_test', type_=int, is_valid=lambda x: x%2 == 0)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', type_=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "\"Entry 'missing' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, type_=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = ("Entry 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
                        " <object object at 0x[0-9A-Fa-f]+>\]")
        assertRegex(self, str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary")
        assertRegex(self, str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for entry 'x'")
        assertRegex(self, str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for entry 'x'")
        assertRegex(self, str(context.exception), expected_msg)
示例#18
0
from openmdao.api import OptionsDictionary

options = OptionsDictionary()

options.declare(
    'include_check_partials',
    default=True,
    types=bool,
    desc='If True, include dymos components when checking partials.')

options.declare(
    'plots',
    default='matplotlib',
    values=['matplotlib', 'bokeh'],
    desc='The plot library used to generate output plots for Dymos.')

options.declare('notebook_mode',
                default=False,
                types=bool,
                desc='If True, provide notebook-enhanced plots and outputs.')
class TestOptionsDict(unittest.TestCase):
    def setUp(self):
        self.dict = OptionsDictionary()

    def test_reprs(self):
        class MyComp(ExplicitComponent):
            pass

        my_comp = MyComp()

        self.dict.declare('test', values=['a', 'b'], desc='Test integer value')
        self.dict.declare('flag', default=False, types=bool)
        self.dict.declare('comp', default=my_comp, types=ExplicitComponent)
        self.dict.declare('long_desc',
                          types=str,
                          desc='This description is long and verbose, so it '
                          'takes up multiple lines in the options table.')

        self.assertEqual(repr(self.dict), repr(self.dict._dict))

        self.assertEqual(
            self.dict.__str__(width=83), '\n'.join([
                "========= ============ ================= ===================== ====================",
                "Option    Default      Acceptable Values Acceptable Types      Description         ",
                "========= ============ ================= ===================== ====================",
                "comp      MyComp       N/A               ['ExplicitComponent']                     ",
                "flag      False        [True, False]     ['bool']                                  ",
                "long_desc **Required** N/A               ['str']               This description is ",
                "                                                               long and verbose, so",
                "                                                                it takes up multipl",
                "                                                               e lines in the optio",
                "                                                               ns table.",
                "test      **Required** ['a', 'b']        N/A                   Test integer value  ",
                "========= ============ ================= ===================== ====================",
            ]))

        # if the table can't be represented in specified width, then we get the full width version
        self.assertEqual(
            self.dict.__str__(width=40), '\n'.join([
                "========= ============ ================= ===================== ====================="
                "==================================================================== ",
                "Option    Default      Acceptable Values Acceptable Types      Description          "
                "                                                                     ",
                "========= ============ ================= ===================== ====================="
                "==================================================================== ",
                "comp      MyComp       N/A               ['ExplicitComponent']                      "
                "                                                                     ",
                "flag      False        [True, False]     ['bool']                                   "
                "                                                                     ",
                "long_desc **Required** N/A               ['str']               This description is l"
                "ong and verbose, so it takes up multiple lines in the options table. ",
                "test      **Required** ['a', 'b']        N/A                   Test integer value   "
                "                                                                     ",
                "========= ============ ================= ===================== ====================="
                "==================================================================== ",
            ]))

    def test_type_checking(self):
        self.dict.declare('test', types=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        expected_msg = "Value ('') of option 'test' has type 'str', " \
                       "but type 'int' was expected."
        self.assertEqual(expected_msg, str(context.exception))

        # multiple types are allowed
        self.dict.declare('test_multi',
                          types=(int, float),
                          desc='Test multiple types')

        self.dict['test_multi'] = 1
        self.assertEqual(self.dict['test_multi'], 1)
        self.assertEqual(type(self.dict['test_multi']), int)

        self.dict['test_multi'] = 1.0
        self.assertEqual(self.dict['test_multi'], 1.0)
        self.assertEqual(type(self.dict['test_multi']), float)

        with self.assertRaises(TypeError) as context:
            self.dict['test_multi'] = ''

        expected_msg = "Value ('') of option 'test_multi' has type 'str', " \
                       "but one of types ('int', 'float') was expected."
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work and allowed values are populated
        self.dict.declare('flag', default=False, types=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

        meta = self.dict._dict['flag']
        self.assertEqual(meta['values'], (True, False))

    def test_allow_none(self):
        self.dict.declare('test',
                          types=int,
                          allow_none=True,
                          desc='Test integer value')
        self.dict['test'] = None
        self.assertEqual(self.dict['test'], None)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', types=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        with self.assertRaises(Exception) as context:
            self.dict.declare('test3', types=int, values=['a', 'b'])
        self.assertEqual(
            str(context.exception),
            "'types' and 'values' were both specified for option 'test3'.")

    def test_isvalid(self):
        self.dict.declare('even_test', types=int, check_valid=check_even)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Option 'even_test' with value 3 is not an even number."
        self.assertEqual(expected_msg, str(context.exception))

    def test_isvalid_deprecated_type(self):

        msg = "In declaration of option 'even_test' the '_type' arg is deprecated.  Use 'types' instead."

        with assert_warning(DeprecationWarning, msg):
            self.dict.declare('even_test', type_=int, check_valid=check_even)

        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Option 'even_test' with value 3 is not an even number."
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Option 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', types=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Option 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "\"Option 'missing' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, types=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = (
            "Value \(<object object at 0x[0-9A-Fa-f]+>\) of option 'test' is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
            " <object object at 0x[0-9A-Fa-f]+>\].")
        assertRegex(self, str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set read-only option 'permanent'.")
        assertRegex(self, str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = "Value (3.0) of option 'x' exceeds maximum allowed value of 2.0."
        self.assertEqual(str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = "Value (-3.0) of option 'x' is less than minimum allowed value of 0.0."
        self.assertEqual(str(context.exception), expected_msg)

    def test_undeclare(self):
        # create an entry in the dict
        self.dict.declare('test', types=int)
        self.dict['test'] = 1

        # prove it's in the dict
        self.assertEqual(self.dict['test'], 1)

        # remove entry from the dict
        self.dict.undeclare("test")

        # prove it is no longer in the dict
        with self.assertRaises(KeyError) as context:
            self.dict['test']

        expected_msg = "\"Option 'test' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))
示例#20
0
文件: options.py 项目: wright/dymos
from openmdao.api import OptionsDictionary

options = OptionsDictionary()

options.declare('include_check_partials', default=False, types=bool,
                desc='If True, include dymos components when checking partials.')
 def setUp(self):
     self.dict = OptionsDictionary()
示例#22
0
    def add_linkage(self,
                    name,
                    vars,
                    shape=(1, ),
                    equals=None,
                    lower=None,
                    upper=None,
                    units=None,
                    scaler=None,
                    adder=None,
                    ref0=None,
                    ref=None,
                    linear=False):
        """
        Add a linkage constraint to be managed by this component.

        .. math ::

            C_n = y_{n1} - y_{n0}

        where :math:`y_1` is the value of the variable at the beginning or end of phase 1,
        and :math:`y_0` is the value of the variable at the beginning or end of phase 0.
        The location of the source of the constraint can be set by the user based on
        connected indices.

        The name of each linkage constraint will be LNK_var where LNK is the name of the linkage
        and var are the vars in that linkage.

        Parameters
        ----------
        name : str.
            The name of one or more linkage constraints to be added.
        vars : str or iterable
            The name of one or more linked variables to be added.
        shape : tuple
            The shape of the constraint being formed.  Must be compliant with the shape
            of the variable.
        units : str, dict, or None
            The units of the linkage constraint.  If given as a string, the units will
            apply to each variable in vars.  If given as a dict, it should be keyed
            with variables in var, and the associated value being the corresponding units.
            Default is None.
        lower : float or ndarray
            The minimum allowable difference of y_1 - y_0, enforced by the optimizer.
        upper : float or ndarray
            The minimum allowable difference of y_1 - y_0, enforced by the optimizer.
        equals : float or ndarray
            The prescribed difference of y_1 - y_0, enforced bt the optimizer.
        scaler : float, ndarray, or None
            The scalar applied to this constraint by the driver.
        adder : float, ndarray, or None
            The adder applied to this constraint by the driver.
        ref0 : float, ndarray, or None
            The zero-reference value of this constraint, used for scaling by the driver.
        ref : float, ndarray, or None
            The one-reference value of this constraint, used for scaling by the driver.
        linear : bool
            If True, this constraint will be treated as a linear constraint by the optimizer.
            This should only be done if the *total derivative* of the constraint is linear.
            That is, the affected variables in each phase are design
            variables or linear functions of design variables.  Default is False.
        """
        if equals is None and lower is None and upper is None:
            equals = np.zeros(shape)

        if isinstance(vars, string_types):
            _vars = (vars, )
        else:
            _vars = vars

        if isinstance(units, string_types) or units is None:
            _units = {}
            for var in _vars:
                _units[var] = units
        else:
            _units = units

        for var in _vars:

            lnk = OptionsDictionary()

            lnk.declare('name', types=(string_types, ))
            lnk.declare('equals', types=(float, np.ndarray), allow_none=True)
            lnk.declare('lower', types=(float, np.ndarray), allow_none=True)
            lnk.declare('upper', types=(float, np.ndarray), allow_none=True)
            lnk.declare('units', types=string_types, allow_none=True)
            lnk.declare('scaler', types=(float, np.ndarray), allow_none=True)
            lnk.declare('adder', types=(float, np.ndarray), allow_none=True)
            lnk.declare('ref0', types=(float, np.ndarray), allow_none=True)
            lnk.declare('ref', types=(float, np.ndarray), allow_none=True)
            lnk.declare('linear', types=bool)
            lnk.declare('shape', types=tuple)
            lnk.declare('cond0_name', types=string_types)
            lnk.declare('cond1_name', types=string_types)

            lnk['name'] = '{0}_{1}'.format(name, var)
            lnk['equals'] = equals
            lnk['lower'] = lower
            lnk['upper'] = upper
            lnk['scaler'] = scaler
            lnk['adder'] = adder
            lnk['ref0'] = ref0
            lnk['ref'] = ref
            lnk['shape'] = shape
            lnk['linear'] = linear
            lnk['units'] = _units.get(var, None)
            lnk['cond0_name'] = '{0}:lhs'.format(lnk['name'])
            lnk['cond1_name'] = '{0}:rhs'.format(lnk['name'])

            self.options['linkages'].append(lnk)
示例#23
0
    def test_options_dictionary(self):
        self.options = OptionsDictionary()

        # Make sure we can't address keys we haven't added

        with self.assertRaises(KeyError) as cm:
            self.options['junk']

        self.assertEqual('"Option \'{}\' has not been added"'.format('junk'),
                         str(cm.exception))

        # Type checking - don't set a float with an int

        self.options.add_option('atol', 1e-6)
        self.assertEqual(self.options['atol'], 1.0e-6)

        with self.assertRaises(ValueError) as cm:
            self.options['atol'] = 1

        if PY2:
            self.assertEqual("'atol' should be a '<type 'float'>'",
                             str(cm.exception))
        else:
            self.assertEqual("'atol' should be a '<class 'float'>'",
                             str(cm.exception))

        # Check enum out of range

        self.options.add_option('iprint', 0, values=[0, 1, 2, 3])
        for value in [0, 1, 2, 3]:
            self.options['iprint'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['iprint'] = 4

        self.assertEqual(
            "'iprint' must be one of the following values: '[0, 1, 2, 3]'",
            str(cm.exception))

        # Type checking for boolean

        self.options.add_option('conmin_diff', True)
        self.options['conmin_diff'] = True
        self.options['conmin_diff'] = False

        with self.assertRaises(ValueError) as cm:
            self.options['conmin_diff'] = "YES!"

        if PY2:
            self.assertEqual("'conmin_diff' should be a '<type 'bool'>'",
                             str(cm.exception))
        else:
            self.assertEqual("'conmin_diff' should be a '<class 'bool'>'",
                             str(cm.exception))

        # Test Max and Min

        self.options.add_option('maxiter', 10, lower=0, upper=10)
        for value in range(0, 11):
            self.options['maxiter'] = value

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = 15

        self.assertEqual("maximum allowed value for 'maxiter' is '10'",
                         str(cm.exception))

        with self.assertRaises(ValueError) as cm:
            self.options['maxiter'] = -1

        self.assertEqual("minimum allowed value for 'maxiter' is '0'",
                         str(cm.exception))

        # Make sure we can't do this
        with self.assertRaises(ValueError) as cm:
            self.options.maxiter = -1

        self.assertEqual("Use dict-like access for option 'maxiter'",
                         str(cm.exception))
示例#24
0
class TestOptionsDict(unittest.TestCase):

    def setUp(self):
        self.dict = OptionsDictionary()

    def test_type_checking(self):
        self.dict.declare('test', types=int, desc='Test integer value')

        self.dict['test'] = 1
        self.assertEqual(self.dict['test'], 1)

        with self.assertRaises(TypeError) as context:
            self.dict['test'] = ''

        class_or_type = 'class' if PY3 else 'type'
        expected_msg = "Entry 'test' has the wrong type (<{} 'int'>)".format(class_or_type)
        self.assertEqual(expected_msg, str(context.exception))

        # make sure bools work
        self.dict.declare('flag', default=False, types=bool)
        self.assertEqual(self.dict['flag'], False)
        self.dict['flag'] = True
        self.assertEqual(self.dict['flag'], True)

    def test_allow_none(self):
        self.dict.declare('test', types=int, allow_none=True, desc='Test integer value')
        self.dict['test'] = None
        self.assertEqual(self.dict['test'], None)

    def test_type_and_values(self):
        # Test with only type_
        self.dict.declare('test1', types=int)
        self.dict['test1'] = 1
        self.assertEqual(self.dict['test1'], 1)

        # Test with only values
        self.dict.declare('test2', values=['a', 'b'])
        self.dict['test2'] = 'a'
        self.assertEqual(self.dict['test2'], 'a')

        # Test with both type_ and values
        with self.assertRaises(Exception) as context:
            self.dict.declare('test3', types=int, values=['a', 'b'])
        self.assertEqual(str(context.exception),
                         "'types' and 'values' were both specified for option 'test3'.")

    def test_isvalid(self):
        self.dict.declare('even_test', types=int, is_valid=lambda x: x%2 == 0)
        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_isvalid_deprecated_type(self):

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            self.dict.declare('even_test', type_=int, is_valid=lambda x: x%2 == 0)
            self.assertEqual(len(w), 1)
            self.assertEqual(str(w[-1].message), "In declaration of option 'even_test' the '_type' arg is deprecated.  Use 'types' instead.")

        self.dict['even_test'] = 2
        self.dict['even_test'] = 4

        with self.assertRaises(ValueError) as context:
            self.dict['even_test'] = 3

        expected_msg = "Function is_valid returns False for {}.".format('even_test')
        self.assertEqual(expected_msg, str(context.exception))

    def test_unnamed_args(self):
        with self.assertRaises(KeyError) as context:
            self.dict['test'] = 1

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_contains(self):
        self.dict.declare('test')

        contains = 'undeclared' in self.dict
        self.assertTrue(not contains)

        contains = 'test' in self.dict
        self.assertTrue(contains)

    def test_update(self):
        self.dict.declare('test', default='Test value', types=object)

        obj = object()
        self.dict.update({'test': obj})
        self.assertIs(self.dict['test'], obj)

    def test_update_extra(self):
        with self.assertRaises(KeyError) as context:
            self.dict.update({'test': 2})

        # KeyError ends up with an extra set of quotes.
        expected_msg = "\"Key 'test' cannot be set because it has not been declared.\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_missing(self):
        with self.assertRaises(KeyError) as context:
            self.dict['missing']

        expected_msg = "\"Entry 'missing' cannot be found\""
        self.assertEqual(expected_msg, str(context.exception))

    def test_get_default(self):
        obj_def = object()
        obj_new = object()

        self.dict.declare('test', default=obj_def, types=object)

        self.assertIs(self.dict['test'], obj_def)

        self.dict['test'] = obj_new
        self.assertIs(self.dict['test'], obj_new)

    def test_values(self):
        obj1 = object()
        obj2 = object()
        self.dict.declare('test', values=[obj1, obj2])

        self.dict['test'] = obj1
        self.assertIs(self.dict['test'], obj1)

        with self.assertRaises(ValueError) as context:
            self.dict['test'] = object()

        expected_msg = ("Entry 'test''s value is not one of \[<object object at 0x[0-9A-Fa-f]+>,"
                        " <object object at 0x[0-9A-Fa-f]+>\]")
        assertRegex(self, str(context.exception), expected_msg)

    def test_read_only(self):
        opt = OptionsDictionary(read_only=True)
        opt.declare('permanent', 3.0)

        with self.assertRaises(KeyError) as context:
            opt['permanent'] = 4.0

        expected_msg = ("Tried to set 'permanent' on a read-only OptionsDictionary")
        assertRegex(self, str(context.exception), expected_msg)

    def test_bounds(self):
        self.dict.declare('x', default=1.0, lower=0.0, upper=2.0)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = 3.0

        expected_msg = ("Value of 3.0 exceeds maximum of 2.0 for entry 'x'")
        assertRegex(self, str(context.exception), expected_msg)

        with self.assertRaises(ValueError) as context:
            self.dict['x'] = -3.0

        expected_msg = ("Value of -3.0 exceeds minimum of 0.0 for entry 'x'")
        assertRegex(self, str(context.exception), expected_msg)