示例#1
0
def main(_):
    placeholder = ml_collections.FieldReference(0)
    cfg = ml_collections.ConfigDict()
    cfg.placeholder = placeholder
    cfg.optional = ml_collections.FieldReference(0, field_type=int)
    cfg.nested = ml_collections.ConfigDict()
    cfg.nested.placeholder = placeholder

    try:
        cfg.optional = 'tom'  # Raises Type error as this field is an integer.
    except TypeError as e:
        print(e)

    cfg.optional = 1555  # Works fine.
    cfg.placeholder = 1  # Changes the value of both placeholder and
    # nested.placeholder fields.

    # Note that the indirection provided by FieldReferences will be lost if
    # accessed through a ConfigDict:
    placeholder = ml_collections.FieldReference(0)
    cfg.field1 = placeholder
    cfg.field2 = placeholder  # This field will be tied to cfg.field1.
    cfg.field3 = cfg.field1  # This will just be an int field initialized to 0.

    print(cfg)
示例#2
0
    def testTypeChecking(self):
        ref = ml_collections.FieldReference(1)
        string_ref = ml_collections.FieldReference('a')

        x = ref + string_ref
        with self.assertRaises(TypeError):
            x.get()
    def testInitFieldReferenceInList(self):
        """Ensure initialization fails on FieldReferences in lists/tuples."""
        list_containing_fr = {
            'list': [1, 2, 3, ml_collections.FieldReference(4)]
        }
        tuple_containing_fr = {
            'tuple': (1, 2, 3, ml_collections.FieldReference('a'))
        }

        self.assertFrozenRaisesValueError(
            [list_containing_fr, tuple_containing_fr])
示例#4
0
    def testControlFlowError(self):
        ref1 = ml_collections.FieldReference(True)
        ref2 = ml_collections.FieldReference(False)

        with self.assertRaises(NotImplementedError):
            if ref1:
                pass
        with self.assertRaises(NotImplementedError):
            _ = ref1 and ref2
        with self.assertRaises(NotImplementedError):
            _ = ref1 or ref2
        with self.assertRaises(NotImplementedError):
            _ = not ref1
示例#5
0
    def testLessEqual(self):
        # Simple case
        ref1 = ml_collections.FieldReference(1)
        ref2 = ml_collections.FieldReference(1)
        ref3 = ml_collections.FieldReference(2)
        self.assertLessEqual(ref1, 1)
        self.assertLessEqual(ref1, 2)
        self.assertLessEqual(0, ref1)
        self.assertLessEqual(1, ref1)
        self.assertGreater(ref1, 0)

        self.assertLessEqual(ref1, ref1)
        self.assertLessEqual(ref1, ref2)
        self.assertLessEqual(ref1, ref3)
        self.assertGreater(ref3, ref1)
示例#6
0
    def testToString(self):
        self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0')

        ref = ml_collections.FieldReference(647)
        ref = ref.to_str()
        self.assertEqual(ref.get(), '647')
        self.assertEqual(ref._field_type, str)
示例#7
0
    def testToFloat(self):
        self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0)

        ref = ml_collections.FieldReference(647)
        ref = ref.to_float()
        self.assertEqual(ref.get(), 647.0)
        self.assertEqual(ref._field_type, float)
def lazy_configdict():
  """Example usage of lazy computation with ConfigDict."""
  config = ml_collections.ConfigDict()
  config.reference_field = ml_collections.FieldReference(1)
  config.integer_field = 2
  config.float_field = 2.5

  # No lazy evaluatuations because we didn't use get_ref()
  config.no_lazy = config.integer_field * config.float_field

  # This will lazily evaluate ONLY config.integer_field
  config.lazy_integer = config.get_ref('integer_field') * config.float_field

  # This will lazily evaluate ONLY config.float_field
  config.lazy_float = config.integer_field * config.get_ref('float_field')

  # This will lazily evaluate BOTH config.integer_field and config.float_Field
  config.lazy_both = (config.get_ref('integer_field') *
                      config.get_ref('float_field'))

  config.integer_field = 3
  print(config.no_lazy)  # Prints 5.0 - It uses integer_field's original value

  print(config.lazy_integer)  # Prints 7.5

  config.float_field = 3.5
  print(config.lazy_float)  # Prints 7.0
  print(config.lazy_both)  # Prints 10.5
