def test_remove_var(self): gm = GlobalVarManager({"1": self.st2}) gm.remove(["1"]) self.assertEqual(list(gm.vars), []) self.assertEqual(list(gm.initial_values), []) self.assertNotIn(self.st2, list(gm.vars))
def test_remove_var_that_does_not_exist_should_throw_error(self): gm = GlobalVarManager({"1": self.st1}) with self.assertRaises(ValueError) as e: gm.remove(["2"]) self.assertRegex(str(e.exception), "^Can't remove var .* which does not exists$")
def test_add_var_that_already_exists_should_throw_error(self): gm = GlobalVarManager(self.all_d) for i, test_element in self.all_d.items(): with self.assertRaises(ValueError) as e: gm.add({i: test_element}) self.assertRegex(str(e.exception), "^GlobalVarManager has already var.*")
def load(self, callable_classes): for cl in callable_classes: attributes = inspect.getmembers(cl, lambda x: not (inspect.isroutine(x))) for a in attributes: if not (a[0].startswith('_') or a[0].endswith('_')): unique_name = cl.__name__ + "." + a[0] value = a[1] self.class_vars[unique_name] = value logger.debug('hook classes vars: {}'.format(self.class_vars)) self.gvm = GlobalVarManager(self.class_vars)
def test_reset_one(self): gm = GlobalVarManager({"1": self.i1}) gm.update({"1": 999}) self.assertEqual(list(gm.vars.values()), [999]) gm.reset_one("1") self.assertEqual(list(gm.vars.values()), [self.i1])
def test_reset_one_var_does_not_exist_should_throw_error(self): gm = GlobalVarManager({"1": self.i1}) gm.update({"1": 999}) self.assertEqual(list(gm.vars.values()), [999]) with self.assertRaises(ValueError) as e: gm.reset_one("2") self.assertRegex(str(e.exception), "^Cannot reset var .* since it does not exist$")
class HookClassVarsHandler: def __init__(self): self.gvm = None self.class_vars = dict() hooks_dir = 'rnnlm/utils/estimator/estimator_hook' files = os.listdir(hooks_dir) self.class_name_to_callable = dict() logger.info('loading hooks from {}'.format(hooks_dir)) for f in files: if not f.startswith('_') and f != os.path.basename(__file__): f = f.replace('.py', '') train_hook = importlib.import_module(name="." + f, package=estimator_hook.__name__) class_name = inspect.getmembers(train_hook, inspect.isclass)[0][0] callable_class = getattr(train_hook, class_name) self.class_name_to_callable[class_name] = callable_class logger.debug('found hook classes {}'.format(self.class_name_to_callable)) self.load(self.class_name_to_callable.values()) def load(self, callable_classes): for cl in callable_classes: attributes = inspect.getmembers(cl, lambda x: not (inspect.isroutine(x))) for a in attributes: if not (a[0].startswith('_') or a[0].endswith('_')): unique_name = cl.__name__ + "." + a[0] value = a[1] self.class_vars[unique_name] = value logger.debug('hook classes vars: {}'.format(self.class_vars)) self.gvm = GlobalVarManager(self.class_vars) def add(self, class_name, var_name, var_initial_value): unique_name = class_name + "." + var_name self.gvm.add({unique_name: var_initial_value}) def remove(self, class_name, var_name): unique_name = class_name + "." + var_name self.gvm.remove([unique_name]) def reset(self, class_name, var_name): unique_name = class_name + "." + var_name self.gvm.reset_one(unique_name) def reset_all(self): logger.debug('resetting hook classes vars to their initial values: {}'.format(self.gvm.initial_values)) self.gvm.reset_all() for k, value in self.class_vars.items(): class_name, attribute = k.split('.') setattr(self.class_name_to_callable[class_name], attribute, value)
def test_reset_all(self): gm = GlobalVarManager(self.all_d) updated_values = [["sdf"], [1], self.d2, self.d1, {"123"}, {"21"}, True, True, [False, False, True], {"a": True, "b": False}, -1, "15.3", 0, "None", "abc123", "True"] update_dict = dict(zip(self.all_d.keys(), updated_values)) gm.update(update_dict) self.assertEqual(list(gm.vars.values()), updated_values) self.assertEqual(list(gm.initial_values.values()), self.all) gm.reset_all() self.assertEqual(list(gm.vars.values()), self.all)
def test_retrieve_original_value_after_it_has_changed(self): gm = GlobalVarManager({"1": self.st2}) gm.update({"1": "hello"}) self.assertEqual([self.st2], list(gm.initial_values.values()))
def test_update_var_that_does_not_exists_should_throw_error(self): gm = GlobalVarManager() with self.assertRaises(ValueError) as e: gm.update({"1": "hello"}) self.assertRegex(str(e.exception), "^Can't update var .* which does not exist. Not updating anything$")
def test_update_var(self): gm = GlobalVarManager({"1": self.st2}) gm.update({"1": "hello"}) self.assertEqual(list(gm.vars.values()), ["hello"])
def test_add_var_after_creation(self): gm = GlobalVarManager({"1": self.st1, "2": self.dl1}) gm.add({"3": self.st2}) self.assertEqual(list(gm.vars.values()), [self.st1, self.dl1, self.st2]) self.assertEqual(list(gm.initial_values.values()), [self.st1, self.dl1, self.st2])
def test_successful_creation(self): gm = GlobalVarManager(self.all_d) self.assertEqual(list(gm.vars.values()), self.all) self.assertEqual(gm.initial_values, gm.vars)
def test_reset_all_when_no_vars(self): gm = GlobalVarManager() gm.reset_all()