Пример #1
0
    def test_map_nested(self):
        """Test the mapping function."""
        def map_fn(x):
            return x * 10

        result = py_utils.map_nested(map_fn, {
            'a': 1,
            'b': {
                'c': 2,
                'e': [3, 4, 5],
            },
        })
        self.assertEqual(result, {
            'a': 10,
            'b': {
                'c': 20,
                'e': [30, 40, 50],
            },
        })

        result = py_utils.map_nested(map_fn, [1, 2, 3])
        self.assertEqual(result, [10, 20, 30])

        result = py_utils.map_nested(map_fn, 1)
        self.assertEqual(result, 10)
Пример #2
0
def shapes_are_compatible(
    shapes0: type_utils.TreeDict[type_utils.Shape],
    shapes1: type_utils.TreeDict[type_utils.Shape],
) -> bool:
  """Returns True if all shapes are compatible."""
  # Use `py_utils.map_nested` instead of `tf.nest.map_structure` as shapes are
  # tuple/list.
  shapes0 = py_utils.map_nested(tf.TensorShape, shapes0, dict_only=True)
  shapes1 = py_utils.map_nested(tf.TensorShape, shapes1, dict_only=True)
  all_values = tf.nest.map_structure(
      lambda s0, s1: s0.is_compatible_with(s1),
      shapes0,
      shapes1,
  )
  return all(tf.nest.flatten(all_values))
Пример #3
0
    def test_dict_only(self):
        def map_fn(x):
            return x[0] + x[1]

        arg0 = {
            'a': (1, 2),
            'b': {
                'c': 2,
                'e': [3, 4, 5],
            },
        }
        arg1 = {
            'a': (10, 20),
            'b': {
                'c': 20,
                'e': [30, 40, 50],
            },
        }

        result = py_utils.zip_nested(arg0, arg1, dict_only=True)
        self.assertEqual(
            result, {
                'a': ((1, 2), (10, 20)),
                'b': {
                    'c': (2, 20),
                    'e': ([3, 4, 5], [30, 40, 50]),
                },
            })

        result = py_utils.map_nested(map_fn, result, dict_only=True)
        self.assertEqual(result, {
            'a': (1, 2, 10, 20),
            'b': {
                'c': 22,
                'e': [3, 4, 5, 30, 40, 50],
            },
        })
Пример #4
0
 def dtype(self):
     """Returns the `dtype` after decoding."""
     tensor_info = self.feature.get_tensor_info()
     return py_utils.map_nested(lambda t: t.dtype, tensor_info)
Пример #5
0
 def dtype(self):
     tensor_info = self.feature.get_serialized_info()
     return py_utils.map_nested(lambda t: t.dtype, tensor_info)