コード例 #1
0
ファイル: hparam_test.py プロジェクト: Tubbz-alt/athena-1
    def testJson(self):
        hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d=True)
        self.assertDictEqual(
            {"aaa": 1, "b": 2.0, "c_c": "relu6", "d": True}, hparams.values()
        )
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual("relu6", hparams.c_c)
        hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}')
        self.assertDictEqual(
            {"aaa": 12, "b": 3.0, "c_c": "relu4", "d": False}, hparams.values()
        )
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual("relu4", hparams.c_c)

        json_str = hparams.to_json()
        hparams2 = hparam.HParams(aaa=10, b=20.0, c_c="hello", d=False)
        hparams2.parse_json(json_str)
        self.assertEqual(12, hparams2.aaa)
        self.assertEqual(3.0, hparams2.b)
        self.assertEqual("relu4", hparams2.c_c)
        self.assertEqual(False, hparams2.d)

        hparams3 = hparam.HParams(aaa=123)
        self.assertEqual('{"aaa": 123}', hparams3.to_json())
        self.assertEqual('{\n  "aaa": 123\n}', hparams3.to_json(indent=2))
        self.assertEqual('{"aaa"=123}', hparams3.to_json(separators=(";", "=")))

        hparams4 = hparam.HParams(aaa=123, b="hello", c_c=False)
        self.assertEqual(
            '{"aaa": 123, "b": "hello", "c_c": false}', hparams4.to_json(sort_keys=True)
        )
コード例 #2
0
ファイル: hparam_test.py プロジェクト: Tubbz-alt/athena-1
    def testSetFromMap(self):
        hparams = hparam.HParams(a=1, b=2.0, c="tanh")
        hparams.override_from_dict({"a": -2, "c": "identity"})
        self.assertDictEqual({"a": -2, "c": "identity", "b": 2.0}, hparams.values())

        hparams = hparam.HParams(x=1, b=2.0, d=[0.5])
        hparams.override_from_dict({"d": [0.1, 0.2, 0.3]})
        self.assertDictEqual({"d": [0.1, 0.2, 0.3], "x": 1, "b": 2.0}, hparams.values())
コード例 #3
0
    def testSetFromMap(self):
        hparams = hparam.HParams(a=1, b=2.0, c='tanh')
        hparams.override_from_dict({'a': -2, 'c': 'identity'})
        self.assertDictEqual({
            'a': -2,
            'c': 'identity',
            'b': 2.0
        }, hparams.values())

        hparams = hparam.HParams(x=1, b=2.0, d=[0.5])
        hparams.override_from_dict({'d': [0.1, 0.2, 0.3]})
        self.assertDictEqual({
            'd': [0.1, 0.2, 0.3],
            'x': 1,
            'b': 2.0
        }, hparams.values())
コード例 #4
0
 def testBoolParsing(self):
     for value in 'true', 'false', 'True', 'False', '1', '0':
         for initial in False, True:
             hparams = hparam.HParams(use_gpu=initial)
             hparams.parse('use_gpu=' + value)
             self.assertEqual(hparams.use_gpu, value
                              in ['True', 'true', '1'])
コード例 #5
0
 def testLists(self):
     hparams = hparam.HParams(aaa=[1], b=[2.0, 3.0], c_c=['relu6'])
     self.assertDictEqual({
         'aaa': [1],
         'b': [2.0, 3.0],
         'c_c': ['relu6']
     }, hparams.values())
     self.assertEqual([1], hparams.aaa)
     self.assertEqual([2.0, 3.0], hparams.b)
     self.assertEqual(['relu6'], hparams.c_c)
     hparams.parse('aaa=[12]')
     self.assertEqual([12], hparams.aaa)
     hparams.parse('aaa=[12,34,56]')
     self.assertEqual([12, 34, 56], hparams.aaa)
     hparams.parse('c_c=[relu4,relu12],b=[1.0]')
     self.assertEqual(['relu4', 'relu12'], hparams.c_c)
     self.assertEqual([1.0], hparams.b)
     hparams.parse('c_c=[],aaa=[-34]')
     self.assertEqual([-34], hparams.aaa)
     self.assertEqual([], hparams.c_c)
     hparams.parse('c_c=[_12,3\'4"],aaa=[+3]')
     self.assertEqual([3], hparams.aaa)
     self.assertEqual(['_12', '3\'4"'], hparams.c_c)
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('x=[123]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=[poipoi]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=[1.0]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=[12x]')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=[relu]')
     with self.assertRaisesRegexp(ValueError, 'Must pass a list'):
         hparams.parse('aaa=123')