def make_reference(
    config,
    field,
):
  """Returns a reference to a field for wiring up a config dict.

  The returned reference is "one-way": a change to the original field value
  propagates to the reference, but a change to the reference does not propagate
  back up.

  This works recursively. If config['field'] is itself a ConfigDict (or a tree
  of ConfigDicts) then an identical tree is constructed with all the leaves
  referring to original leaves. Internal nodes are *not* references.

  See config_schema_test.MakeReferenceRecursiveTest for example usage.

  Args:
    config: A ConfigDict instance.
    field: The name of a field contained in 'config'.

  Returns:
    If config['field'] is a scalar, returns a one-way FieldReference to
      config['field']. If config['field'] is a ConfigDict, returns a new
      ConfigDict whose fields are references to the fields in config['field'].
  """
  field_value = config[field]
  if isinstance(field_value, ml_collections.ConfigDict):
    new_dict = ml_collections.ConfigDict()
    for subfield in field_value:
      new_dict[subfield] = make_reference(field_value, subfield)
    return new_dict
  else:
    return ml_collections.FieldReference(config.get_ref(field))
 def testFieldReferenceResolved(self):
     """Tests that FieldReferences are resolved."""
     cfg = ml_collections.ConfigDict(
         {'fr': ml_collections.FieldReference(1)})
     frozen_cfg = ml_collections.FrozenConfigDict(cfg)
     self.assertNotIsInstance(frozen_cfg._fields['fr'],
                              ml_collections.FieldReference)
     hash(
         frozen_cfg)  # with FieldReference resolved, frozen_cfg is hashable
    def testFieldReferenceCycle(self):
        """Tests that FieldReferences may not contain reference cycles."""
        frozenset_fr = {'frozenset': frozenset({1, 2})}
        frozenset_fr['fr'] = ml_collections.FieldReference(
            frozenset_fr['frozenset'])
        list_fr = {'list': [1, 2]}
        list_fr['fr'] = ml_collections.FieldReference(list_fr['list'])

        cyclic_fr = {'a': 1}
        cyclic_fr['fr'] = ml_collections.FieldReference(cyclic_fr)
        cyclic_fr_parent = {'dict': {}}
        cyclic_fr_parent['dict']['fr'] = ml_collections.FieldReference(
            cyclic_fr_parent)

        # FieldReference is allowed to point to non-cyclic objects:
        _ = ml_collections.FrozenConfigDict(frozenset_fr)
        _ = ml_collections.FrozenConfigDict(list_fr)
        # But not cycles:
        self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
def get_config():
    """Returns a ConfigDict. Used for tests."""
    cfg = ml_collections.ConfigDict()
    cfg.integer = 1
    cfg.reference = ml_collections.FieldReference(1)
    cfg.list = [1, 2, 3]
    cfg.nested_list = [[1, 2, 3]]
    cfg.nested_configdict = ml_collections.ConfigDict()
    cfg.nested_configdict.integer = 1
    cfg.unusable_config = UnusableConfig()

    return cfg
示例#13
0
    def testSetValue(self):
        ref = ml_collections.FieldReference(1.0)
        other = ml_collections.FieldReference(3)
        ref_plus_other = ref + other

        self.assertEqual(ref_plus_other.get(), 4.0)

        ref.set(2.5)
        self.assertEqual(ref_plus_other.get(), 5.5)

        other.set(110)
        self.assertEqual(ref_plus_other.get(), 112.5)

        # Type checking
        with self.assertRaises(TypeError):
            other.set('this is a string')

        with self.assertRaises(TypeError):
            other.set(ml_collections.FieldReference('this is a string'))

        with self.assertRaises(TypeError):
            other.set(ml_collections.FieldReference(None, field_type=str))
def lazy_computation():
  """Simple example of lazy computation with `configdict.FieldReference`."""
  ref = ml_collections.FieldReference(1)
  print(ref.get())  # Prints 1

  add_ten = ref.get() + 10  # ref.get() is an integer and so is add_ten
  add_ten_lazy = ref + 10  # add_ten_lazy is a FieldReference - NOT an integer

  print(add_ten)  # Prints 11
  print(add_ten_lazy.get())  # Prints 11 because ref's value is 1

  # Addition is lazily computed for FieldReferences so changing ref will change
  # the value that is used to compute add_ten.
  ref.set(5)
  print(add_ten)  # Prints 11
  print(add_ten_lazy.get())  # Prints 15 because ref's value is 5
    def testInitDictInList(self):
        """Ensure initialization fails on dict and ConfigDict in lists/tuples."""
        list_containing_dict = {'list': [1, 2, 3, {'a': 4, 'b': 5}]}
        tuple_containing_dict = {'tuple': (1, 2, 3, {'a': 4, 'b': 5})}
        list_containing_cd = {'list': [1, 2, 3, _test_configdict()]}
        tuple_containing_cd = {'tuple': (1, 2, 3, _test_configdict())}
        fr_containing_list_containing_dict = {
            'fr': ml_collections.FieldReference([1, {
                'a': 2
            }])
        }

        self.assertFrozenRaisesValueError([
            list_containing_dict, tuple_containing_dict, list_containing_cd,
            tuple_containing_cd, fr_containing_list_containing_dict
        ])
