def main(_): placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = ml_collections.FieldReference(0, field_type=int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. # Note that the indirection provided by FieldReferences will be lost if # accessed through a ConfigDict: placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. print(cfg)
def testTypeChecking(self): ref = ml_collections.FieldReference(1) string_ref = ml_collections.FieldReference('a') x = ref + string_ref with self.assertRaises(TypeError): x.get()
def testInitFieldReferenceInList(self): """Ensure initialization fails on FieldReferences in lists/tuples.""" list_containing_fr = { 'list': [1, 2, 3, ml_collections.FieldReference(4)] } tuple_containing_fr = { 'tuple': (1, 2, 3, ml_collections.FieldReference('a')) } self.assertFrozenRaisesValueError( [list_containing_fr, tuple_containing_fr])
def testControlFlowError(self): ref1 = ml_collections.FieldReference(True) ref2 = ml_collections.FieldReference(False) with self.assertRaises(NotImplementedError): if ref1: pass with self.assertRaises(NotImplementedError): _ = ref1 and ref2 with self.assertRaises(NotImplementedError): _ = ref1 or ref2 with self.assertRaises(NotImplementedError): _ = not ref1
def testLessEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertLessEqual(ref1, 1) self.assertLessEqual(ref1, 2) self.assertLessEqual(0, ref1) self.assertLessEqual(1, ref1) self.assertGreater(ref1, 0) self.assertLessEqual(ref1, ref1) self.assertLessEqual(ref1, ref2) self.assertLessEqual(ref1, ref3) self.assertGreater(ref3, ref1)
def testToString(self): self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0') ref = ml_collections.FieldReference(647) ref = ref.to_str() self.assertEqual(ref.get(), '647') self.assertEqual(ref._field_type, str)
def testToFloat(self): self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0) ref = ml_collections.FieldReference(647) ref = ref.to_float() self.assertEqual(ref.get(), 647.0) self.assertEqual(ref._field_type, float)
def lazy_configdict(): """Example usage of lazy computation with ConfigDict.""" config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5
def make_reference( config, field, ): """Returns a reference to a field for wiring up a config dict. The returned reference is "one-way": a change to the original field value propagates to the reference, but a change to the reference does not propagate back up. This works recursively. If config['field'] is itself a ConfigDict (or a tree of ConfigDicts) then an identical tree is constructed with all the leaves referring to original leaves. Internal nodes are *not* references. See config_schema_test.MakeReferenceRecursiveTest for example usage. Args: config: A ConfigDict instance. field: The name of a field contained in 'config'. Returns: If config['field'] is a scalar, returns a one-way FieldReference to config['field']. If config['field'] is a ConfigDict, returns a new ConfigDict whose fields are references to the fields in config['field']. """ field_value = config[field] if isinstance(field_value, ml_collections.ConfigDict): new_dict = ml_collections.ConfigDict() for subfield in field_value: new_dict[subfield] = make_reference(field_value, subfield) return new_dict else: return ml_collections.FieldReference(config.get_ref(field))
def testFieldReferenceResolved(self): """Tests that FieldReferences are resolved.""" cfg = ml_collections.ConfigDict( {'fr': ml_collections.FieldReference(1)}) frozen_cfg = ml_collections.FrozenConfigDict(cfg) self.assertNotIsInstance(frozen_cfg._fields['fr'], ml_collections.FieldReference) hash( frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable
def testFieldReferenceCycle(self): """Tests that FieldReferences may not contain reference cycles.""" frozenset_fr = {'frozenset': frozenset({1, 2})} frozenset_fr['fr'] = ml_collections.FieldReference( frozenset_fr['frozenset']) list_fr = {'list': [1, 2]} list_fr['fr'] = ml_collections.FieldReference(list_fr['list']) cyclic_fr = {'a': 1} cyclic_fr['fr'] = ml_collections.FieldReference(cyclic_fr) cyclic_fr_parent = {'dict': {}} cyclic_fr_parent['dict']['fr'] = ml_collections.FieldReference( cyclic_fr_parent) # FieldReference is allowed to point to non-cyclic objects: _ = ml_collections.FrozenConfigDict(frozenset_fr) _ = ml_collections.FrozenConfigDict(list_fr) # But not cycles: self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
def get_config(): """Returns a ConfigDict. Used for tests.""" cfg = ml_collections.ConfigDict() cfg.integer = 1 cfg.reference = ml_collections.FieldReference(1) cfg.list = [1, 2, 3] cfg.nested_list = [[1, 2, 3]] cfg.nested_configdict = ml_collections.ConfigDict() cfg.nested_configdict.integer = 1 cfg.unusable_config = UnusableConfig() return cfg
def testSetValue(self): ref = ml_collections.FieldReference(1.0) other = ml_collections.FieldReference(3) ref_plus_other = ref + other self.assertEqual(ref_plus_other.get(), 4.0) ref.set(2.5) self.assertEqual(ref_plus_other.get(), 5.5) other.set(110) self.assertEqual(ref_plus_other.get(), 112.5) # Type checking with self.assertRaises(TypeError): other.set('this is a string') with self.assertRaises(TypeError): other.set(ml_collections.FieldReference('this is a string')) with self.assertRaises(TypeError): other.set(ml_collections.FieldReference(None, field_type=str))
def lazy_computation(): """Simple example of lazy computation with `configdict.FieldReference`.""" ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5
def testInitDictInList(self): """Ensure initialization fails on dict and ConfigDict in lists/tuples.""" list_containing_dict = {'list': [1, 2, 3, {'a': 4, 'b': 5}]} tuple_containing_dict = {'tuple': (1, 2, 3, {'a': 4, 'b': 5})} list_containing_cd = {'list': [1, 2, 3, _test_configdict()]} tuple_containing_cd = {'tuple': (1, 2, 3, _test_configdict())} fr_containing_list_containing_dict = { 'fr': ml_collections.FieldReference([1, { 'a': 2 }]) } self.assertFrozenRaisesValueError([ list_containing_dict, tuple_containing_dict, list_containing_cd, tuple_containing_cd, fr_containing_list_containing_dict ])
def _test_binary_operator(self, initial_value, other_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing binary operators. Generally speaking this checks that: 1. `op(initial_value, other_value) COMP true_value` 2. `op(new_initial_value, other_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the binary operator. other_value: The second argument for the binary operator. op: The binary operator. true_value: The expected output of the binary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the binary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref, other_value) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.b = other_value config.result = op(config.get_ref('a'), config.b) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value)
def testCycles(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.b, 11.0) self.assertEqual(config.c, 21.0) # Introduce a cycle with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = config.get_ref('c') - 1.0 # Introduce a cycle on second operand with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = ml_collections.FieldReference(5.0) + config.get_ref('c') # We can create multiple FieldReferences that all point to the same object l = [0] config = ml_collections.ConfigDict() config.a = l config.b = l config.c = config.get_ref('a') + ['c'] config.d = config.get_ref('b') + ['d'] self.assertEqual(config.c, [0, 'c']) self.assertEqual(config.d, [0, 'd']) # Make sure nothing was mutated self.assertEqual(l, [0]) self.assertEqual(config.c, [0, 'c']) config.a = [1] config.b = [2] self.assertEqual(l, [0]) self.assertEqual(config.c, [1, 'c']) self.assertEqual(config.d, [2, 'd'])
def testSetResult(self): ref = ml_collections.FieldReference(1.0) result = ref + 1.0 second_result = result + 1.0 self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 2.0) self.assertEqual(second_result.get(), 3.0) ref.set(2.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 3.0) self.assertEqual(second_result.get(), 4.0) result.set(4.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) # All references are broken at this point. ref.set(1.0) self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0)
def testEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertEqual(ref1, 1) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, 2) self.assertNotEqual(ref1, ref3) # ConfigDict inside FieldReference ref1 = ml_collections.FieldReference( ml_collections.ConfigDict({'a': 1})) ref2 = ml_collections.FieldReference( ml_collections.ConfigDict({'a': 1})) ref3 = ml_collections.FieldReference( ml_collections.ConfigDict({'a': 2})) self.assertEqual(ref1, ml_collections.ConfigDict({'a': 1})) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, ml_collections.ConfigDict({'a': 2})) self.assertNotEqual(ref1, ref3)
def get_config(): cfg = ml_collections.ConfigDict() cfg.ref = ml_collections.FieldReference(123) cfg.ref_nodefault = config_dict.placeholder(int) return cfg
def get_config(): """Returns default config.""" config = ml_collections.ConfigDict() # ================================================= # # Placeholders. # ================================================= # # These values will be filled at runtime once the gym.Env is loaded. obs_dim = ml_collections.FieldReference(None, field_type=int) action_dim = ml_collections.FieldReference(None, field_type=int) action_range = ml_collections.FieldReference(None, field_type=tuple) # ================================================= # # Main parameters. # ================================================= # config.save_dir = "/tmp/xirl/rl_runs/" # Set this to True to allow CUDA to find the best convolutional algorithm to # use for the given parameters. When False, cuDNN will deterministically # select the same algorithm at a possible cost in performance. config.cudnn_benchmark = True # Enforce CUDA convolution determinism. The algorithm itself might not be # deterministic so setting this to True ensures we make it repeatable. config.cudnn_deterministic = False # ================================================= # # Wrappers. # ================================================= # config.action_repeat = 1 config.frame_stack = 3 config.reward_wrapper = ml_collections.ConfigDict() config.reward_wrapper.pretrained_path = "" # Can be one of ['distance_to_goal', 'goal_classifier']. config.reward_wrapper.type = "distance_to_goal" # ================================================= # # Training parameters. # ================================================= # config.num_train_steps = 75_000 config.replay_buffer_capacity = 1_000_000 config.num_seed_steps = 5_000 config.num_eval_episodes = 50 config.eval_frequency = 5_000 config.checkpoint_frequency = 50_000 config.log_frequency = 10_000 config.save_video = True # ================================================= # # SAC parameters. # ================================================= # config.sac = ml_collections.ConfigDict() config.sac.obs_dim = obs_dim config.sac.action_dim = action_dim config.sac.action_range = action_range config.sac.discount = 0.99 config.sac.init_temperature = 0.1 config.sac.alpha_lr = 1e-4 config.sac.alpha_betas = [0.9, 0.999] config.sac.actor_lr = 1e-4 config.sac.actor_betas = [0.9, 0.999] config.sac.actor_update_frequency = 1 config.sac.critic_lr = 1e-4 config.sac.critic_betas = [0.9, 0.999] config.sac.critic_tau = 0.005 config.sac.critic_target_update_frequency = 2 config.sac.batch_size = 1024 config.sac.learnable_temperature = True # ================================================= # # Critic parameters. # ================================================= # config.sac.critic = ml_collections.ConfigDict() config.sac.critic.obs_dim = obs_dim config.sac.critic.action_dim = action_dim config.sac.critic.hidden_dim = 1024 config.sac.critic.hidden_depth = 2 # ================================================= # # Actor parameters. # ================================================= # config.sac.actor = ml_collections.ConfigDict() config.sac.actor.obs_dim = obs_dim config.sac.actor.action_dim = action_dim config.sac.actor.hidden_dim = 1024 config.sac.actor.hidden_depth = 2 config.sac.actor.log_std_bounds = [-5, 2] # ================================================= # return config
class FieldReferenceTest(parameterized.TestCase): def _test_binary_operator(self, initial_value, other_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing binary operators. Generally speaking this checks that: 1. `op(initial_value, other_value) COMP true_value` 2. `op(new_initial_value, other_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the binary operator. other_value: The second argument for the binary operator. op: The binary operator. true_value: The expected output of the binary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the binary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref, other_value) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.b = other_value config.result = op(config.get_ref('a'), config.b) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def _test_unary_operator(self, initial_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing unary operators. Generally speaking this checks that: 1. `op(initial_value) COMP true_value` 2. `op(new_initial_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the unary operator. op: The unary operator. true_value: The expected output of the unary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the unary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.result = op(config.get_ref('a')) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def testBasic(self): ref = ml_collections.FieldReference(1) self.assertEqual(ref.get(), 1) def testGetRef(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.c, 21.0) def testFunction(self): def fn(x): return x + 5 config = ml_collections.ConfigDict() config.a = 1 config.b = fn(config.get_ref('a')) config.c = fn(config.get_ref('b')) self.assertEqual(config.b, 6) self.assertEqual(config.c, 11) config.a = 2 self.assertEqual(config.b, 7) self.assertEqual(config.c, 12) def testCycles(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.b, 11.0) self.assertEqual(config.c, 21.0) # Introduce a cycle with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = config.get_ref('c') - 1.0 # Introduce a cycle on second operand with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = ml_collections.FieldReference(5.0) + config.get_ref('c') # We can create multiple FieldReferences that all point to the same object l = [0] config = ml_collections.ConfigDict() config.a = l config.b = l config.c = config.get_ref('a') + ['c'] config.d = config.get_ref('b') + ['d'] self.assertEqual(config.c, [0, 'c']) self.assertEqual(config.d, [0, 'd']) # Make sure nothing was mutated self.assertEqual(l, [0]) self.assertEqual(config.c, [0, 'c']) config.a = [1] config.b = [2] self.assertEqual(l, [0]) self.assertEqual(config.c, [1, 'c']) self.assertEqual(config.d, [2, 'd']) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 3, 'new_initial_value': 10, 'new_true_value': 12 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 4.5, 'new_initial_value': 3.7, 'new_true_value': 6.2 }, { 'initial_value': 'hello, ', 'other_value': 'world!', 'true_value': 'hello, world!', 'new_initial_value': 'foo, ', 'new_true_value': 'foo, world!' }, { 'initial_value': ['hello'], 'other_value': ['world'], 'true_value': ['hello', 'world'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'world'] }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 15.0, 'new_initial_value': 12, 'new_true_value': 17.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 19.0 }, { 'initial_value': 5.0, 'other_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': 8.0, 'new_true_value': None }, { 'initial_value': config_dict.placeholder(str), 'other_value': 'tail', 'true_value': None, 'new_initial_value': 'head', 'new_true_value': 'headtail' }) def testAdd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.add, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 5, 'other_value': 3, 'true_value': 2, 'new_initial_value': -1, 'new_true_value': -4 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': -0.5, 'new_initial_value': 12.3, 'new_true_value': 9.8 }, { 'initial_value': set(['hello', 123, 4.5]), 'other_value': set([123]), 'true_value': set(['hello', 4.5]), 'new_initial_value': set([123]), 'new_true_value': set([]) }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 5.0, 'new_initial_value': 12, 'new_true_value': 7.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 5.0 }) def testSub(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.sub, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 2, 'new_initial_value': 3, 'new_true_value': 6 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 5.0, 'new_initial_value': 3.5, 'new_true_value': 8.75 }, { 'initial_value': ['hello'], 'other_value': 3, 'true_value': ['hello', 'hello', 'hello'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'foo', 'foo'] }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 50.0, 'new_initial_value': 1, 'new_true_value': 5.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 84.0 }) def testMul(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.mul, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1.5, 'new_initial_value': 10, 'new_true_value': 5.0 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 0.8, 'new_initial_value': 6.3, 'new_true_value': 2.52 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 2.0, 'new_initial_value': 13, 'new_true_value': 2.6 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 17.5, 'new_true_value': 2.5 }) def testTrueDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.truediv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 7, 'new_true_value': 3 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 2, 'new_initial_value': 28, 'new_true_value': 5 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 3 }) def testFloorDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.floordiv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 9, 'new_initial_value': 10, 'new_true_value': 100 }, { 'initial_value': 2.7, 'other_value': 3.2, 'true_value': 24.0084457245, 'new_initial_value': 6.5, 'new_true_value': 399.321543621 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 1e5, 'new_initial_value': 2, 'new_true_value': 32 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 3.0, 'true_value': None, 'new_initial_value': 7.0, 'new_true_value': 343.0 }) def testPow(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.pow, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 10, 'new_true_value': 0 }, { 'initial_value': 5.3, 'other_value': 3.2, 'true_value': 2.0999999999999996, 'new_initial_value': 77, 'new_true_value': 0.2 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 0, 'new_initial_value': 32, 'new_true_value': 2 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 4 }) def testMod(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.mod, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': True, 'other_value': True, 'true_value': True, 'new_initial_value': False, 'new_true_value': False }, { 'initial_value': ml_collections.FieldReference(False), 'other_value': ml_collections.FieldReference(False), 'true_value': False, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': False, 'new_true_value': False }) def testAnd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.and_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': False, 'true_value': False, 'new_initial_value': True, 'new_true_value': True }, { 'initial_value': ml_collections.FieldReference(True), 'other_value': ml_collections.FieldReference(True), 'true_value': True, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': False, 'true_value': None, 'new_initial_value': True, 'new_true_value': True }) def testOr(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.or_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': True, 'true_value': True, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': ml_collections.FieldReference(True), 'other_value': ml_collections.FieldReference(True), 'true_value': False, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': True, 'new_true_value': False }) def testXor(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.xor, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': -3, 'new_initial_value': -22, 'new_true_value': 22 }, { 'initial_value': 15.3, 'true_value': -15.3, 'new_initial_value': -0.2, 'new_true_value': 0.2 }, { 'initial_value': ml_collections.FieldReference(7), 'true_value': ml_collections.FieldReference(-7), 'new_initial_value': 123, 'new_true_value': -123 }, { 'initial_value': config_dict.placeholder(int), 'true_value': None, 'new_initial_value': -6, 'new_true_value': 6 }) def testNeg(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.neg, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': 3, 'new_initial_value': -101, 'new_true_value': 101 }, { 'initial_value': -15.3, 'true_value': 15.3, 'new_initial_value': 7.3, 'new_true_value': 7.3 }, { 'initial_value': ml_collections.FieldReference(-7), 'true_value': ml_collections.FieldReference(7), 'new_initial_value': 3, 'new_true_value': 3 }, { 'initial_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': -6.25, 'new_true_value': 6.25 }) def testAbs(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.abs, true_value, new_initial_value, new_true_value) def testToInt(self): self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27) ref = ml_collections.FieldReference(64.7) ref = ref.to_int() self.assertEqual(ref.get(), 64) self.assertEqual(ref._field_type, int) def testToFloat(self): self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0) ref = ml_collections.FieldReference(647) ref = ref.to_float() self.assertEqual(ref.get(), 647.0) self.assertEqual(ref._field_type, float) def testToString(self): self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0') ref = ml_collections.FieldReference(647) ref = ref.to_str() self.assertEqual(ref.get(), '647') self.assertEqual(ref._field_type, str) def testSetValue(self): ref = ml_collections.FieldReference(1.0) other = ml_collections.FieldReference(3) ref_plus_other = ref + other self.assertEqual(ref_plus_other.get(), 4.0) ref.set(2.5) self.assertEqual(ref_plus_other.get(), 5.5) other.set(110) self.assertEqual(ref_plus_other.get(), 112.5) # Type checking with self.assertRaises(TypeError): other.set('this is a string') with self.assertRaises(TypeError): other.set(ml_collections.FieldReference('this is a string')) with self.assertRaises(TypeError): other.set(ml_collections.FieldReference(None, field_type=str)) def testSetResult(self): ref = ml_collections.FieldReference(1.0) result = ref + 1.0 second_result = result + 1.0 self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 2.0) self.assertEqual(second_result.get(), 3.0) ref.set(2.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 3.0) self.assertEqual(second_result.get(), 4.0) result.set(4.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) # All references are broken at this point. ref.set(1.0) self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) def testTypeChecking(self): ref = ml_collections.FieldReference(1) string_ref = ml_collections.FieldReference('a') x = ref + string_ref with self.assertRaises(TypeError): x.get() def testNoType(self): self.assertRaisesRegex(TypeError, 'field_type should be a type.*', ml_collections.FieldReference, None, 0) def testEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertEqual(ref1, 1) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, 2) self.assertNotEqual(ref1, ref3) # ConfigDict inside FieldReference ref1 = ml_collections.FieldReference( ml_collections.ConfigDict({'a': 1})) ref2 = ml_collections.FieldReference( ml_collections.ConfigDict({'a': 1})) ref3 = ml_collections.FieldReference( ml_collections.ConfigDict({'a': 2})) self.assertEqual(ref1, ml_collections.ConfigDict({'a': 1})) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, ml_collections.ConfigDict({'a': 2})) self.assertNotEqual(ref1, ref3) def testLessEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertLessEqual(ref1, 1) self.assertLessEqual(ref1, 2) self.assertLessEqual(0, ref1) self.assertLessEqual(1, ref1) self.assertGreater(ref1, 0) self.assertLessEqual(ref1, ref1) self.assertLessEqual(ref1, ref2) self.assertLessEqual(ref1, ref3) self.assertGreater(ref3, ref1) def testControlFlowError(self): ref1 = ml_collections.FieldReference(True) ref2 = ml_collections.FieldReference(False) with self.assertRaises(NotImplementedError): if ref1: pass with self.assertRaises(NotImplementedError): _ = ref1 and ref2 with self.assertRaises(NotImplementedError): _ = ref1 or ref2 with self.assertRaises(NotImplementedError): _ = not ref1
def main(_): inner_dict = {'list': [1, 2], 'tuple': (1, 2, [3, 4], (5, 6))} example_dict = { 'string': 'tom', 'int': 2, 'list': [1, 2], 'set': {1, 2}, 'tuple': (1, 2), 'ref': ml_collections.FieldReference({'int': 0}), 'inner_dict_1': inner_dict, 'inner_dict_2': inner_dict } print_section('Initializing on dictionary.') # ConfigDict can be initialized on example_dict example_cd = ml_collections.ConfigDict(example_dict) # Dictionary fields are also converted to ConfigDict print(type(example_cd.inner_dict_1)) # And the reference structure is preserved print(id(example_cd.inner_dict_1) == id(example_cd.inner_dict_2)) print_section('Initializing on ConfigDict.') # ConfigDict can also be initialized on a ConfigDict example_cd_cd = ml_collections.ConfigDict(example_cd) # Yielding the same result: print(example_cd == example_cd_cd) # Note that the memory addresses are different print(id(example_cd) == id(example_cd_cd)) # The memory addresses of the attributes are not the same because of the # FieldReference, which gets removed on the second initialization list_to_ids = lambda x: [id(element) for element in x] print( set(list_to_ids(list(example_cd.values()))) == set( list_to_ids(list(example_cd_cd.values())))) print_section('Initializing on self-referencing dictionary.') # Initialization works on a self-referencing dict self_ref_dict = copy.deepcopy(example_dict) self_ref_dict['self'] = self_ref_dict self_ref_cd = ml_collections.ConfigDict(self_ref_dict) # And the reference structure is replicated print(id(self_ref_cd) == id(self_ref_cd.self)) print_section('Unexpected initialization behavior.') # ConfigDict initialization doesn't look inside lists, so doesn't convert a # dict in a list to ConfigDict dict_in_list_in_dict = {'list': [{'troublemaker': 0}]} dict_in_list_in_dict_cd = ml_collections.ConfigDict(dict_in_list_in_dict) print(type(dict_in_list_in_dict_cd.list[0])) # This can cause the reference structure to not be replicated referred_dict = {'key': 'value'} bad_reference = {'referred_dict': referred_dict, 'list': [referred_dict]} bad_reference_cd = ml_collections.ConfigDict(bad_reference) print(id(bad_reference_cd.referred_dict) == id(bad_reference_cd.list[0]))
def testToInt(self): self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27) ref = ml_collections.FieldReference(64.7) ref = ref.to_int() self.assertEqual(ref.get(), 64) self.assertEqual(ref._field_type, int)
_TEST_DICT = { 'int': 2, 'list': [1, 2], 'nested_list': [[1, [2]]], 'set': {1, 2}, 'tuple': (1, 2), 'frozenset': frozenset({1, 2}), 'dict': { 'float': -1.23, 'list': [1, 2], 'dict': {}, 'tuple_containing_list': (1, 2, (3, [4, 5], (6, 7))), 'list_containing_tuple': [1, 2, [3, 4], (5, 6)], }, 'ref': ml_collections.FieldReference({'int': 0}) } def _test_dict_deepcopy(): return copy.deepcopy(_TEST_DICT) def _test_configdict(): return ml_collections.ConfigDict(_TEST_DICT) def _test_frozenconfigdict(): return ml_collections.FrozenConfigDict(_TEST_DICT)
def testBasic(self): ref = ml_collections.FieldReference(1) self.assertEqual(ref.get(), 1)