コード例 #6
0
 def testEmpty(self):
     hparams = hparam.HParams()
     self.assertDictEqual({}, hparams.values())
     hparams.parse('')
     self.assertDictEqual({}, hparams.values())
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('xyz=123')
コード例 #7
0
    def testSetHParamTypeMismatch(self):
        hparams = hparam.HParams(int_=1,
                                 str_="str",
                                 bool_=True,
                                 float_=1.1,
                                 list_int=[1, 2],
                                 none=None)

        with self.assertRaises(ValueError):
            hparams.set_hparam("str_", 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam("int_", False)

        with self.assertRaises(ValueError):
            hparams.set_hparam("bool_", 1)

        with self.assertRaises(ValueError):
            hparams.set_hparam("int_", 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam("list_int", [2, 3.3])

        with self.assertRaises(ValueError):
            hparams.set_hparam("int_", "2")

        # Casting int to float is OK
        hparams.set_hparam("float_", 1)

        # Getting stuck with NoneType :(
        hparams.set_hparam("none", "1")
        self.assertEqual("1", hparams.none)
コード例 #8
0
    def testSetHParam(self):
        hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d=True)
        self.assertDictEqual({
            "aaa": 1,
            "b": 2.0,
            "c_c": "relu6",
            "d": True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual("relu6", hparams.c_c)

        hparams.set_hparam("aaa", 12)
        hparams.set_hparam("b", 3.0)
        hparams.set_hparam("c_c", "relu4")
        hparams.set_hparam("d", False)
        self.assertDictEqual({
            "aaa": 12,
            "b": 3.0,
            "c_c": "relu4",
            "d": False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual("relu4", hparams.c_c)
コード例 #9
0
    def testSetHParam(self):
        hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)

        hparams.set_hparam('aaa', 12)
        hparams.set_hparam('b', 3.0)
        hparams.set_hparam('c_c', 'relu4')
        hparams.set_hparam('d', False)
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)
コード例 #10
0
    def testSetHParamTypeMismatch(self):
        hparams = hparam.HParams(int_=1,
                                 str_='str',
                                 bool_=True,
                                 float_=1.1,
                                 list_int=[1, 2],
                                 none=None)

        with self.assertRaises(ValueError):
            hparams.set_hparam('str_', 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', False)

        with self.assertRaises(ValueError):
            hparams.set_hparam('bool_', 1)

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam('list_int', [2, 3.3])

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', '2')

        # Casting int to float is OK
        hparams.set_hparam('float_', 1)

        # Getting stuck with NoneType :(
        hparams.set_hparam('none', '1')
        self.assertEqual('1', hparams.none)
コード例 #11
0
 def testLists(self):
     hparams = hparam.HParams(aaa=[1], b=[2.0, 3.0], c_c=["relu6"])
     self.assertDictEqual({
         "aaa": [1],
         "b": [2.0, 3.0],
         "c_c": ["relu6"]
     }, hparams.values())
     self.assertEqual([1], hparams.aaa)
     self.assertEqual([2.0, 3.0], hparams.b)
     self.assertEqual(["relu6"], hparams.c_c)
     hparams.parse("aaa=[12]")
     self.assertEqual([12], hparams.aaa)
     hparams.parse("aaa=[12,34,56]")
     self.assertEqual([12, 34, 56], hparams.aaa)
     hparams.parse("c_c=[relu4,relu12],b=[1.0]")
     self.assertEqual(["relu4", "relu12"], hparams.c_c)
     self.assertEqual([1.0], hparams.b)
     hparams.parse("c_c=[],aaa=[-34]")
     self.assertEqual([-34], hparams.aaa)
     self.assertEqual([], hparams.c_c)
     hparams.parse("c_c=[_12,3'4\"],aaa=[+3]")
     self.assertEqual([3], hparams.aaa)
     self.assertEqual(["_12", "3'4\""], hparams.c_c)
     with self.assertRaisesRegexp(ValueError, "Unknown hyperparameter"):
         hparams.parse("x=[123]")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("aaa=[poipoi]")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("aaa=[1.0]")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("b=[12x]")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("b=[relu]")
     with self.assertRaisesRegexp(ValueError, "Must pass a list"):
         hparams.parse("aaa=123")
コード例 #12
0
 def testBoolParsing(self):
     for value in "true", "false", "True", "False", "1", "0":
         for initial in False, True:
             hparams = hparam.HParams(use_gpu=initial)
             hparams.parse("use_gpu=" + value)
             self.assertEqual(hparams.use_gpu, value
                              in ["True", "true", "1"])
コード例 #13
0
 def test_parse_dict(self):
     hparams = hparam.HParams(aaa=1)
     hparams.parse({
         "b": 2.0,
         "c_c": "relu6",
         "d": True,
         "e": [5.0, 6.0]
     }, True)
コード例 #14
0
 def test_parse_dict(self):
     hparams = hparam.HParams(aaa=1)
     hparams.parse({
         'b': 2.0,
         'c_c': 'relu6',
         'd': True,
         'e': [5.0, 6.0]
     }, True)
コード例 #15
0
    def testFunction(self):
        def f(x):
            return x

        hparams = hparam.HParams(function=f)
        self.assertEqual(hparams.function, f)

        json_str = hparams.to_json()
        self.assertEqual(json_str, '{}')
コード例 #16
0
    def testJson(self):
        hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)
        hparams.parse_json('{"aaa": 12, "b": 3.0, "c_c": "relu4", "d": false}')
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)

        json_str = hparams.to_json()
        hparams2 = hparam.HParams(aaa=10, b=20.0, c_c='hello', d=False)
        hparams2.parse_json(json_str)
        self.assertEqual(12, hparams2.aaa)
        self.assertEqual(3.0, hparams2.b)
        self.assertEqual('relu4', hparams2.c_c)
        self.assertEqual(False, hparams2.d)

        hparams3 = hparam.HParams(aaa=123)
        self.assertEqual('{"aaa": 123}', hparams3.to_json())
        self.assertEqual('{\n  "aaa": 123\n}', hparams3.to_json(indent=2))
        self.assertEqual('{"aaa"=123}',
                         hparams3.to_json(separators=(';', '=')))

        hparams4 = hparam.HParams(aaa=123, b='hello', c_c=False)
        self.assertEqual('{"aaa": 123, "b": "hello", "c_c": false}',
                         hparams4.to_json(sort_keys=True))
コード例 #17
0
 def testWithPeriodInVariableName(self):
     hparams = hparam.HParams()
     hparams.add_hparam(name="a.b", value=0.0)
     hparams.parse("a.b=1.0")
     self.assertEqual(1.0, getattr(hparams, "a.b"))
     hparams.add_hparam(name="c.d", value=0.0)
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("c.d=abc")
     hparams.add_hparam(name="e.f", value="")
     hparams.parse("e.f=abc")
     self.assertEqual("abc", getattr(hparams, "e.f"))
     hparams.add_hparam(name="d..", value=0.0)
     hparams.parse("d..=10.0")
     self.assertEqual(10.0, getattr(hparams, "d.."))
コード例 #18
0
 def testWithPeriodInVariableName(self):
     hparams = hparam.HParams()
     hparams.add_hparam(name='a.b', value=0.0)
     hparams.parse('a.b=1.0')
     self.assertEqual(1.0, getattr(hparams, 'a.b'))
     hparams.add_hparam(name='c.d', value=0.0)
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('c.d=abc')
     hparams.add_hparam(name='e.f', value='')
     hparams.parse('e.f=abc')
     self.assertEqual('abc', getattr(hparams, 'e.f'))
     hparams.add_hparam(name='d..', value=0.0)
     hparams.parse('d..=10.0')
     self.assertEqual(10.0, getattr(hparams, 'd..'))
コード例 #19
0
    def testDel(self):
        hparams = hparam.HParams(aaa=1, b=2.0)

        with self.assertRaises(ValueError):
            hparams.set_hparam('aaa', 'will fail')

        with self.assertRaises(ValueError):
            hparams.add_hparam('aaa', 'will fail')

        hparams.del_hparam('aaa')
        hparams.add_hparam('aaa', 'will work')
        self.assertEqual('will work', hparams.get('aaa'))

        hparams.set_hparam('aaa', 'still works')
        self.assertEqual('still works', hparams.get('aaa'))
コード例 #20
0
    def testDel(self):
        hparams = hparam.HParams(aaa=1, b=2.0)

        with self.assertRaises(ValueError):
            hparams.set_hparam("aaa", "will fail")

        with self.assertRaises(ValueError):
            hparams.add_hparam("aaa", "will fail")

        hparams.del_hparam("aaa")
        hparams.add_hparam("aaa", "will work")
        self.assertEqual("will work", hparams.get("aaa"))

        hparams.set_hparam("aaa", "still works")
        self.assertEqual("still works", hparams.get("aaa"))
コード例 #21
0
    def testGet(self):
        hparams = hparam.HParams(aaa=1,
                                 b=2.0,
                                 c_c='relu6',
                                 d=True,
                                 e=[5.0, 6.0])

        # Existing parameters with default=None.
        self.assertEqual(1, hparams.get('aaa'))
        self.assertEqual(2.0, hparams.get('b'))
        self.assertEqual('relu6', hparams.get('c_c'))
        self.assertEqual(True, hparams.get('d'))
        self.assertEqual([5.0, 6.0], hparams.get('e', None))

        # Existing parameters with compatible defaults.
        self.assertEqual(1, hparams.get('aaa', 2))
        self.assertEqual(2.0, hparams.get('b', 3.0))
        self.assertEqual(2.0, hparams.get('b', 3))
        self.assertEqual('relu6', hparams.get('c_c', 'default'))
        self.assertEqual(True, hparams.get('d', True))
        self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0]))
        self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3]))

        # Existing parameters with incompatible defaults.
        with self.assertRaises(ValueError):
            hparams.get('aaa', 2.0)

        with self.assertRaises(ValueError):
            hparams.get('b', False)

        with self.assertRaises(ValueError):
            hparams.get('c_c', [1, 2, 3])

        with self.assertRaises(ValueError):
            hparams.get('d', 'relu')

        with self.assertRaises(ValueError):
            hparams.get('e', 123.0)

        with self.assertRaises(ValueError):
            hparams.get('e', ['a', 'b', 'c'])

        # Nonexistent parameters.
        self.assertEqual(None, hparams.get('unknown'))
        self.assertEqual(123, hparams.get('unknown', 123))
        self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