示例#16
0
    def _test_binary_operator(self,
                              initial_value,
                              other_value,
                              op,
                              true_value,
                              new_initial_value,
                              new_true_value,
                              assert_fn=None):
        """Helper for testing binary operators.

    Generally speaking this checks that:
      1. `op(initial_value, other_value) COMP true_value`
      2. `op(new_initial_value, other_value) COMP new_true_value
    where `COMP` is the comparison function defined by `assert_fn`.

    Args:
      initial_value: Initial value for the `FieldReference`, this is the first
        argument for the binary operator.
      other_value: The second argument for the binary operator.
      op: The binary operator.
      true_value: The expected output of the binary operator.
      new_initial_value: The value that the `FieldReference` is changed to.
      new_true_value: The expected output of the binary operator after the
        `FieldReference` has changed.
      assert_fn: Function used to check the output values.
    """
        if assert_fn is None:
            assert_fn = self.assertEqual

        ref = ml_collections.FieldReference(initial_value)
        new_ref = op(ref, other_value)
        assert_fn(new_ref.get(), true_value)

        config = ml_collections.ConfigDict()
        config.a = initial_value
        config.b = other_value
        config.result = op(config.get_ref('a'), config.b)
        assert_fn(config.result, true_value)

        config.a = new_initial_value
        assert_fn(config.result, new_true_value)
示例#17
0
    def testCycles(self):
        config = ml_collections.ConfigDict()
        config.a = 1.
        config.b = config.get_ref('a') + 10
        config.c = config.get_ref('b') + 10

        self.assertEqual(config.b, 11.0)
        self.assertEqual(config.c, 21.0)

        # Introduce a cycle
        with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
            config.a = config.get_ref('c') - 1.0

        # Introduce a cycle on second operand
        with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
            config.a = ml_collections.FieldReference(5.0) + config.get_ref('c')

        # We can create multiple FieldReferences that all point to the same object
        l = [0]
        config = ml_collections.ConfigDict()
        config.a = l
        config.b = l
        config.c = config.get_ref('a') + ['c']
        config.d = config.get_ref('b') + ['d']

        self.assertEqual(config.c, [0, 'c'])
        self.assertEqual(config.d, [0, 'd'])

        # Make sure nothing was mutated
        self.assertEqual(l, [0])
        self.assertEqual(config.c, [0, 'c'])

        config.a = [1]
        config.b = [2]
        self.assertEqual(l, [0])
        self.assertEqual(config.c, [1, 'c'])
        self.assertEqual(config.d, [2, 'd'])
示例#18
0
    def testSetResult(self):
        ref = ml_collections.FieldReference(1.0)
        result = ref + 1.0
        second_result = result + 1.0

        self.assertEqual(ref.get(), 1.0)
        self.assertEqual(result.get(), 2.0)
        self.assertEqual(second_result.get(), 3.0)

        ref.set(2.0)
        self.assertEqual(ref.get(), 2.0)
        self.assertEqual(result.get(), 3.0)
        self.assertEqual(second_result.get(), 4.0)

        result.set(4.0)
        self.assertEqual(ref.get(), 2.0)
        self.assertEqual(result.get(), 4.0)
        self.assertEqual(second_result.get(), 5.0)

        # All references are broken at this point.
        ref.set(1.0)
        self.assertEqual(ref.get(), 1.0)
        self.assertEqual(result.get(), 4.0)
        self.assertEqual(second_result.get(), 5.0)
示例#19
0
    def testEqual(self):
        # Simple case
        ref1 = ml_collections.FieldReference(1)
        ref2 = ml_collections.FieldReference(1)
        ref3 = ml_collections.FieldReference(2)
        self.assertEqual(ref1, 1)
        self.assertEqual(ref1, ref1)
        self.assertEqual(ref1, ref2)
        self.assertNotEqual(ref1, 2)
        self.assertNotEqual(ref1, ref3)

        # ConfigDict inside FieldReference
        ref1 = ml_collections.FieldReference(
            ml_collections.ConfigDict({'a': 1}))
        ref2 = ml_collections.FieldReference(
            ml_collections.ConfigDict({'a': 1}))
        ref3 = ml_collections.FieldReference(
            ml_collections.ConfigDict({'a': 2}))
        self.assertEqual(ref1, ml_collections.ConfigDict({'a': 1}))
        self.assertEqual(ref1, ref1)
        self.assertEqual(ref1, ref2)
        self.assertNotEqual(ref1, ml_collections.ConfigDict({'a': 2}))
        self.assertNotEqual(ref1, ref3)
def get_config():
    cfg = ml_collections.ConfigDict()
    cfg.ref = ml_collections.FieldReference(123)
    cfg.ref_nodefault = config_dict.placeholder(int)
    return cfg
