コード例 #1
0
    def testDel(self):
        hparams = HParams(aaa=1, b=2.0)

        with self.assertRaises(ValueError):
            hparams.set_hparam('aaa', 'will fail')

        with self.assertRaises(ValueError):
            hparams.add_hparam('aaa', 'will fail')

        hparams.del_hparam('aaa')
        hparams.add_hparam('aaa', 'will work')
        self.assertEqual('will work', hparams.get('aaa'))

        hparams.set_hparam('aaa', 'still works')
        self.assertEqual('still works', hparams.get('aaa'))
コード例 #2
0
 def testWithPeriodInVariableName(self):
     hparams = HParams()
     hparams.add_hparam(name='a.b', value=0.0)
     hparams.parse('a.b=1.0')
     self.assertEqual(1.0, getattr(hparams, 'a.b'))
     hparams.add_hparam(name='c.d', value=0.0)
     with self.assertRaisesRegexp(ValueError, 'Could not parse'):
         hparams.parse('c.d=abc')
     hparams.add_hparam(name='e.f', value='')
     hparams.parse('e.f=abc')
     self.assertEqual('abc', getattr(hparams, 'e.f'))
     hparams.add_hparam(name='d..', value=0.0)
     hparams.parse('d..=10.0')
     self.assertEqual(10.0, getattr(hparams, 'd..'))
コード例 #3
0
    def test_hparams(self):
        """Tests the HParams class.
        """
        default_hparams = {
            "str": "str",
            "list": ['item1', 'item2'],
            "dict": {
                "key1": "value1",
                "key2": "value2"
            },
            "nested_dict": {
                "dict_l2": {
                    "key1_l2": "value1_l2"
                }
            },
            "type": "type",
            "kwargs": {
                "arg1": "argv1"
            },
        }

        # Test HParams.items() function
        hparams_ = HParams(hparams=None, default_hparams=default_hparams)
        names = []
        for name, _ in hparams_.items():
            names.append(name)
        self.assertEqual(set(names), set(default_hparams.keys()))

        hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}}

        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)

        # Test HParams construction
        self.assertEqual(hparams_.str, default_hparams["str"])
        self.assertEqual(hparams_.list, default_hparams["list"])
        self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"])
        self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"])
        self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2,
                         default_hparams["nested_dict"]["dict_l2"]["key1_l2"])

        self.assertEqual(len(hparams_), len(default_hparams))

        new_hparams = copy.deepcopy(default_hparams)
        new_hparams["dict"]["key1"] = hparams["dict"]["key1"]
        new_hparams["kwargs"].update(hparams["kwargs"])
        self.assertEqual(hparams_.todict(), new_hparams)

        self.assertTrue("dict" in hparams_)

        self.assertIsNone(hparams_.get('not_existed_name', None))
        self.assertEqual(hparams_.get('str'), default_hparams['str'])

        # Test HParams update related operations
        hparams_.str = "new_str"
        hparams_.dict = {"key3": "value3"}
        self.assertEqual(hparams_.str, "new_str")
        self.assertEqual(hparams_.dict.key3, "value3")

        hparams_.add_hparam("added_str", "added_str")
        hparams_.add_hparam("added_dict", {"key4": "value4"})
        hparams_.kwargs.add_hparam("added_arg", "added_argv")
        self.assertEqual(hparams_.added_str, "added_str")
        self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"})
        self.assertEqual(hparams_.kwargs.added_arg, "added_argv")

        # Test HParams I/O
        hparams_file = tempfile.NamedTemporaryFile()
        pickle.dump(hparams_, hparams_file)
        with open(hparams_file.name, 'rb') as hparams_file:
            hparams_loaded = pickle.load(hparams_file)
        self.assertEqual(hparams_loaded.todict(), hparams_.todict())