コード例 #22
0
    def testGet(self):
        hparams = hparam.HParams(aaa=1,
                                 b=2.0,
                                 c_c="relu6",
                                 d=True,
                                 e=[5.0, 6.0])

        # Existing parameters with default=None.
        self.assertEqual(1, hparams.get("aaa"))
        self.assertEqual(2.0, hparams.get("b"))
        self.assertEqual("relu6", hparams.get("c_c"))
        self.assertEqual(True, hparams.get("d"))
        self.assertEqual([5.0, 6.0], hparams.get("e", None))

        # Existing parameters with compatible defaults.
        self.assertEqual(1, hparams.get("aaa", 2))
        self.assertEqual(2.0, hparams.get("b", 3.0))
        self.assertEqual(2.0, hparams.get("b", 3))
        self.assertEqual("relu6", hparams.get("c_c", "default"))
        self.assertEqual(True, hparams.get("d", True))
        self.assertEqual([5.0, 6.0], hparams.get("e", [1.0, 2.0, 3.0]))
        self.assertEqual([5.0, 6.0], hparams.get("e", [1, 2, 3]))

        # Existing parameters with incompatible defaults.
        with self.assertRaises(ValueError):
            hparams.get("aaa", 2.0)

        with self.assertRaises(ValueError):
            hparams.get("b", False)

        with self.assertRaises(ValueError):
            hparams.get("c_c", [1, 2, 3])

        with self.assertRaises(ValueError):
            hparams.get("d", "relu")

        with self.assertRaises(ValueError):
            hparams.get("e", 123.0)

        with self.assertRaises(ValueError):
            hparams.get("e", ["a", "b", "c"])

        # Nonexistent parameters.
        self.assertEqual(None, hparams.get("unknown"))
        self.assertEqual(123, hparams.get("unknown", 123))
        self.assertEqual([1, 2, 3], hparams.get("unknown", [1, 2, 3]))
