def test_contains(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') s[key1] = 1 s[key2] = 2 self.assertTrue(s.__contains__(key1))
def test_update(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') new_s = {key1: 1, key2: 2} s.update(new_s) self.assertTrue(s.__contains__(key1)) self.assertTrue(s[key1] == 1)
def test_warn(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') s[key1] = 'key_1' s[key2] = 'key_2' s['bad_key'] = 'bad_key' self.assertTrue(len(w) == 1) self.assertTrue( 'State was accessed with a string' in str(w[-1].message))
def test_delete(self): s = State() key1 = torchbearer.state_key('test_a') key2 = torchbearer.state_key('test_b') s[key1] = 1 s[key2] = 2 self.assertTrue(s.__contains__(key1)) s.__delitem__(key1) self.assertFalse(s.__contains__(key1))