def testSetHParam(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.set_hparam('aaa', 12) hparams.set_hparam('b', 3.0) hparams.set_hparam('c_c', 'relu4') hparams.set_hparam('d', False) self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c)
def testEmpty(self): hparams = HParams() self.assertDictEqual({}, hparams.values()) hparams.parse('') self.assertDictEqual({}, hparams.values()) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('xyz=123')
def testLists(self): hparams = HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6']) self.assertDictEqual({ 'aaa': [1], 'b': [2.0, 3.0], 'c_c': ['relu6'] }, hparams.values()) self.assertEqual([1], hparams.aaa) self.assertEqual([2.0, 3.0], hparams.b) self.assertEqual(['relu6'], hparams.c_c) hparams.parse('aaa=[12]') self.assertEqual([12], hparams.aaa) hparams.parse('aaa=[12,34,56]') self.assertEqual([12, 34, 56], hparams.aaa) hparams.parse('c_c=[relu4,relu12],b=[1.0]') self.assertEqual(['relu4', 'relu12'], hparams.c_c) self.assertEqual([1.0], hparams.b) hparams.parse('c_c=[],aaa=[-34]') self.assertEqual([-34], hparams.aaa) self.assertEqual([], hparams.c_c) hparams.parse('c_c=[_12,3\'4"],aaa=[+3]') self.assertEqual([3], hparams.aaa) self.assertEqual(['_12', '3\'4"'], hparams.c_c) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=[123]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[poipoi]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[1.0]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[12x]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[relu]') with self.assertRaisesRegexp(ValueError, 'Must pass a list'): hparams.parse('aaa=123')
def testSetFromMap(self): hparams = HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) self.assertDictEqual({ 'a': -2, 'c': 'identity', 'b': 2.0 }, hparams.values()) hparams = HParams(x=1, b=2.0, d=[0.5]) hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) self.assertDictEqual({ 'd': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0 }, hparams.values())
def copy_hparams(hparams: HParams): hp_vals = hparams.values() new_hparams = HParams(**hp_vals) # Make sure below params are part of new phparams while making a copy other_attrs = ["problem", "problem_hparams"] for attr in other_attrs: attr_val = getattr(hparams, attr, None) if attr_val is not None: setattr(new_hparams, attr, attr_val) return new_hparams
def testJson(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}') self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c) json_str = hparams.to_json() hparams2 = HParams(aaa=10, b=20.0, c_c='hello', d=False) hparams2.parse_json(json_str) self.assertEqual(12, hparams2.aaa) self.assertEqual(3.0, hparams2.b) self.assertEqual('relu4', hparams2.c_c) self.assertEqual(False, hparams2.d) hparams3 = HParams(aaa=123) self.assertEqual('{"aaa": 123}', hparams3.to_json()) self.assertEqual('{\n "aaa": 123\n}', hparams3.to_json(indent=2)) self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(';', '='))) hparams4 = HParams(aaa=123, b='hello', c_c=False) self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}', hparams4.to_json(sort_keys=True))
def create_hparams_from_json(json_path, hparams=None): """Loading hparams from json; can also start from hparams if specified.""" tf.logging.info("Loading hparams from existing json %s" % json_path) with tf.gfile.Open(json_path, "r") as f: hparams_values = json.load(f) new_hparams = HParams(**hparams_values) # Some keys are in new_hparams but not hparams, so we need to be more # careful than simply using parse_json() from HParams if hparams: # hparams specified, so update values from json for key in sorted(new_hparams.values().keys()): if hasattr(hparams, key): # Overlapped keys value = getattr(hparams, key) new_value = getattr(new_hparams, key) if value != new_value: # Different values tf.logging.info("Overwrite key %s: %s -> %s" % (key, value, new_value)) setattr(hparams, key, new_value) else: hparams = new_hparams return hparams
def testSomeValues(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') self.assertDictEqual( { 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) # expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' # '(\'d\', \'/a/b=c/d\')]') expected_str = ('{\'aaa\': 1, \'b\': 2.0, \'c_c\': \'relu6\', ' '\'d\': \'/a/b=c/d\'}') self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual( { 'aaa': 12, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual( { 'aaa': 12, 'b': -2.0e10, 'c_c': 'relu4', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') self.assertDictEqual({ 'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) hparams.parse('d=/a/b/c/d,aaa=11,') self.assertEqual(11, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a/b/c/d', hparams.d) hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=poipoi') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=1.0') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=12x') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d)