コード例 #23
0
 def test_append(self):
     hparams = hparam.HParams(aaa=1)
     hparams2 = hparam.HParams(b=2.0, c_c='relu6', d=True, e=[5.0, 6.0])
     hparams.append(hparams2)
     print(hparams)
コード例 #24
0
 def testContains(self):
     hparams = hparam.HParams(foo=1)
     self.assertTrue("foo" in hparams)
     self.assertFalse("bar" in hparams)
コード例 #25
0
 def testSetHParamListNonListMismatch(self):
     hparams = hparam.HParams(a=1, b=[2.0, 3.0])
     with self.assertRaisesRegexp(ValueError, r'Must not pass a list'):
         hparams.set_hparam('a', [1.0])
     with self.assertRaisesRegexp(ValueError, r'Must pass a list'):
         hparams.set_hparam('b', 1.0)
コード例 #26
0
 def testSomeValues(self):
     hparams = hparam.HParams(aaa=1, b=2.0, c_c="relu6", d="/a/b=c/d")
     self.assertDictEqual(
         {
             "aaa": 1,
             "b": 2.0,
             "c_c": "relu6",
             "d": "/a/b=c/d"
         }, hparams.values())
     expected_str = ("[('aaa', 1), ('b', 2.0), ('c_c', 'relu6'), "
                     "('d', '/a/b=c/d')]")
     self.assertEqual(expected_str, str(hparams.__str__()))
     self.assertEqual(expected_str, str(hparams))
     self.assertEqual(1, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual("relu6", hparams.c_c)
     self.assertEqual("/a/b=c/d", hparams.d)
     hparams.parse("aaa=12")
     self.assertDictEqual(
         {
             "aaa": 12,
             "b": 2.0,
             "c_c": "relu6",
             "d": "/a/b=c/d"
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual("relu6", hparams.c_c)
     self.assertEqual("/a/b=c/d", hparams.d)
     hparams.parse("c_c=relu4, b=-2.0e10")
     self.assertDictEqual(
         {
             "aaa": 12,
             "b": -2.0e10,
             "c_c": "relu4",
             "d": "/a/b=c/d"
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(-2.0e10, hparams.b)
     self.assertEqual("relu4", hparams.c_c)
     self.assertEqual("/a/b=c/d", hparams.d)
     hparams.parse("c_c=,b=0,")
     self.assertDictEqual({
         "aaa": 12,
         "b": 0,
         "c_c": "",
         "d": "/a/b=c/d"
     }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(0.0, hparams.b)
     self.assertEqual("", hparams.c_c)
     self.assertEqual("/a/b=c/d", hparams.d)
     hparams.parse('c_c=2.3",b=+2,')
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     hparams.parse("d=/a/b/c/d,aaa=11,")
     self.assertEqual(11, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual("/a/b/c/d", hparams.d)
     hparams.parse("b=1.5,d=/a=b/c/d,aaa=10,")
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual("/a=b/c/d", hparams.d)
     with self.assertRaisesRegexp(ValueError, "Unknown hyperparameter"):
         hparams.parse("x=123")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("aaa=poipoi")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("aaa=1.0")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("b=12x")
     with self.assertRaisesRegexp(ValueError, "Could not parse"):
         hparams.parse("b=relu")
     with self.assertRaisesRegexp(ValueError, "Must not pass a list"):
         hparams.parse("aaa=[123]")
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual("/a=b/c/d", hparams.d)
コード例 #27
0
 def testSomeValues(self):
     hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d')
     self.assertDictEqual(
         {
             'aaa': 1,
             'b': 2.0,
             'c_c': 'relu6',
             'd': '/a/b=c/d'
         }, hparams.values())
     expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), '
                     '(\'d\', \'/a/b=c/d\')]')
     self.assertEqual(expected_str, str(hparams.__str__()))
     self.assertEqual(expected_str, str(hparams))
     self.assertEqual(1, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('relu6', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('aaa=12')
     self.assertDictEqual(
         {
             'aaa': 12,
             'b': 2.0,
             'c_c': 'relu6',
             'd': '/a/b=c/d'
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('relu6', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=relu4, b=-2.0e10')
     self.assertDictEqual(
         {
             'aaa': 12,
             'b': -2.0e10,
             'c_c': 'relu4',
             'd': '/a/b=c/d'
         }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(-2.0e10, hparams.b)
     self.assertEqual('relu4', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=,b=0,')
     self.assertDictEqual({
         'aaa': 12,
         'b': 0,
         'c_c': '',
         'd': '/a/b=c/d'
     }, hparams.values())
     self.assertEqual(12, hparams.aaa)
     self.assertEqual(0.0, hparams.b)
     self.assertEqual('', hparams.c_c)
     self.assertEqual('/a/b=c/d', hparams.d)
     hparams.parse('c_c=2.3",b=+2,')
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     hparams.parse('d=/a/b/c/d,aaa=11,')
     self.assertEqual(11, hparams.aaa)
     self.assertEqual(2.0, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a/b/c/d', hparams.d)
     hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,')
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a=b/c/d', hparams.d)
     with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
         hparams.parse('x=123')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=poipoi')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('aaa=1.0')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=12x')
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('b=relu')
     with self.assertRaisesRegexp(ValueError, 'Must not pass a list'):
         hparams.parse('aaa=[123]')
     self.assertEqual(10, hparams.aaa)
     self.assertEqual(1.5, hparams.b)
     self.assertEqual('2.3"', hparams.c_c)
     self.assertEqual('/a=b/c/d', hparams.d)
コード例 #28
0
 def testContains(self):
     hparams = hparam.HParams(foo=1)
     self.assertTrue('foo' in hparams)
     self.assertFalse('bar' in hparams)
コード例 #29
0
 def testBoolParsingFail(self):
     hparams = hparam.HParams(use_gpu=True)
     with self.assertRaisesRegexp(ValueError, r'Could not parse.*use_gpu'):
         hparams.parse('use_gpu=yep')