示例#21
0
def get_config():
    """Returns default config."""
    config = ml_collections.ConfigDict()

    # ================================================= #
    # Placeholders.
    # ================================================= #
    # These values will be filled at runtime once the gym.Env is loaded.
    obs_dim = ml_collections.FieldReference(None, field_type=int)
    action_dim = ml_collections.FieldReference(None, field_type=int)
    action_range = ml_collections.FieldReference(None, field_type=tuple)

    # ================================================= #
    # Main parameters.
    # ================================================= #
    config.save_dir = "/tmp/xirl/rl_runs/"

    # Set this to True to allow CUDA to find the best convolutional algorithm to
    # use for the given parameters. When False, cuDNN will deterministically
    # select the same algorithm at a possible cost in performance.
    config.cudnn_benchmark = True
    # Enforce CUDA convolution determinism. The algorithm itself might not be
    # deterministic so setting this to True ensures we make it repeatable.
    config.cudnn_deterministic = False

    # ================================================= #
    # Wrappers.
    # ================================================= #
    config.action_repeat = 1
    config.frame_stack = 3

    config.reward_wrapper = ml_collections.ConfigDict()
    config.reward_wrapper.pretrained_path = ""
    # Can be one of ['distance_to_goal', 'goal_classifier'].
    config.reward_wrapper.type = "distance_to_goal"

    # ================================================= #
    # Training parameters.
    # ================================================= #
    config.num_train_steps = 75_000
    config.replay_buffer_capacity = 1_000_000
    config.num_seed_steps = 5_000
    config.num_eval_episodes = 50
    config.eval_frequency = 5_000
    config.checkpoint_frequency = 50_000
    config.log_frequency = 10_000
    config.save_video = True

    # ================================================= #
    # SAC parameters.
    # ================================================= #
    config.sac = ml_collections.ConfigDict()

    config.sac.obs_dim = obs_dim
    config.sac.action_dim = action_dim
    config.sac.action_range = action_range
    config.sac.discount = 0.99
    config.sac.init_temperature = 0.1
    config.sac.alpha_lr = 1e-4
    config.sac.alpha_betas = [0.9, 0.999]
    config.sac.actor_lr = 1e-4
    config.sac.actor_betas = [0.9, 0.999]
    config.sac.actor_update_frequency = 1
    config.sac.critic_lr = 1e-4
    config.sac.critic_betas = [0.9, 0.999]
    config.sac.critic_tau = 0.005
    config.sac.critic_target_update_frequency = 2
    config.sac.batch_size = 1024
    config.sac.learnable_temperature = True

    # ================================================= #
    # Critic parameters.
    # ================================================= #
    config.sac.critic = ml_collections.ConfigDict()

    config.sac.critic.obs_dim = obs_dim
    config.sac.critic.action_dim = action_dim
    config.sac.critic.hidden_dim = 1024
    config.sac.critic.hidden_depth = 2

    # ================================================= #
    # Actor parameters.
    # ================================================= #
    config.sac.actor = ml_collections.ConfigDict()

    config.sac.actor.obs_dim = obs_dim
    config.sac.actor.action_dim = action_dim
    config.sac.actor.hidden_dim = 1024
    config.sac.actor.hidden_depth = 2
    config.sac.actor.log_std_bounds = [-5, 2]

    # ================================================= #

    return config
