def test_map_nested(self): """Test the mapping function.""" def map_fn(x): return x * 10 result = py_utils.map_nested(map_fn, { 'a': 1, 'b': { 'c': 2, 'e': [3, 4, 5], }, }) self.assertEqual(result, { 'a': 10, 'b': { 'c': 20, 'e': [30, 40, 50], }, }) result = py_utils.map_nested(map_fn, [1, 2, 3]) self.assertEqual(result, [10, 20, 30]) result = py_utils.map_nested(map_fn, 1) self.assertEqual(result, 10)
def shapes_are_compatible( shapes0: type_utils.TreeDict[type_utils.Shape], shapes1: type_utils.TreeDict[type_utils.Shape], ) -> bool: """Returns True if all shapes are compatible.""" # Use `py_utils.map_nested` instead of `tf.nest.map_structure` as shapes are # tuple/list. shapes0 = py_utils.map_nested(tf.TensorShape, shapes0, dict_only=True) shapes1 = py_utils.map_nested(tf.TensorShape, shapes1, dict_only=True) all_values = tf.nest.map_structure( lambda s0, s1: s0.is_compatible_with(s1), shapes0, shapes1, ) return all(tf.nest.flatten(all_values))
def test_dict_only(self): def map_fn(x): return x[0] + x[1] arg0 = { 'a': (1, 2), 'b': { 'c': 2, 'e': [3, 4, 5], }, } arg1 = { 'a': (10, 20), 'b': { 'c': 20, 'e': [30, 40, 50], }, } result = py_utils.zip_nested(arg0, arg1, dict_only=True) self.assertEqual( result, { 'a': ((1, 2), (10, 20)), 'b': { 'c': (2, 20), 'e': ([3, 4, 5], [30, 40, 50]), }, }) result = py_utils.map_nested(map_fn, result, dict_only=True) self.assertEqual(result, { 'a': (1, 2, 10, 20), 'b': { 'c': 22, 'e': [3, 4, 5, 30, 40, 50], }, })
def dtype(self): """Returns the `dtype` after decoding.""" tensor_info = self.feature.get_tensor_info() return py_utils.map_nested(lambda t: t.dtype, tensor_info)
def dtype(self): tensor_info = self.feature.get_serialized_info() return py_utils.map_nested(lambda t: t.dtype, tensor_info)