def testJson(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d=True) self.assertDictEqual( {"aaa": 1, "b": 2.0, "c_c": "relu6", "d": True}, hparams.values() ) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual("relu6", hparams.c_c) hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}') self.assertDictEqual( {"aaa": 12, "b": 3.0, "c_c": "relu4", "d": False}, hparams.values() ) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual("relu4", hparams.c_c) json_str = hparams.to_json() hparams2 = hparam.HParams(aaa=10, b=20.0, c_c="hello", d=False) hparams2.parse_json(json_str) self.assertEqual(12, hparams2.aaa) self.assertEqual(3.0, hparams2.b) self.assertEqual("relu4", hparams2.c_c) self.assertEqual(False, hparams2.d) hparams3 = hparam.HParams(aaa=123) self.assertEqual('{"aaa": 123}', hparams3.to_json()) self.assertEqual('{\n "aaa": 123\n}', hparams3.to_json(indent=2)) self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(";", "="))) hparams4 = hparam.HParams(aaa=123, b="hello", c_c=False) self.assertEqual( '{"aaa": 123, "b": "hello", "c_c": false}', hparams4.to_json(sort_keys=True) )
def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c="tanh") hparams.override_from_dict({"a": -2, "c": "identity"}) self.assertDictEqual({"a": -2, "c": "identity", "b": 2.0}, hparams.values()) hparams = hparam.HParams(x=1, b=2.0, d=[0.5]) hparams.override_from_dict({"d": [0.1, 0.2, 0.3]}) self.assertDictEqual({"d": [0.1, 0.2, 0.3], "x": 1, "b": 2.0}, hparams.values())
def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) self.assertDictEqual({ 'a': -2, 'c': 'identity', 'b': 2.0 }, hparams.values()) hparams = hparam.HParams(x=1, b=2.0, d=[0.5]) hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) self.assertDictEqual({ 'd': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0 }, hparams.values())
def testBoolParsing(self): for value in 'true', 'false', 'True', 'False', '1', '0': for initial in False, True: hparams = hparam.HParams(use_gpu=initial) hparams.parse('use_gpu=' + value) self.assertEqual(hparams.use_gpu, value in ['True', 'true', '1'])
def testLists(self): hparams = hparam.HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6']) self.assertDictEqual({ 'aaa': [1], 'b': [2.0, 3.0], 'c_c': ['relu6'] }, hparams.values()) self.assertEqual([1], hparams.aaa) self.assertEqual([2.0, 3.0], hparams.b) self.assertEqual(['relu6'], hparams.c_c) hparams.parse('aaa=[12]') self.assertEqual([12], hparams.aaa) hparams.parse('aaa=[12,34,56]') self.assertEqual([12, 34, 56], hparams.aaa) hparams.parse('c_c=[relu4,relu12],b=[1.0]') self.assertEqual(['relu4', 'relu12'], hparams.c_c) self.assertEqual([1.0], hparams.b) hparams.parse('c_c=[],aaa=[-34]') self.assertEqual([-34], hparams.aaa) self.assertEqual([], hparams.c_c) hparams.parse('c_c=[_12,3\'4"],aaa=[+3]') self.assertEqual([3], hparams.aaa) self.assertEqual(['_12', '3\'4"'], hparams.c_c) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=[123]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[poipoi]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=[1.0]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[12x]') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=[relu]') with self.assertRaisesRegexp(ValueError, 'Must pass a list'): hparams.parse('aaa=123')
def testEmpty(self): hparams = hparam.HParams() self.assertDictEqual({}, hparams.values()) hparams.parse('') self.assertDictEqual({}, hparams.values()) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('xyz=123')
def testSetHParamTypeMismatch(self): hparams = hparam.HParams(int_=1, str_="str", bool_=True, float_=1.1, list_int=[1, 2], none=None) with self.assertRaises(ValueError): hparams.set_hparam("str_", 2.2) with self.assertRaises(ValueError): hparams.set_hparam("int_", False) with self.assertRaises(ValueError): hparams.set_hparam("bool_", 1) with self.assertRaises(ValueError): hparams.set_hparam("int_", 2.2) with self.assertRaises(ValueError): hparams.set_hparam("list_int", [2, 3.3]) with self.assertRaises(ValueError): hparams.set_hparam("int_", "2") # Casting int to float is OK hparams.set_hparam("float_", 1) # Getting stuck with NoneType :( hparams.set_hparam("none", "1") self.assertEqual("1", hparams.none)
def testSetHParam(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d=True) self.assertDictEqual({ "aaa": 1, "b": 2.0, "c_c": "relu6", "d": True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual("relu6", hparams.c_c) hparams.set_hparam("aaa", 12) hparams.set_hparam("b", 3.0) hparams.set_hparam("c_c", "relu4") hparams.set_hparam("d", False) self.assertDictEqual({ "aaa": 12, "b": 3.0, "c_c": "relu4", "d": False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual("relu4", hparams.c_c)
def testSetHParam(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.set_hparam('aaa', 12) hparams.set_hparam('b', 3.0) hparams.set_hparam('c_c', 'relu4') hparams.set_hparam('d', False) self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c)
def testSetHParamTypeMismatch(self): hparams = hparam.HParams(int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None) with self.assertRaises(ValueError): hparams.set_hparam('str_', 2.2) with self.assertRaises(ValueError): hparams.set_hparam('int_', False) with self.assertRaises(ValueError): hparams.set_hparam('bool_', 1) with self.assertRaises(ValueError): hparams.set_hparam('int_', 2.2) with self.assertRaises(ValueError): hparams.set_hparam('list_int', [2, 3.3]) with self.assertRaises(ValueError): hparams.set_hparam('int_', '2') # Casting int to float is OK hparams.set_hparam('float_', 1) # Getting stuck with NoneType :( hparams.set_hparam('none', '1') self.assertEqual('1', hparams.none)
def testLists(self): hparams = hparam.HParams(aaa=[1], b=[2.0, 3.0], c_c=["relu6"]) self.assertDictEqual({ "aaa": [1], "b": [2.0, 3.0], "c_c": ["relu6"] }, hparams.values()) self.assertEqual([1], hparams.aaa) self.assertEqual([2.0, 3.0], hparams.b) self.assertEqual(["relu6"], hparams.c_c) hparams.parse("aaa=[12]") self.assertEqual([12], hparams.aaa) hparams.parse("aaa=[12,34,56]") self.assertEqual([12, 34, 56], hparams.aaa) hparams.parse("c_c=[relu4,relu12],b=[1.0]") self.assertEqual(["relu4", "relu12"], hparams.c_c) self.assertEqual([1.0], hparams.b) hparams.parse("c_c=[],aaa=[-34]") self.assertEqual([-34], hparams.aaa) self.assertEqual([], hparams.c_c) hparams.parse("c_c=[_12,3'4\"],aaa=[+3]") self.assertEqual([3], hparams.aaa) self.assertEqual(["_12", "3'4\""], hparams.c_c) with self.assertRaisesRegexp(ValueError, "Unknown hyperparameter"): hparams.parse("x=[123]") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("aaa=[poipoi]") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("aaa=[1.0]") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("b=[12x]") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("b=[relu]") with self.assertRaisesRegexp(ValueError, "Must pass a list"): hparams.parse("aaa=123")
def testBoolParsing(self): for value in "true", "false", "True", "False", "1", "0": for initial in False, True: hparams = hparam.HParams(use_gpu=initial) hparams.parse("use_gpu=" + value) self.assertEqual(hparams.use_gpu, value in ["True", "true", "1"])
def test_parse_dict(self): hparams = hparam.HParams(aaa=1) hparams.parse({ "b": 2.0, "c_c": "relu6", "d": True, "e": [5.0, 6.0] }, True)
def test_parse_dict(self): hparams = hparam.HParams(aaa=1) hparams.parse({ 'b': 2.0, 'c_c': 'relu6', 'd': True, 'e': [5.0, 6.0] }, True)
def testFunction(self): def f(x): return x hparams = hparam.HParams(function=f) self.assertEqual(hparams.function, f) json_str = hparams.to_json() self.assertEqual(json_str, '{}')
def testJson(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True) self.assertDictEqual({ 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': True }, hparams.values()) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}') self.assertDictEqual({ 'aaa': 12, 'b': 3.0, 'c_c': 'relu4', 'd': False }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(3.0, hparams.b) self.assertEqual('relu4', hparams.c_c) json_str = hparams.to_json() hparams2 = hparam.HParams(aaa=10, b=20.0, c_c='hello', d=False) hparams2.parse_json(json_str) self.assertEqual(12, hparams2.aaa) self.assertEqual(3.0, hparams2.b) self.assertEqual('relu4', hparams2.c_c) self.assertEqual(False, hparams2.d) hparams3 = hparam.HParams(aaa=123) self.assertEqual('{"aaa": 123}', hparams3.to_json()) self.assertEqual('{\n "aaa": 123\n}', hparams3.to_json(indent=2)) self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(';', '='))) hparams4 = hparam.HParams(aaa=123, b='hello', c_c=False) self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}', hparams4.to_json(sort_keys=True))
def testWithPeriodInVariableName(self): hparams = hparam.HParams() hparams.add_hparam(name="a.b", value=0.0) hparams.parse("a.b=1.0") self.assertEqual(1.0, getattr(hparams, "a.b")) hparams.add_hparam(name="c.d", value=0.0) with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("c.d=abc") hparams.add_hparam(name="e.f", value="") hparams.parse("e.f=abc") self.assertEqual("abc", getattr(hparams, "e.f")) hparams.add_hparam(name="d..", value=0.0) hparams.parse("d..=10.0") self.assertEqual(10.0, getattr(hparams, "d.."))
def testWithPeriodInVariableName(self): hparams = hparam.HParams() hparams.add_hparam(name='a.b', value=0.0) hparams.parse('a.b=1.0') self.assertEqual(1.0, getattr(hparams, 'a.b')) hparams.add_hparam(name='c.d', value=0.0) with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('c.d=abc') hparams.add_hparam(name='e.f', value='') hparams.parse('e.f=abc') self.assertEqual('abc', getattr(hparams, 'e.f')) hparams.add_hparam(name='d..', value=0.0) hparams.parse('d..=10.0') self.assertEqual(10.0, getattr(hparams, 'd..'))
def testDel(self): hparams = hparam.HParams(aaa=1, b=2.0) with self.assertRaises(ValueError): hparams.set_hparam('aaa', 'will fail') with self.assertRaises(ValueError): hparams.add_hparam('aaa', 'will fail') hparams.del_hparam('aaa') hparams.add_hparam('aaa', 'will work') self.assertEqual('will work', hparams.get('aaa')) hparams.set_hparam('aaa', 'still works') self.assertEqual('still works', hparams.get('aaa'))
def testDel(self): hparams = hparam.HParams(aaa=1, b=2.0) with self.assertRaises(ValueError): hparams.set_hparam("aaa", "will fail") with self.assertRaises(ValueError): hparams.add_hparam("aaa", "will fail") hparams.del_hparam("aaa") hparams.add_hparam("aaa", "will work") self.assertEqual("will work", hparams.get("aaa")) hparams.set_hparam("aaa", "still works") self.assertEqual("still works", hparams.get("aaa"))
def testGet(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0]) # Existing parameters with default=None. self.assertEqual(1, hparams.get('aaa')) self.assertEqual(2.0, hparams.get('b')) self.assertEqual('relu6', hparams.get('c_c')) self.assertEqual(True, hparams.get('d')) self.assertEqual([5.0, 6.0], hparams.get('e', None)) # Existing parameters with compatible defaults. self.assertEqual(1, hparams.get('aaa', 2)) self.assertEqual(2.0, hparams.get('b', 3.0)) self.assertEqual(2.0, hparams.get('b', 3)) self.assertEqual('relu6', hparams.get('c_c', 'default')) self.assertEqual(True, hparams.get('d', True)) self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0])) self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3])) # Existing parameters with incompatible defaults. with self.assertRaises(ValueError): hparams.get('aaa', 2.0) with self.assertRaises(ValueError): hparams.get('b', False) with self.assertRaises(ValueError): hparams.get('c_c', [1, 2, 3]) with self.assertRaises(ValueError): hparams.get('d', 'relu') with self.assertRaises(ValueError): hparams.get('e', 123.0) with self.assertRaises(ValueError): hparams.get('e', ['a', 'b', 'c']) # Nonexistent parameters. self.assertEqual(None, hparams.get('unknown')) self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
def testGet(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d=True, e=[5.0, 6.0]) # Existing parameters with default=None. self.assertEqual(1, hparams.get("aaa")) self.assertEqual(2.0, hparams.get("b")) self.assertEqual("relu6", hparams.get("c_c")) self.assertEqual(True, hparams.get("d")) self.assertEqual([5.0, 6.0], hparams.get("e", None)) # Existing parameters with compatible defaults. self.assertEqual(1, hparams.get("aaa", 2)) self.assertEqual(2.0, hparams.get("b", 3.0)) self.assertEqual(2.0, hparams.get("b", 3)) self.assertEqual("relu6", hparams.get("c_c", "default")) self.assertEqual(True, hparams.get("d", True)) self.assertEqual([5.0, 6.0], hparams.get("e", [1.0, 2.0, 3.0])) self.assertEqual([5.0, 6.0], hparams.get("e", [1, 2, 3])) # Existing parameters with incompatible defaults. with self.assertRaises(ValueError): hparams.get("aaa", 2.0) with self.assertRaises(ValueError): hparams.get("b", False) with self.assertRaises(ValueError): hparams.get("c_c", [1, 2, 3]) with self.assertRaises(ValueError): hparams.get("d", "relu") with self.assertRaises(ValueError): hparams.get("e", 123.0) with self.assertRaises(ValueError): hparams.get("e", ["a", "b", "c"]) # Nonexistent parameters. self.assertEqual(None, hparams.get("unknown")) self.assertEqual(123, hparams.get("unknown", 123)) self.assertEqual([1, 2, 3], hparams.get("unknown", [1, 2, 3]))
def test_append(self): hparams = hparam.HParams(aaa=1) hparams2 = hparam.HParams(b=2.0, c_c='relu6', d=True, e=[5.0, 6.0]) hparams.append(hparams2) print(hparams)
def testContains(self): hparams = hparam.HParams(foo=1) self.assertTrue("foo" in hparams) self.assertFalse("bar" in hparams)
def testSetHParamListNonListMismatch(self): hparams = hparam.HParams(a=1, b=[2.0, 3.0]) with self.assertRaisesRegexp(ValueError, r'Must not pass a list'): hparams.set_hparam('a', [1.0]) with self.assertRaisesRegexp(ValueError, r'Must pass a list'): hparams.set_hparam('b', 1.0)
def testSomeValues(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d="/a/b=c/d") self.assertDictEqual( { "aaa": 1, "b": 2.0, "c_c": "relu6", "d": "/a/b=c/d" }, hparams.values()) expected_str = ("[('aaa', 1), ('b', 2.0), ('c_c', 'relu6'), " "('d', '/a/b=c/d')]") self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual("relu6", hparams.c_c) self.assertEqual("/a/b=c/d", hparams.d) hparams.parse("aaa=12") self.assertDictEqual( { "aaa": 12, "b": 2.0, "c_c": "relu6", "d": "/a/b=c/d" }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual("relu6", hparams.c_c) self.assertEqual("/a/b=c/d", hparams.d) hparams.parse("c_c=relu4, b=-2.0e10") self.assertDictEqual( { "aaa": 12, "b": -2.0e10, "c_c": "relu4", "d": "/a/b=c/d" }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual("relu4", hparams.c_c) self.assertEqual("/a/b=c/d", hparams.d) hparams.parse("c_c=,b=0,") self.assertDictEqual({ "aaa": 12, "b": 0, "c_c": "", "d": "/a/b=c/d" }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual("", hparams.c_c) self.assertEqual("/a/b=c/d", hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) hparams.parse("d=/a/b/c/d,aaa=11,") self.assertEqual(11, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual("/a/b/c/d", hparams.d) hparams.parse("b=1.5,d=/a=b/c/d,aaa=10,") self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual("/a=b/c/d", hparams.d) with self.assertRaisesRegexp(ValueError, "Unknown hyperparameter"): hparams.parse("x=123") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("aaa=poipoi") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("aaa=1.0") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("b=12x") with self.assertRaisesRegexp(ValueError, "Could not parse"): hparams.parse("b=relu") with self.assertRaisesRegexp(ValueError, "Must not pass a list"): hparams.parse("aaa=[123]") self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual("/a=b/c/d", hparams.d)
def testSomeValues(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') self.assertDictEqual( { 'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' '(\'d\', \'/a/b=c/d\')]') self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual( { 'aaa': 12, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual( { 'aaa': 12, 'b': -2.0e10, 'c_c': 'relu4', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') self.assertDictEqual({ 'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) hparams.parse('d=/a/b/c/d,aaa=11,') self.assertEqual(11, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a/b/c/d', hparams.d) hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=poipoi') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('aaa=1.0') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=12x') with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') self.assertEqual(10, hparams.aaa) self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) self.assertEqual('/a=b/c/d', hparams.d)
def testContains(self): hparams = hparam.HParams(foo=1) self.assertTrue('foo' in hparams) self.assertFalse('bar' in hparams)
def testBoolParsingFail(self): hparams = hparam.HParams(use_gpu=True) with self.assertRaisesRegexp(ValueError, r'Could not parse.*use_gpu'): hparams.parse('use_gpu=yep')