Ejemplo n.º 1
0
 def _check_split(self, dataset):
   """Check given split has right types and shapes."""
   for component, (expected_type, expected_shapes) in self.SPEC.items():
     output_type = dataset.output_types[component]
     self.assertEqual(
         expected_type, output_type,
         "Component %s doesn't have type %s, but %s." %
         (component, expected_type, output_type))
     shapes = dataset.output_shapes[component]
     tf_utils.assert_shape_match(shapes, expected_shapes)
Ejemplo n.º 2
0
def compare_shapes_and_types(tensor_info, output_types, output_shapes):
    """Compare shapes and types between TensorInfo and Dataset types/shapes."""
    for feature_name, feature_info in tensor_info.items():
        if isinstance(feature_info, dict):
            compare_shapes_and_types(feature_info, output_types[feature_name],
                                     output_shapes[feature_name])
        else:
            expected_type = feature_info.dtype
            output_type = output_types[feature_name]
            if expected_type != output_type:
                raise TypeError("Feature %s has type %s but expected %s" %
                                (feature_name, output_type, expected_type))

            expected_shape = feature_info.shape
            output_shape = output_shapes[feature_name]
            tf_utils.assert_shape_match(expected_shape, output_shape)