def testDel(self): hparams = HParams(aaa=1, b=2.0) with self.assertRaises(ValueError): hparams.set_hparam('aaa', 'will fail') with self.assertRaises(ValueError): hparams.add_hparam('aaa', 'will fail') hparams.del_hparam('aaa') hparams.add_hparam('aaa', 'will work') self.assertEqual('will work', hparams.get('aaa')) hparams.set_hparam('aaa', 'still works') self.assertEqual('still works', hparams.get('aaa'))
def testGet(self): hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0]) # Existing parameters with default=None. self.assertEqual(1, hparams.get('aaa')) self.assertEqual(2.0, hparams.get('b')) self.assertEqual('relu6', hparams.get('c_c')) self.assertEqual(True, hparams.get('d')) self.assertEqual([5.0, 6.0], hparams.get('e', None)) # Existing parameters with compatible defaults. self.assertEqual(1, hparams.get('aaa', 2)) self.assertEqual(2.0, hparams.get('b', 3.0)) self.assertEqual(2.0, hparams.get('b', 3)) self.assertEqual('relu6', hparams.get('c_c', 'default')) self.assertEqual(True, hparams.get('d', True)) self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0])) self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3])) # Existing parameters with incompatible defaults. with self.assertRaises(ValueError): hparams.get('aaa', 2.0) with self.assertRaises(ValueError): hparams.get('b', False) with self.assertRaises(ValueError): hparams.get('c_c', [1, 2, 3]) with self.assertRaises(ValueError): hparams.get('d', 'relu') with self.assertRaises(ValueError): hparams.get('e', 123.0) with self.assertRaises(ValueError): hparams.get('e', ['a', 'b', 'c']) # Nonexistent parameters. self.assertEqual(None, hparams.get('unknown')) self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
def test_hparams(self): """Tests the HParams class. """ default_hparams = { "str": "str", "list": ['item1', 'item2'], "dict": { "key1": "value1", "key2": "value2" }, "nested_dict": { "dict_l2": { "key1_l2": "value1_l2" } }, "type": "type", "kwargs": { "arg1": "argv1" }, } # Test HParams.items() function hparams_ = HParams(hparams=None, default_hparams=default_hparams) names = [] for name, _ in hparams_.items(): names.append(name) self.assertEqual(set(names), set(default_hparams.keys())) hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}} hparams_ = HParams(hparams=hparams, default_hparams=default_hparams) # Test HParams construction self.assertEqual(hparams_.str, default_hparams["str"]) self.assertEqual(hparams_.list, default_hparams["list"]) self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"]) self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"]) self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2, default_hparams["nested_dict"]["dict_l2"]["key1_l2"]) self.assertEqual(len(hparams_), len(default_hparams)) new_hparams = copy.deepcopy(default_hparams) new_hparams["dict"]["key1"] = hparams["dict"]["key1"] new_hparams["kwargs"].update(hparams["kwargs"]) self.assertEqual(hparams_.todict(), new_hparams) self.assertTrue("dict" in hparams_) self.assertIsNone(hparams_.get('not_existed_name', None)) self.assertEqual(hparams_.get('str'), default_hparams['str']) # Test HParams update related operations hparams_.str = "new_str" hparams_.dict = {"key3": "value3"} self.assertEqual(hparams_.str, "new_str") self.assertEqual(hparams_.dict.key3, "value3") hparams_.add_hparam("added_str", "added_str") hparams_.add_hparam("added_dict", {"key4": "value4"}) hparams_.kwargs.add_hparam("added_arg", "added_argv") self.assertEqual(hparams_.added_str, "added_str") self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"}) self.assertEqual(hparams_.kwargs.added_arg, "added_argv") # Test HParams I/O hparams_file = tempfile.NamedTemporaryFile() pickle.dump(hparams_, hparams_file) with open(hparams_file.name, 'rb') as hparams_file: hparams_loaded = pickle.load(hparams_file) self.assertEqual(hparams_loaded.todict(), hparams_.todict())