def testSetItemOverlap(self): """Ensure insertion fails if key overlaps with existing key.""" with tf.Graph().as_default(): lp_dict = layer_collection.LayerParametersDict() x = tf.constant(0) y = tf.constant(0) lp_dict[x] = 'value' with self.assertRaises(ValueError): lp_dict[(x, y)] = 'value' # Ensure 'y' wasn't inserted. self.assertTrue(x in lp_dict) self.assertFalse(y in lp_dict)
def testSetItem(self): """Ensure insertion, contains, retrieval works for supported key types.""" with tf.Graph().as_default(): lp_dict = layer_collection.LayerParametersDict() x = tf.constant(0) y0 = tf.constant(0) y1 = tf.constant(0) z0 = tf.constant(0) z1 = tf.constant(0) keys = [x, (y0, y1), [z0, z1]] for key in keys: lp_dict[key] = key for key in keys: self.assertTrue(key in lp_dict) self.assertEqual(lp_dict[key], key)