def testDel(self): hparams = HParams(aaa=1, b=2.0) with self.assertRaises(ValueError): hparams.set_hparam('aaa', 'will fail') with self.assertRaises(ValueError): hparams.add_hparam('aaa', 'will fail') hparams.del_hparam('aaa') hparams.add_hparam('aaa', 'will work') self.assertEqual('will work', hparams.get('aaa')) hparams.set_hparam('aaa', 'still works') self.assertEqual('still works', hparams.get('aaa'))
def testWithPeriodInVariableName(self): hparams = HParams() hparams.add_hparam(name='a.b', value=0.0) hparams.parse('a.b=1.0') self.assertEqual(1.0, getattr(hparams, 'a.b')) hparams.add_hparam(name='c.d', value=0.0) with self.assertRaisesRegexp(ValueError, 'Could not parse'): hparams.parse('c.d=abc') hparams.add_hparam(name='e.f', value='') hparams.parse('e.f=abc') self.assertEqual('abc', getattr(hparams, 'e.f')) hparams.add_hparam(name='d..', value=0.0) hparams.parse('d..=10.0') self.assertEqual(10.0, getattr(hparams, 'd..'))
def test_hparams(self): """Tests the HParams class. """ default_hparams = { "str": "str", "list": ['item1', 'item2'], "dict": { "key1": "value1", "key2": "value2" }, "nested_dict": { "dict_l2": { "key1_l2": "value1_l2" } }, "type": "type", "kwargs": { "arg1": "argv1" }, } # Test HParams.items() function hparams_ = HParams(hparams=None, default_hparams=default_hparams) names = [] for name, _ in hparams_.items(): names.append(name) self.assertEqual(set(names), set(default_hparams.keys())) hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) # Test HParams construction self.assertEqual(hparams_.str, default_hparams["str"]) self.assertEqual(hparams_.list, default_hparams["list"]) self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"]) self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"]) self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2, default_hparams["nested_dict"]["dict_l2"]["key1_l2"]) self.assertEqual(len(hparams_), len(default_hparams)) new_hparams = copy.deepcopy(default_hparams) new_hparams["dict"]["key1"] = hparams["dict"]["key1"] new_hparams["kwargs"].update(hparams["kwargs"]) self.assertEqual(hparams_.todict(), new_hparams) self.assertTrue("dict" in hparams_) self.assertIsNone(hparams_.get('not_existed_name', None)) self.assertEqual(hparams_.get('str'), default_hparams['str']) # Test HParams update related operations hparams_.str = "new_str" hparams_.dict = {"key3": "value3"} self.assertEqual(hparams_.str, "new_str") self.assertEqual(hparams_.dict.key3, "value3") hparams_.add_hparam("added_str", "added_str") hparams_.add_hparam("added_dict", {"key4": "value4"}) hparams_.kwargs.add_hparam("added_arg", "added_argv") self.assertEqual(hparams_.added_str, "added_str") self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"}) self.assertEqual(hparams_.kwargs.added_arg, "added_argv") # Test HParams I/O hparams_file = tempfile.NamedTemporaryFile() pickle.dump(hparams_, hparams_file) with open(hparams_file.name, 'rb') as hparams_file: hparams_loaded = pickle.load(hparams_file) self.assertEqual(hparams_loaded.todict(), hparams_.todict())