예제 #1
0
    def testSetHParam(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)

        hparams.set_hparam('aaa', 12)
        hparams.set_hparam('b', 3.0)
        hparams.set_hparam('c_c', 'relu4')
        hparams.set_hparam('d', False)
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)
예제 #2
0
 def testEmpty(self):
     hparams = HParams()
     self.assertDictEqual({}, hparams.values())
     hparams.parse('')
     self.assertDictEqual({}, hparams.values())
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('xyz=123')
예제 #3
0
 def testLists(self):
     hparams = HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6'])
     self.assertDictEqual({
         'aaa': [1],
         'b': [2.0, 3.0],
         'c_c': ['relu6']
     }, hparams.values())
     self.assertEqual([1], hparams.aaa)
     self.assertEqual([2.0, 3.0], hparams.b)
     self.assertEqual(['relu6'], hparams.c_c)
     hparams.parse('aaa=[12]')
     self.assertEqual([12], hparams.aaa)
     hparams.parse('aaa=[12,34,56]')
     self.assertEqual([12, 34, 56], hparams.aaa)
     hparams.parse('c_c=[relu4,relu12],b=[1.0]')
     self.assertEqual(['relu4', 'relu12'], hparams.c_c)
     self.assertEqual([1.0], hparams.b)
     hparams.parse('c_c=[],aaa=[-34]')
     self.assertEqual([-34], hparams.aaa)
     self.assertEqual([], hparams.c_c)
     hparams.parse('c_c=[_12,3\'4"],aaa=[+3]')
     self.assertEqual([3], hparams.aaa)
     self.assertEqual(['_12', '3\'4"'], hparams.c_c)
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('x=[123]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=[poipoi]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=[1.0]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=[12x]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=[relu]')
     with self.assertRaisesRegexp(ValueError, 'Must pass a list'):
         hparams.parse('aaa=123')
예제 #4
0
    def testSetFromMap(self):
        hparams = HParams(a=1, b=2.0, c='tanh')
        hparams.override_from_dict({'a': -2, 'c': 'identity'})
        self.assertDictEqual({
            'a': -2,
            'c': 'identity',
            'b': 2.0
        }, hparams.values())

        hparams = HParams(x=1, b=2.0, d=[0.5])
        hparams.override_from_dict({'d': [0.1, 0.2, 0.3]})
        self.assertDictEqual({
            'd': [0.1, 0.2, 0.3],
            'x': 1,
            'b': 2.0
        }, hparams.values())
예제 #5
0
def copy_hparams(hparams: HParams):
    hp_vals = hparams.values()
    new_hparams = HParams(**hp_vals)
    # Make sure below params are part of new phparams while making a copy
    other_attrs = ["problem", "problem_hparams"]
    for attr in other_attrs:
        attr_val = getattr(hparams, attr, None)
        if attr_val is not None:
            setattr(new_hparams, attr, attr_val)
    return new_hparams
예제 #6
0
    def testJson(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)
        hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}')
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)

        json_str = hparams.to_json()
        hparams2 = HParams(aaa=10, b=20.0, c_c='hello', d=False)
        hparams2.parse_json(json_str)
        self.assertEqual(12, hparams2.aaa)
        self.assertEqual(3.0, hparams2.b)
        self.assertEqual('relu4', hparams2.c_c)
        self.assertEqual(False, hparams2.d)

        hparams3 = HParams(aaa=123)
        self.assertEqual('{"aaa": 123}', hparams3.to_json())
        self.assertEqual('{\n  "aaa": 123\n}', hparams3.to_json(indent=2))
        self.assertEqual('{"aaa"=123}',
                         hparams3.to_json(separators=(';', '=')))

        hparams4 = HParams(aaa=123, b='hello', c_c=False)
        self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}',
                         hparams4.to_json(sort_keys=True))
예제 #7
0
def create_hparams_from_json(json_path, hparams=None):
    """Loading hparams from json; can also start from hparams if specified."""
    tf.logging.info("Loading hparams from existing json %s" % json_path)
    with tf.gfile.Open(json_path, "r") as f:
        hparams_values = json.load(f)
        new_hparams = HParams(**hparams_values)
        # Some keys are in new_hparams but not hparams, so we need to be more
        #   careful than simply using parse_json() from HParams
        if hparams:  # hparams specified, so update values from json
            for key in sorted(new_hparams.values().keys()):
                if hasattr(hparams, key):  # Overlapped keys
                    value = getattr(hparams, key)
                    new_value = getattr(new_hparams, key)
                    if value != new_value:  # Different values
                        tf.logging.info("Overwrite key %s: %s -> %s" %
                                        (key, value, new_value))
                        setattr(hparams, key, new_value)
        else:
            hparams = new_hparams

    return hparams
예제 #8
0
 def testSomeValues(self):
     hparams = HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d')
     self.assertDictEqual(
         {
             'aaa': 1,
             'b': 2.0,
             'c_c': 'relu6',
             'd': '/a/b=c/d'
         }, hparams.values())
     # expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), '
     #                 '(\'d\', \'/a/b=c/d\')]')
     expected_str = ('{\'aaa\': 1, \'b\': 2.0, \'c_c\': \'relu6\', '
                     '\'d\': \'/a/b=c/d\'}')
     self.assertEqual(expected_str, str(hparams.__str__()))
     self.assertEqual(expected_str, str(hparams))
     self.assertEqual(1, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('relu6', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('aaa=12')
     self.assertDictEqual(
         {
             'aaa': 12,
             'b': 2.0,
             'c_c': 'relu6',
             'd': '/a/b=c/d'
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('relu6', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=relu4, b=-2.0e10')
     self.assertDictEqual(
         {
             'aaa': 12,
             'b': -2.0e10,
             'c_c': 'relu4',
             'd': '/a/b=c/d'
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(-2.0e10, hparams.b)
     self.assertEqual('relu4', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=,b=0,')
     self.assertDictEqual({
         'aaa': 12,
         'b': 0,
         'c_c': '',
         'd': '/a/b=c/d'
     }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(0.0, hparams.b)
     self.assertEqual('', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=2.3",b=+2,')
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     hparams.parse('d=/a/b/c/d,aaa=11,')
     self.assertEqual(11, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a/b/c/d', hparams.d)
     hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,')
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a=b/c/d', hparams.d)
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('x=123')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=poipoi')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=1.0')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=12x')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=relu')
     with self.assertRaisesRegexp(ValueError, 'Must not pass a list'):
         hparams.parse('aaa=[123]')
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a=b/c/d', hparams.d)