예제 #1
0
 def test_empty_state_evaluates_to_false(self):
     s_1 = bd.State()
     s_2 = bd.State({})
     s_3 = bd.State({'foo': 10})
     assert not s_1
     assert not s_2
     assert s_3
예제 #2
0
 def create_criteria_from_cfg(self, cfg=None, **module_kwargs):
     cfg = _prepare_cfg(cfg, CRITERIA_KEYS)
     if not cfg.criteria:
         return
     bd.print_separator()
     bd.log('Building criteria from cfg')
     ret = bd.State({'weights': {}})
     w_strs, mod_strs = [], []
     module_kwargs = {k.lower(): v for k, v in module_kwargs.items()}
     for name in cfg.criteria:
         kwargs = {}
         if 'all' in module_kwargs:
             kwargs.update(module_kwargs['all'])
         if name.lower() in module_kwargs:
             kwargs.update(module_kwargs[name.lower()])
         module = bd.magic_module([name, {'kwargs': kwargs}])
         with cfg.group_fallback():
             weight = cfg.g[name].get('criterion_weight')
         mod_strs.append(f'\t{module}')
         w_strs.append(f'\t{name}={weight}')
         ret[name] = module
         ret.weights[name] = weight
     bd.write('Criteria:\n' + '\n'.join(mod_strs))
     bd.write('Weights:\n' + '\n'.join(w_strs))
     bd.print_separator()
     return ret
예제 #3
0
 def test_can_check_membership(self):
     s = bd.State({'a': {'b': 10, 'd': ('f', 'g')}})
     assert 'a' in s
     assert 'a.b' in s
     assert 'c' not in s
     assert 'a.c' not in s
     assert 'a.d' in s
     assert 'a.d.f' not in s
예제 #4
0
 def default_compute_losses(self, prediction, target):
     losses = bd.State()
     total = 0
     crits = self.criteria
     weights = crits.weights
     for key, loss_fn in crits.items():
         if key == 'weights':
             continue
         current = weights[key] * loss_fn(prediction, target)
         losses[key] = current
         total = total + current
     losses.total = total
     return losses
예제 #5
0
 def test_update_works_and_is_recursive(self):
     s_1 = bd.State({
         'foo': {
             'baz': {
                 'bonk': 5,
                 'bonkers': 3
             },
             'qux': 2
         },
         'bar': 1
     })
     s_2 = bd.State({'foo': {'baz': {'bonk': 9}}, 'bar': 7})
     assert s_1.foo.baz.bonk == 5
     assert s_1.foo.baz.bonkers == 3
     assert s_1.foo.qux == 2
     assert s_1.bar == 1
     assert s_2.foo.baz.bonk == 9
     assert s_2.bar == 7
     s_1.update(s_2)
     assert s_1.foo.baz.bonk == 9
     assert s_1.foo.baz.bonkers == 3
     assert s_1.foo.qux == 2
     assert s_1.bar == 7
예제 #6
0
 def test_can_access_nested_with_dotted_string(self):
     s = bd.State()
     s.a = {'foo': 10}
     assert s['a.foo'] == 10
예제 #7
0
    def test_is_mapping(self):
        from collections.abc import Mapping

        assert isinstance(bd.State(), Mapping)
예제 #8
0
 def test_can_access_nested_attributes(self):
     s = bd.State()
     s.a = {'foo': 10}
     assert s['a']['foo'] == 10
     assert s['a'].foo == 10
     assert s.a.foo == 10
예제 #9
0
 def test_can_create_from_state_with_data(self):
     d = bd.State()
     d.a = {'b': 10}
     s = bd.State(bd.State(d))
     assert s.a.b == 10
예제 #10
0
 def test_can_get_using_get_function(self):
     s = bd.State({'foo': {'bar': 10}, 'baz': 3})
     assert s.get('baz') == 3
     assert s.get('foo.bar') == 10
     assert s.get('qux', None) is None
     assert s.get('foo.qux', 5) == 5
예제 #11
0
 def test_can_assign_dict(self):
     s = bd.State()
     s.a = {'foo': 10}
     assert s['a']['foo'] == 10
     assert isinstance(s['a'], bd.State)
예제 #12
0
 def test_can_assign_tuple(self):
     s = bd.State()
     s.a = (10, )
     assert s['a'] == (10, )
예제 #13
0
 def test_can_assign_list(self):
     s = bd.State()
     s.a = [10]
     assert s['a'] == [10]
예제 #14
0
 def test_can_assign_member(self):
     s = bd.State()
     s.a = 10
     assert s['a'] == 10
예제 #15
0
 def test_can_access_nested_with_sequence(self):
     s = bd.State()
     s['a'] = {'b': 10}
     assert s['a', 'b'] == 10
     assert s[['a', 'b']] == 10
     assert s[('a', 'b')] == 10
예제 #16
0
 def test_non_sequence_key_raises_error(self):
     s = bd.State()
     s['a'] = {'b': 10}
     with pytest.raises(TypeError):
         s[set(('a', 'b'))]
예제 #17
0
 def test_assigned_nested_dicts_become_state(self):
     s = bd.State()
     s.a = {'foo': {'bar': 10}}
     assert s['a']['foo']['bar'] == 10
     assert isinstance(s['a'], bd.State)
     assert isinstance(s['a']['foo'], bd.State)
예제 #18
0
 def test_can_create_from_state(self):
     s = bd.State(bd.State())
     assert s == bd.State()
예제 #19
0
 def test_can_assign_tensor(self):
     s = bd.State()
     s.a = torch.Tensor((10, ))
     assert s['a'] == torch.Tensor((10, ))
예제 #20
0
 def test_can_create_from_dict(self):
     s = bd.State({'a': {'b': 10}})
     assert s.a.b == 10
예제 #21
0
 def test_dicts_become_state_type(self):
     s = bd.State()
     s.a = {'foo': 10}
     assert isinstance(s['a'], bd.State)
예제 #22
0
 def test_can_access_member_as_attribute(self):
     s = bd.State()
     s['a'] = 10
     assert s.a == 10
예제 #23
0
 def test_can_only_use_valid_identifiers_as_keys(self):
     with pytest.raises(KeyError) as e:
         bd.State({'1a': 10})
     assert 'valid Python identifiers' in str(e)