예제 #1
0
    def testDel(self):
        hparams = HParams(aaa=1, b=2.0)

        with self.assertRaises(ValueError):
            hparams.set_hparam('aaa', 'will fail')

        with self.assertRaises(ValueError):
            hparams.add_hparam('aaa', 'will fail')

        hparams.del_hparam('aaa')
        hparams.add_hparam('aaa', 'will work')
        self.assertEqual('will work', hparams.get('aaa'))

        hparams.set_hparam('aaa', 'still works')
        self.assertEqual('still works', hparams.get('aaa'))
예제 #2
0
    def testGet(self):
        hparams = HParams(aaa=1, b=2.0, c_c='relu6', d=True, e=[5.0, 6.0])

        # Existing parameters with default=None.
        self.assertEqual(1, hparams.get('aaa'))
        self.assertEqual(2.0, hparams.get('b'))
        self.assertEqual('relu6', hparams.get('c_c'))
        self.assertEqual(True, hparams.get('d'))
        self.assertEqual([5.0, 6.0], hparams.get('e', None))

        # Existing parameters with compatible defaults.
        self.assertEqual(1, hparams.get('aaa', 2))
        self.assertEqual(2.0, hparams.get('b', 3.0))
        self.assertEqual(2.0, hparams.get('b', 3))
        self.assertEqual('relu6', hparams.get('c_c', 'default'))
        self.assertEqual(True, hparams.get('d', True))
        self.assertEqual([5.0, 6.0], hparams.get('e', [1.0, 2.0, 3.0]))
        self.assertEqual([5.0, 6.0], hparams.get('e', [1, 2, 3]))

        # Existing parameters with incompatible defaults.
        with self.assertRaises(ValueError):
            hparams.get('aaa', 2.0)

        with self.assertRaises(ValueError):
            hparams.get('b', False)

        with self.assertRaises(ValueError):
            hparams.get('c_c', [1, 2, 3])

        with self.assertRaises(ValueError):
            hparams.get('d', 'relu')

        with self.assertRaises(ValueError):
            hparams.get('e', 123.0)

        with self.assertRaises(ValueError):
            hparams.get('e', ['a', 'b', 'c'])

        # Nonexistent parameters.
        self.assertEqual(None, hparams.get('unknown'))
        self.assertEqual(123, hparams.get('unknown', 123))
        self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
예제 #3
0
    def test_hparams(self):
        """Tests the HParams class.
        """
        default_hparams = {
            "str": "str",
            "list": ['item1', 'item2'],
            "dict": {
                "key1": "value1",
                "key2": "value2"
            },
            "nested_dict": {
                "dict_l2": {
                    "key1_l2": "value1_l2"
                }
            },
            "type": "type",
            "kwargs": {
                "arg1": "argv1"
            },
        }

        # Test HParams.items() function
        hparams_ = HParams(hparams=None, default_hparams=default_hparams)
        names = []
        for name, _ in hparams_.items():
            names.append(name)
        self.assertEqual(set(names), set(default_hparams.keys()))

        hparams = {"dict": {"key1": "new_value"}, "kwargs": {"arg2": "argv2"}}

        hparams_ = HParams(hparams=hparams, default_hparams=default_hparams)

        # Test HParams construction
        self.assertEqual(hparams_.str, default_hparams["str"])
        self.assertEqual(hparams_.list, default_hparams["list"])
        self.assertEqual(hparams_.dict.key1, hparams["dict"]["key1"])
        self.assertEqual(hparams_.kwargs.arg2, hparams["kwargs"]["arg2"])
        self.assertEqual(hparams_.nested_dict.dict_l2.key1_l2,
                         default_hparams["nested_dict"]["dict_l2"]["key1_l2"])

        self.assertEqual(len(hparams_), len(default_hparams))

        new_hparams = copy.deepcopy(default_hparams)
        new_hparams["dict"]["key1"] = hparams["dict"]["key1"]
        new_hparams["kwargs"].update(hparams["kwargs"])
        self.assertEqual(hparams_.todict(), new_hparams)

        self.assertTrue("dict" in hparams_)

        self.assertIsNone(hparams_.get('not_existed_name', None))
        self.assertEqual(hparams_.get('str'), default_hparams['str'])

        # Test HParams update related operations
        hparams_.str = "new_str"
        hparams_.dict = {"key3": "value3"}
        self.assertEqual(hparams_.str, "new_str")
        self.assertEqual(hparams_.dict.key3, "value3")

        hparams_.add_hparam("added_str", "added_str")
        hparams_.add_hparam("added_dict", {"key4": "value4"})
        hparams_.kwargs.add_hparam("added_arg", "added_argv")
        self.assertEqual(hparams_.added_str, "added_str")
        self.assertEqual(hparams_.added_dict.todict(), {"key4": "value4"})
        self.assertEqual(hparams_.kwargs.added_arg, "added_argv")

        # Test HParams I/O
        hparams_file = tempfile.NamedTemporaryFile()
        pickle.dump(hparams_, hparams_file)
        with open(hparams_file.name, 'rb') as hparams_file:
            hparams_loaded = pickle.load(hparams_file)
        self.assertEqual(hparams_loaded.todict(), hparams_.todict())