示例#22
0
class FieldReferenceTest(parameterized.TestCase):
    def _test_binary_operator(self,
                              initial_value,
                              other_value,
                              op,
                              true_value,
                              new_initial_value,
                              new_true_value,
                              assert_fn=None):
        """Helper for testing binary operators.

    Generally speaking this checks that:
      1. `op(initial_value, other_value) COMP true_value`
      2. `op(new_initial_value, other_value) COMP new_true_value
    where `COMP` is the comparison function defined by `assert_fn`.

    Args:
      initial_value: Initial value for the `FieldReference`, this is the first
        argument for the binary operator.
      other_value: The second argument for the binary operator.
      op: The binary operator.
      true_value: The expected output of the binary operator.
      new_initial_value: The value that the `FieldReference` is changed to.
      new_true_value: The expected output of the binary operator after the
        `FieldReference` has changed.
      assert_fn: Function used to check the output values.
    """
        if assert_fn is None:
            assert_fn = self.assertEqual

        ref = ml_collections.FieldReference(initial_value)
        new_ref = op(ref, other_value)
        assert_fn(new_ref.get(), true_value)

        config = ml_collections.ConfigDict()
        config.a = initial_value
        config.b = other_value
        config.result = op(config.get_ref('a'), config.b)
        assert_fn(config.result, true_value)

        config.a = new_initial_value
        assert_fn(config.result, new_true_value)

    def _test_unary_operator(self,
                             initial_value,
                             op,
                             true_value,
                             new_initial_value,
                             new_true_value,
                             assert_fn=None):
        """Helper for testing unary operators.

    Generally speaking this checks that:
      1. `op(initial_value) COMP true_value`
      2. `op(new_initial_value) COMP new_true_value
    where `COMP` is the comparison function defined by `assert_fn`.

    Args:
      initial_value: Initial value for the `FieldReference`, this is the first
        argument for the unary operator.
      op: The unary operator.
      true_value: The expected output of the unary operator.
      new_initial_value: The value that the `FieldReference` is changed to.
      new_true_value: The expected output of the unary operator after the
        `FieldReference` has changed.
      assert_fn: Function used to check the output values.
    """
        if assert_fn is None:
            assert_fn = self.assertEqual

        ref = ml_collections.FieldReference(initial_value)
        new_ref = op(ref)
        assert_fn(new_ref.get(), true_value)

        config = ml_collections.ConfigDict()
        config.a = initial_value
        config.result = op(config.get_ref('a'))
        assert_fn(config.result, true_value)

        config.a = new_initial_value
        assert_fn(config.result, new_true_value)

    def testBasic(self):
        ref = ml_collections.FieldReference(1)
        self.assertEqual(ref.get(), 1)

    def testGetRef(self):
        config = ml_collections.ConfigDict()
        config.a = 1.
        config.b = config.get_ref('a') + 10
        config.c = config.get_ref('b') + 10
        self.assertEqual(config.c, 21.0)

    def testFunction(self):
        def fn(x):
            return x + 5

        config = ml_collections.ConfigDict()
        config.a = 1
        config.b = fn(config.get_ref('a'))
        config.c = fn(config.get_ref('b'))

        self.assertEqual(config.b, 6)
        self.assertEqual(config.c, 11)
        config.a = 2
        self.assertEqual(config.b, 7)
        self.assertEqual(config.c, 12)

    def testCycles(self):
        config = ml_collections.ConfigDict()
        config.a = 1.
        config.b = config.get_ref('a') + 10
        config.c = config.get_ref('b') + 10

        self.assertEqual(config.b, 11.0)
        self.assertEqual(config.c, 21.0)

        # Introduce a cycle
        with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
            config.a = config.get_ref('c') - 1.0

        # Introduce a cycle on second operand
        with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
            config.a = ml_collections.FieldReference(5.0) + config.get_ref('c')

        # We can create multiple FieldReferences that all point to the same object
        l = [0]
        config = ml_collections.ConfigDict()
        config.a = l
        config.b = l
        config.c = config.get_ref('a') + ['c']
        config.d = config.get_ref('b') + ['d']

        self.assertEqual(config.c, [0, 'c'])
        self.assertEqual(config.d, [0, 'd'])

        # Make sure nothing was mutated
        self.assertEqual(l, [0])
        self.assertEqual(config.c, [0, 'c'])

        config.a = [1]
        config.b = [2]
        self.assertEqual(l, [0])
        self.assertEqual(config.c, [1, 'c'])
        self.assertEqual(config.d, [2, 'd'])

    @parameterized.parameters(
        {
            'initial_value': 1,
            'other_value': 2,
            'true_value': 3,
            'new_initial_value': 10,
            'new_true_value': 12
        }, {
            'initial_value': 2.0,
            'other_value': 2.5,
            'true_value': 4.5,
            'new_initial_value': 3.7,
            'new_true_value': 6.2
        }, {
            'initial_value': 'hello, ',
            'other_value': 'world!',
            'true_value': 'hello, world!',
            'new_initial_value': 'foo, ',
            'new_true_value': 'foo, world!'
        }, {
            'initial_value': ['hello'],
            'other_value': ['world'],
            'true_value': ['hello', 'world'],
            'new_initial_value': ['foo'],
            'new_true_value': ['foo', 'world']
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5.0),
            'true_value': 15.0,
            'new_initial_value': 12,
            'new_true_value': 17.0
        }, {
            'initial_value': config_dict.placeholder(float),
            'other_value': 7.0,
            'true_value': None,
            'new_initial_value': 12,
            'new_true_value': 19.0
        }, {
            'initial_value': 5.0,
            'other_value': config_dict.placeholder(float),
            'true_value': None,
            'new_initial_value': 8.0,
            'new_true_value': None
        }, {
            'initial_value': config_dict.placeholder(str),
            'other_value': 'tail',
            'true_value': None,
            'new_initial_value': 'head',
            'new_true_value': 'headtail'
        })
    def testAdd(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value, operator.add,
                                   true_value, new_initial_value,
                                   new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 5,
            'other_value': 3,
            'true_value': 2,
            'new_initial_value': -1,
            'new_true_value': -4
        }, {
            'initial_value': 2.0,
            'other_value': 2.5,
            'true_value': -0.5,
            'new_initial_value': 12.3,
            'new_true_value': 9.8
        }, {
            'initial_value': set(['hello', 123, 4.5]),
            'other_value': set([123]),
            'true_value': set(['hello', 4.5]),
            'new_initial_value': set([123]),
            'new_true_value': set([])
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5.0),
            'true_value': 5.0,
            'new_initial_value': 12,
            'new_true_value': 7.0
        }, {
            'initial_value': config_dict.placeholder(float),
            'other_value': 7.0,
            'true_value': None,
            'new_initial_value': 12,
            'new_true_value': 5.0
        })
    def testSub(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value, operator.sub,
                                   true_value, new_initial_value,
                                   new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 1,
            'other_value': 2,
            'true_value': 2,
            'new_initial_value': 3,
            'new_true_value': 6
        }, {
            'initial_value': 2.0,
            'other_value': 2.5,
            'true_value': 5.0,
            'new_initial_value': 3.5,
            'new_true_value': 8.75
        }, {
            'initial_value': ['hello'],
            'other_value': 3,
            'true_value': ['hello', 'hello', 'hello'],
            'new_initial_value': ['foo'],
            'new_true_value': ['foo', 'foo', 'foo']
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5.0),
            'true_value': 50.0,
            'new_initial_value': 1,
            'new_true_value': 5.0
        }, {
            'initial_value': config_dict.placeholder(float),
            'other_value': 7.0,
            'true_value': None,
            'new_initial_value': 12,
            'new_true_value': 84.0
        })
    def testMul(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value, operator.mul,
                                   true_value, new_initial_value,
                                   new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 3,
            'other_value': 2,
            'true_value': 1.5,
            'new_initial_value': 10,
            'new_true_value': 5.0
        }, {
            'initial_value': 2.0,
            'other_value': 2.5,
            'true_value': 0.8,
            'new_initial_value': 6.3,
            'new_true_value': 2.52
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5.0),
            'true_value': 2.0,
            'new_initial_value': 13,
            'new_true_value': 2.6
        }, {
            'initial_value': config_dict.placeholder(float),
            'other_value': 7.0,
            'true_value': None,
            'new_initial_value': 17.5,
            'new_true_value': 2.5
        })
    def testTrueDiv(self, initial_value, other_value, true_value,
                    new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value,
                                   operator.truediv, true_value,
                                   new_initial_value, new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 3,
            'other_value': 2,
            'true_value': 1,
            'new_initial_value': 7,
            'new_true_value': 3
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5),
            'true_value': 2,
            'new_initial_value': 28,
            'new_true_value': 5
        }, {
            'initial_value': config_dict.placeholder(int),
            'other_value': 7,
            'true_value': None,
            'new_initial_value': 25,
            'new_true_value': 3
        })
    def testFloorDiv(self, initial_value, other_value, true_value,
                     new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value,
                                   operator.floordiv, true_value,
                                   new_initial_value, new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 3,
            'other_value': 2,
            'true_value': 9,
            'new_initial_value': 10,
            'new_true_value': 100
        }, {
            'initial_value': 2.7,
            'other_value': 3.2,
            'true_value': 24.0084457245,
            'new_initial_value': 6.5,
            'new_true_value': 399.321543621
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5),
            'true_value': 1e5,
            'new_initial_value': 2,
            'new_true_value': 32
        }, {
            'initial_value': config_dict.placeholder(float),
            'other_value': 3.0,
            'true_value': None,
            'new_initial_value': 7.0,
            'new_true_value': 343.0
        })
    def testPow(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value,
                                   other_value,
                                   operator.pow,
                                   true_value,
                                   new_initial_value,
                                   new_true_value,
                                   assert_fn=self.assertAlmostEqual)

    @parameterized.parameters(
        {
            'initial_value': 3,
            'other_value': 2,
            'true_value': 1,
            'new_initial_value': 10,
            'new_true_value': 0
        }, {
            'initial_value': 5.3,
            'other_value': 3.2,
            'true_value': 2.0999999999999996,
            'new_initial_value': 77,
            'new_true_value': 0.2
        }, {
            'initial_value': ml_collections.FieldReference(10),
            'other_value': ml_collections.FieldReference(5),
            'true_value': 0,
            'new_initial_value': 32,
            'new_true_value': 2
        }, {
            'initial_value': config_dict.placeholder(int),
            'other_value': 7,
            'true_value': None,
            'new_initial_value': 25,
            'new_true_value': 4
        })
    def testMod(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value,
                                   other_value,
                                   operator.mod,
                                   true_value,
                                   new_initial_value,
                                   new_true_value,
                                   assert_fn=self.assertAlmostEqual)

    @parameterized.parameters(
        {
            'initial_value': True,
            'other_value': True,
            'true_value': True,
            'new_initial_value': False,
            'new_true_value': False
        }, {
            'initial_value': ml_collections.FieldReference(False),
            'other_value': ml_collections.FieldReference(False),
            'true_value': False,
            'new_initial_value': True,
            'new_true_value': False
        }, {
            'initial_value': config_dict.placeholder(bool),
            'other_value': True,
            'true_value': None,
            'new_initial_value': False,
            'new_true_value': False
        })
    def testAnd(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value, operator.and_,
                                   true_value, new_initial_value,
                                   new_true_value)

    @parameterized.parameters(
        {
            'initial_value': False,
            'other_value': False,
            'true_value': False,
            'new_initial_value': True,
            'new_true_value': True
        }, {
            'initial_value': ml_collections.FieldReference(True),
            'other_value': ml_collections.FieldReference(True),
            'true_value': True,
            'new_initial_value': False,
            'new_true_value': True
        }, {
            'initial_value': config_dict.placeholder(bool),
            'other_value': False,
            'true_value': None,
            'new_initial_value': True,
            'new_true_value': True
        })
    def testOr(self, initial_value, other_value, true_value, new_initial_value,
               new_true_value):
        self._test_binary_operator(initial_value, other_value, operator.or_,
                                   true_value, new_initial_value,
                                   new_true_value)

    @parameterized.parameters(
        {
            'initial_value': False,
            'other_value': True,
            'true_value': True,
            'new_initial_value': True,
            'new_true_value': False
        }, {
            'initial_value': ml_collections.FieldReference(True),
            'other_value': ml_collections.FieldReference(True),
            'true_value': False,
            'new_initial_value': False,
            'new_true_value': True
        }, {
            'initial_value': config_dict.placeholder(bool),
            'other_value': True,
            'true_value': None,
            'new_initial_value': True,
            'new_true_value': False
        })
    def testXor(self, initial_value, other_value, true_value,
                new_initial_value, new_true_value):
        self._test_binary_operator(initial_value, other_value, operator.xor,
                                   true_value, new_initial_value,
                                   new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 3,
            'true_value': -3,
            'new_initial_value': -22,
            'new_true_value': 22
        }, {
            'initial_value': 15.3,
            'true_value': -15.3,
            'new_initial_value': -0.2,
            'new_true_value': 0.2
        }, {
            'initial_value': ml_collections.FieldReference(7),
            'true_value': ml_collections.FieldReference(-7),
            'new_initial_value': 123,
            'new_true_value': -123
        }, {
            'initial_value': config_dict.placeholder(int),
            'true_value': None,
            'new_initial_value': -6,
            'new_true_value': 6
        })
    def testNeg(self, initial_value, true_value, new_initial_value,
                new_true_value):
        self._test_unary_operator(initial_value, operator.neg, true_value,
                                  new_initial_value, new_true_value)

    @parameterized.parameters(
        {
            'initial_value': 3,
            'true_value': 3,
            'new_initial_value': -101,
            'new_true_value': 101
        }, {
            'initial_value': -15.3,
            'true_value': 15.3,
            'new_initial_value': 7.3,
            'new_true_value': 7.3
        }, {
            'initial_value': ml_collections.FieldReference(-7),
            'true_value': ml_collections.FieldReference(7),
            'new_initial_value': 3,
            'new_true_value': 3
        }, {
            'initial_value': config_dict.placeholder(float),
            'true_value': None,
            'new_initial_value': -6.25,
            'new_true_value': 6.25
        })
    def testAbs(self, initial_value, true_value, new_initial_value,
                new_true_value):
        self._test_unary_operator(initial_value, operator.abs, true_value,
                                  new_initial_value, new_true_value)

    def testToInt(self):
        self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27)
        ref = ml_collections.FieldReference(64.7)
        ref = ref.to_int()
        self.assertEqual(ref.get(), 64)
        self.assertEqual(ref._field_type, int)

    def testToFloat(self):
        self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0)

        ref = ml_collections.FieldReference(647)
        ref = ref.to_float()
        self.assertEqual(ref.get(), 647.0)
        self.assertEqual(ref._field_type, float)

    def testToString(self):
        self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0')

        ref = ml_collections.FieldReference(647)
        ref = ref.to_str()
        self.assertEqual(ref.get(), '647')
        self.assertEqual(ref._field_type, str)

    def testSetValue(self):
        ref = ml_collections.FieldReference(1.0)
        other = ml_collections.FieldReference(3)
        ref_plus_other = ref + other

        self.assertEqual(ref_plus_other.get(), 4.0)

        ref.set(2.5)
        self.assertEqual(ref_plus_other.get(), 5.5)

        other.set(110)
        self.assertEqual(ref_plus_other.get(), 112.5)

        # Type checking
        with self.assertRaises(TypeError):
            other.set('this is a string')

        with self.assertRaises(TypeError):
            other.set(ml_collections.FieldReference('this is a string'))

        with self.assertRaises(TypeError):
            other.set(ml_collections.FieldReference(None, field_type=str))

    def testSetResult(self):
        ref = ml_collections.FieldReference(1.0)
        result = ref + 1.0
        second_result = result + 1.0

        self.assertEqual(ref.get(), 1.0)
        self.assertEqual(result.get(), 2.0)
        self.assertEqual(second_result.get(), 3.0)

        ref.set(2.0)
        self.assertEqual(ref.get(), 2.0)
        self.assertEqual(result.get(), 3.0)
        self.assertEqual(second_result.get(), 4.0)

        result.set(4.0)
        self.assertEqual(ref.get(), 2.0)
        self.assertEqual(result.get(), 4.0)
        self.assertEqual(second_result.get(), 5.0)

        # All references are broken at this point.
        ref.set(1.0)
        self.assertEqual(ref.get(), 1.0)
        self.assertEqual(result.get(), 4.0)
        self.assertEqual(second_result.get(), 5.0)

    def testTypeChecking(self):
        ref = ml_collections.FieldReference(1)
        string_ref = ml_collections.FieldReference('a')

        x = ref + string_ref
        with self.assertRaises(TypeError):
            x.get()

    def testNoType(self):
        self.assertRaisesRegex(TypeError, 'field_type should be a type.*',
                               ml_collections.FieldReference, None, 0)

    def testEqual(self):
        # Simple case
        ref1 = ml_collections.FieldReference(1)
        ref2 = ml_collections.FieldReference(1)
        ref3 = ml_collections.FieldReference(2)
        self.assertEqual(ref1, 1)
        self.assertEqual(ref1, ref1)
        self.assertEqual(ref1, ref2)
        self.assertNotEqual(ref1, 2)
        self.assertNotEqual(ref1, ref3)

        # ConfigDict inside FieldReference
        ref1 = ml_collections.FieldReference(
            ml_collections.ConfigDict({'a': 1}))
        ref2 = ml_collections.FieldReference(
            ml_collections.ConfigDict({'a': 1}))
        ref3 = ml_collections.FieldReference(
            ml_collections.ConfigDict({'a': 2}))
        self.assertEqual(ref1, ml_collections.ConfigDict({'a': 1}))
        self.assertEqual(ref1, ref1)
        self.assertEqual(ref1, ref2)
        self.assertNotEqual(ref1, ml_collections.ConfigDict({'a': 2}))
        self.assertNotEqual(ref1, ref3)

    def testLessEqual(self):
        # Simple case
        ref1 = ml_collections.FieldReference(1)
        ref2 = ml_collections.FieldReference(1)
        ref3 = ml_collections.FieldReference(2)
        self.assertLessEqual(ref1, 1)
        self.assertLessEqual(ref1, 2)
        self.assertLessEqual(0, ref1)
        self.assertLessEqual(1, ref1)
        self.assertGreater(ref1, 0)

        self.assertLessEqual(ref1, ref1)
        self.assertLessEqual(ref1, ref2)
        self.assertLessEqual(ref1, ref3)
        self.assertGreater(ref3, ref1)

    def testControlFlowError(self):
        ref1 = ml_collections.FieldReference(True)
        ref2 = ml_collections.FieldReference(False)

        with self.assertRaises(NotImplementedError):
            if ref1:
                pass
        with self.assertRaises(NotImplementedError):
            _ = ref1 and ref2
        with self.assertRaises(NotImplementedError):
            _ = ref1 or ref2
        with self.assertRaises(NotImplementedError):
            _ = not ref1
