예제 #1
0
    def testSetHParam(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True)
        self.assertDictEqual({
            'aaa': 1,
            'b': 2.0,
            'c_c': 'relu6',
            'd': True
        }, hparams.values())
        self.assertEqual(1, hparams.aaa)
        self.assertEqual(2.0, hparams.b)
        self.assertEqual('relu6', hparams.c_c)

        hparams.set_hparam('aaa', 12)
        hparams.set_hparam('b', 3.0)
        hparams.set_hparam('c_c', 'relu4')
        hparams.set_hparam('d', False)
        self.assertDictEqual({
            'aaa': 12,
            'b': 3.0,
            'c_c': 'relu4',
            'd': False
        }, hparams.values())
        self.assertEqual(12, hparams.aaa)
        self.assertEqual(3.0, hparams.b)
        self.assertEqual('relu4', hparams.c_c)
예제 #2
0
    def testDel(self):
        hparams = HParams(aaa=1, b=2.0)

        with self.assertRaises(ValueError):
            hparams.set_hparam('aaa', 'will fail')

        with self.assertRaises(ValueError):
            hparams.add_hparam('aaa', 'will fail')

        hparams.del_hparam('aaa')
        hparams.add_hparam('aaa', 'will work')
        self.assertEqual('will work', hparams.get('aaa'))

        hparams.set_hparam('aaa', 'still works')
        self.assertEqual('still works', hparams.get('aaa'))
예제 #3
0
    def testSetHParamTypeMismatch(self):
        hparams = HParams(int_=1,
                          str_='str',
                          bool_=True,
                          float_=1.1,
                          list_int=[1, 2],
                          none=None)

        with self.assertRaises(ValueError):
            hparams.set_hparam('str_', 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', False)

        with self.assertRaises(ValueError):
            hparams.set_hparam('bool_', 1)

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', 2.2)

        with self.assertRaises(ValueError):
            hparams.set_hparam('list_int', [2, 3.3])

        with self.assertRaises(ValueError):
            hparams.set_hparam('int_', '2')

        # Casting int to float is OK
        hparams.set_hparam('float_', 1)

        # Getting stuck with NoneType :(
        hparams.set_hparam('none', '1')
        self.assertEqual('1', hparams.none)
예제 #4
0
 def testSetHParamListNonListMismatch(self):
     hparams = HParams(a=1, b=[2.0, 3.0])
     with self.assertRaisesRegexp(ValueError, r'Must not pass a list'):
         hparams.set_hparam('a', [1.0])
     with self.assertRaisesRegexp(ValueError, r'Must pass a list'):
         hparams.set_hparam('b', 1.0)