示例#23
0
def main(_):

    inner_dict = {'list': [1, 2], 'tuple': (1, 2, [3, 4], (5, 6))}
    example_dict = {
        'string': 'tom',
        'int': 2,
        'list': [1, 2],
        'set': {1, 2},
        'tuple': (1, 2),
        'ref': ml_collections.FieldReference({'int': 0}),
        'inner_dict_1': inner_dict,
        'inner_dict_2': inner_dict
    }

    print_section('Initializing on dictionary.')
    # ConfigDict can be initialized on example_dict
    example_cd = ml_collections.ConfigDict(example_dict)

    # Dictionary fields are also converted to ConfigDict
    print(type(example_cd.inner_dict_1))

    # And the reference structure is preserved
    print(id(example_cd.inner_dict_1) == id(example_cd.inner_dict_2))

    print_section('Initializing on ConfigDict.')

    # ConfigDict can also be initialized on a ConfigDict
    example_cd_cd = ml_collections.ConfigDict(example_cd)

    # Yielding the same result:
    print(example_cd == example_cd_cd)

    # Note that the memory addresses are different
    print(id(example_cd) == id(example_cd_cd))

    # The memory addresses of the attributes are not the same because of the
    # FieldReference, which gets removed on the second initialization
    list_to_ids = lambda x: [id(element) for element in x]
    print(
        set(list_to_ids(list(example_cd.values()))) == set(
            list_to_ids(list(example_cd_cd.values()))))

    print_section('Initializing on self-referencing dictionary.')

    # Initialization works on a self-referencing dict
    self_ref_dict = copy.deepcopy(example_dict)
    self_ref_dict['self'] = self_ref_dict
    self_ref_cd = ml_collections.ConfigDict(self_ref_dict)

    # And the reference structure is replicated
    print(id(self_ref_cd) == id(self_ref_cd.self))

    print_section('Unexpected initialization behavior.')

    # ConfigDict initialization doesn't look inside lists, so doesn't convert a
    # dict in a list to ConfigDict
    dict_in_list_in_dict = {'list': [{'troublemaker': 0}]}
    dict_in_list_in_dict_cd = ml_collections.ConfigDict(dict_in_list_in_dict)
    print(type(dict_in_list_in_dict_cd.list[0]))

    # This can cause the reference structure to not be replicated
    referred_dict = {'key': 'value'}
    bad_reference = {'referred_dict': referred_dict, 'list': [referred_dict]}
    bad_reference_cd = ml_collections.ConfigDict(bad_reference)
    print(id(bad_reference_cd.referred_dict) == id(bad_reference_cd.list[0]))
示例#24
0
 def testToInt(self):
     self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27)
     ref = ml_collections.FieldReference(64.7)
     ref = ref.to_int()
     self.assertEqual(ref.get(), 64)
     self.assertEqual(ref._field_type, int)
_TEST_DICT = {
    'int': 2,
    'list': [1, 2],
    'nested_list': [[1, [2]]],
    'set': {1, 2},
    'tuple': (1, 2),
    'frozenset': frozenset({1, 2}),
    'dict': {
        'float': -1.23,
        'list': [1, 2],
        'dict': {},
        'tuple_containing_list': (1, 2, (3, [4, 5], (6, 7))),
        'list_containing_tuple': [1, 2, [3, 4], (5, 6)],
    },
    'ref': ml_collections.FieldReference({'int': 0})
}


def _test_dict_deepcopy():
    return copy.deepcopy(_TEST_DICT)


def _test_configdict():
    return ml_collections.ConfigDict(_TEST_DICT)


def _test_frozenconfigdict():
    return ml_collections.FrozenConfigDict(_TEST_DICT)

示例#26
0
 def testBasic(self):
     ref = ml_collections.FieldReference(1)
     self.assertEqual(ref.get(), 1)