def test_infer_feature_schema_with_ragged_tensor(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            return {
                'foo':
                tf.RaggedTensor.from_row_splits(values=tf.constant(
                    [3, 1, 4, 1, 5, 9, 2, 6], tf.int64),
                                                row_splits=[0, 4, 4, 7, 8, 8]),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        expected_schema_ascii = """feature {
name: "foo"
type: INT
annotation {
tag: "ragged_tensor"
}
}
"""
        expected_schema = text_format.Parse(expected_schema_ascii,
                                            schema_pb2.Schema())
        schema_utils_legacy.set_generate_legacy_feature_spec(
            expected_schema, False)
        self.assertProtoEquals(expected_schema, schema)
        with self.assertRaisesRegexp(ValueError,
                                     'Feature "foo" had tag "ragged_tensor"'):
            schema_utils.schema_as_feature_spec(schema)
 def test_handle_ragged_batch(self, ragged_tensor, spec,
                              expected_components):
     test_case.skip_if_not_tf2('RaggedFeature is not available in TF 1.x')
     result = impl_helper._handle_ragged_batch(ragged_tensor,
                                               spec,
                                               name='ragged')
     np.testing.assert_equal(result, expected_components)
Beispiel #3
0
    def test_convert_to_arrow(self, feature_spec, instances, feed_dict,
                              feed_eager_tensors):
        if feed_eager_tensors:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        schema = schema_utils.schema_from_feature_spec(feature_spec)
        converter = impl_helper.make_tensor_to_arrow_converter(schema)
        feed_dict_local = copy.copy(feed_dict)
        if feed_eager_tensors:
            for key, value in six.iteritems(feed_dict_local):
                if isinstance(value, tf.compat.v1.SparseTensorValue):
                    feed_dict_local[key] = tf.sparse.SparseTensor.from_value(
                        value)
                else:
                    feed_dict_local[key] = tf.constant(value)
        arrow_columns, arrow_schema = impl_helper.convert_to_arrow(
            schema, converter, feed_dict_local)
        record_batch = pa.RecordBatch.from_arrays(arrow_columns, arrow_schema)

        # Merge and flatten expected instance dicts.
        expected = collections.defaultdict(list)
        for instance_dict in instances:
            for key, value in instance_dict.items():
                expected[key].append(np.ravel(value))
        actual = record_batch.to_pydict()
        self.assertEqual(len(actual), len(expected))
        for key, expected_value in expected.items():
            # Floating-point error breaks exact equality for some floating values.
            # However, the approximate equality testing fails on strings.
            if np.issubdtype(expected_value[0].dtype, np.number):
                self.assertAllClose(actual[key], expected_value)
            else:
                np.testing.assert_equal(actual[key], expected_value)
    def test_get_analysis_cache_entry_keys(self, use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        full_dataset_keys = ['a', 'b']

        def preprocessing_fn(inputs):
            return {'x': tft.scale_to_0_1(inputs['x'])}

        mocked_cache_entry_key = 'A'

        def mocked_make_cache_entry_key(_):
            return mocked_cache_entry_key

        feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        with mock.patch(
                'tensorflow_transform.beam.analysis_graph_builder.'
                'analyzer_cache.make_cache_entry_key',
                side_effect=mocked_make_cache_entry_key):
            cache_entry_keys = (
                analysis_graph_builder.get_analysis_cache_entry_keys(
                    preprocessing_fn,
                    specs,
                    full_dataset_keys,
                    force_tf_compat_v1=use_tf_compat_v1))

        dot_string = nodes.get_dot_graph(
            [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key])
Beispiel #5
0
    def test_supply_missing_tensor_inputs(self, batch_size, dtype):
        test_case.skip_if_not_tf2('Tensorflow 2.x required.')

        @tf.function(input_signature=[{
            'x_1':
            tf.TensorSpec([None], dtype=tf.int32),
            'x_2':
            tf.TensorSpec([None], dtype=dtype),
        }])
        def foo(inputs):
            return inputs

        conc_fn = foo.get_concrete_function()
        # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves
        # the structure of the first arg, which for `foo` is `inputs`.
        structured_inputs = tf.nest.pack_sequence_as(
            conc_fn.structured_input_signature[0][0],
            conc_fn.inputs,
            expand_composites=True)
        missing_keys = ['x_2']
        result = tf2_utils.supply_missing_inputs(structured_inputs, batch_size,
                                                 missing_keys)

        self.assertCountEqual(missing_keys, result.keys())
        self.assertIsInstance(result['x_2'], tf.Tensor)
        self.assertEqual((batch_size, ), result['x_2'].shape)
        self.assertEqual(dtype, result['x_2'].dtype)
    def test_global_annotation(self, use_compat_v1):
        # pylint: enable=g-import-not-at-top
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            # Annotate an arbitrary proto at the schema level (not sure what global
            # schema boundaries would mean, but hey I'm just a test).
            boundaries = tf.constant([[1.0]])
            message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name
            sizes = tf.expand_dims([tf.size(boundaries)], axis=0)
            message_proto = tf.raw_ops.EncodeProto(
                sizes=sizes,
                values=[tf.cast(boundaries, tf.float32)],
                field_names=['boundaries'],
                message_type=message_type)[0]
            type_url = os.path.join('type.googleapis.com', message_type)
            schema_inference.annotate(type_url, message_proto)
            return {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
                'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.annotation.extra_metadata, 1)
        for annotation in schema.annotation.extra_metadata:
            # Extract the annotated message and validate its contents
            message = annotations_pb2.BucketBoundaries()
            annotation.Unpack(message)
            self.assertAllClose(list(message.boundaries), [1])
    def test_vocab_annotation(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            analyzers._maybe_annotate_vocab_metadata(
                'file1', tf.constant(100, dtype=tf.int64),
                tf.constant(75, dtype=tf.int64))
            analyzers._maybe_annotate_vocab_metadata(
                'file2', tf.constant(200, dtype=tf.int64),
                tf.constant(175, dtype=tf.int64))
            return {
                'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.annotation.extra_metadata, 2)
        unfiltered_sizes = {}
        filtered_sizes = {}
        for annotation in schema.annotation.extra_metadata:
            message = annotations_pb2.VocabularyMetadata()
            annotation.Unpack(message)
            unfiltered_sizes[
                message.file_name] = message.unfiltered_vocabulary_size
            filtered_sizes[
                message.file_name] = message.filtered_vocabulary_size
        self.assertDictEqual(unfiltered_sizes, {'file1': 100, 'file2': 200})
        self.assertDictEqual(filtered_sizes, {'file1': 75, 'file2': 175})
 def test_to_instance_dicts(self, feature_spec, instances, record_batch,
                            feed_dict, feed_eager_tensors):
     del record_batch
     if feed_eager_tensors:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     schema = schema_utils.schema_from_feature_spec(feature_spec)
     feed_dict_local = (_eager_tensor_from_values(feed_dict)
                        if feed_eager_tensors else copy.copy(feed_dict))
     result = impl_helper.to_instance_dicts(schema, feed_dict_local)
     np.testing.assert_equal(instances, result)
 def test_infer_feature_schema_bad_rank(self, use_compat_v1):
     if not use_compat_v1:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     inputs = {'x': 0}
     input_signature = {'x': tf.TensorSpec([], dtype=tf.float32)}
     with self.assertRaises(ValueError):
         self._get_schema(_make_tensors,
                          use_compat_v1,
                          inputs=inputs,
                          input_signature=input_signature)
 def test_convert_to_arrow(self, feature_spec, instances, record_batch,
                           feed_dict, feed_eager_tensors):
     del instances
     if feed_eager_tensors:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     schema = schema_utils.schema_from_feature_spec(feature_spec)
     converter = impl_helper.make_tensor_to_arrow_converter(schema)
     feed_dict_local = (_eager_tensor_from_values(feed_dict)
                        if feed_eager_tensors else copy.copy(feed_dict))
     arrow_columns, arrow_schema = impl_helper.convert_to_arrow(
         schema, converter, feed_dict_local)
     actual = pa.RecordBatch.from_arrays(arrow_columns, schema=arrow_schema)
     expected = pa.RecordBatch.from_arrays(list(record_batch.values()),
                                           names=list(record_batch.keys()))
     np.testing.assert_equal(actual.to_pydict(), expected.to_pydict())
Beispiel #11
0
 def test_to_instance_dicts(self, feature_spec, instances, feed_dict,
                            feed_eager_tensors):
     if feed_eager_tensors:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     schema = schema_utils.schema_from_feature_spec(feature_spec)
     feed_dict_local = copy.copy(feed_dict)
     if feed_eager_tensors:
         for key, value in six.iteritems(feed_dict_local):
             if isinstance(value, tf.compat.v1.SparseTensorValue):
                 feed_dict_local[key] = tf.sparse.SparseTensor.from_value(
                     value)
             else:
                 feed_dict_local[key] = tf.constant(value)
     np.testing.assert_equal(
         instances, impl_helper.to_instance_dicts(schema, feed_dict_local))
    def setUpClass(cls):
        test_case.skip_if_not_tf2('Tensorflow 2.x required.')
        input_specs = {
            'x': tf.TensorSpec([
                None,
            ], dtype=tf.float32)
        }

        def preprocessing_fn(inputs):
            output = (inputs['x'] - 2.0) / 5.0
            return {'x_scaled': output}

        cls._saved_model_path_v1 = _create_test_saved_model(
            True, input_specs, preprocessing_fn, 'export_v1')
        cls._saved_model_path_v2 = _create_test_saved_model(
            False, input_specs, preprocessing_fn, 'export_v2')
Beispiel #13
0
  def test_object_tracker(self):
    test_case.skip_if_not_tf2('Tensorflow 2.x required')

    trackable_object = base.Trackable()

    @tf.function
    def preprocessing_fn():
      _ = annotators.make_and_track_object(lambda: trackable_object)
      return 1

    object_tracker = annotators.ObjectTracker()
    with annotators.object_tracker_scope(object_tracker):
      _ = preprocessing_fn()

    self.assertLen(object_tracker.trackable_objects, 1)
    self.assertEqual(trackable_object, object_tracker.trackable_objects[0])
Beispiel #14
0
 def test_make_feed_list(self, feature_spec, instances, feed_dict,
                         produce_eager_tensors):
     if produce_eager_tensors:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     schema = schema_utils.schema_from_feature_spec(feature_spec)
     feature_names = list(feature_spec.keys())
     expected_feed_list = [feed_dict[key] for key in feature_names]
     evaluated_feed_list = impl_helper.make_feed_list(
         feature_names,
         schema,
         instances,
         produce_eager_tensors=produce_eager_tensors)
     np.testing.assert_equal(
         evaluated_feed_list if not produce_eager_tensors else
         _get_value_from_eager_tensors(evaluated_feed_list),
         expected_feed_list)
Beispiel #15
0
    def test_analyze_in_place_with_analyzers_raises_error(
            self, force_tf_compat_v1):
        if not force_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(inputs):
            return {'x_add_1': analyzers.mean(inputs['x'])}

        feature_spec = {'x': tf.io.FixedLenFeature([], tf.int64)}
        type_spec = {
            'x': tf.TensorSpec(dtype=tf.int64, shape=[
                None,
            ])
        }
        output_path = os.path.join(self.get_temp_dir(), self._testMethodName)
        with self.assertRaisesRegexp(RuntimeError,
                                     'analyzers found when tracing'):
            impl_helper.analyze_in_place(preprocessing_fn, force_tf_compat_v1,
                                         feature_spec, type_spec, output_path)
Beispiel #16
0
    def setUpClass(cls):
        test_case.skip_if_not_tf2('Tensorflow 2.x required.')
        input_specs = {
            'x': tf.TensorSpec([
                None,
            ], dtype=tf.float32)
        }

        def foo(inputs):
            output = (inputs['x'] - 2.0) / 5.0
            return {'x_scaled': output}

        cls._saved_model_path_v1 = _create_test_saved_model(
            True, input_specs, foo, 'export_v1')
        cls._saved_model_loader_v1 = saved_transform_io_v2.SavedModelLoader(
            cls._saved_model_path_v1)
        cls._saved_model_path_v2 = _create_test_saved_model(
            False, input_specs, foo, 'export_v2')
        cls._saved_model_loader_v2 = saved_transform_io_v2.SavedModelLoader(
            cls._saved_model_path_v2)
    def test_column_inference(self, preprocessing_fn,
                              expected_analyze_input_columns,
                              expected_transform_input_columns,
                              force_tf_compat_v1):
        if not force_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
            specs = _TYPE_SPEC
        else:
            specs = _FEATURE_SPEC

        analyze_input_columns = (
            inspect_preprocessing_fn.get_analyze_input_columns(
                preprocessing_fn, specs, force_tf_compat_v1))
        transform_input_columns = (
            inspect_preprocessing_fn.get_transform_input_columns(
                preprocessing_fn, specs, force_tf_compat_v1))
        self.assertCountEqual(analyze_input_columns,
                              expected_analyze_input_columns)
        self.assertCountEqual(transform_input_columns,
                              expected_transform_input_columns)
 def test_infer_feature_schema(self,
                               make_tensors_fn,
                               feature_spec,
                               use_compat_v1,
                               domains=None,
                               create_session=False):
     if not use_compat_v1:
         test_case.skip_if_not_tf2('Tensorflow 2.x required')
     x_val = '0' if feature_spec['x'].dtype == tf.string else 0
     inputs = {'x': [x_val]}
     input_signature = {
         'x': tf.TensorSpec([None], dtype=feature_spec['x'].dtype)
     }
     schema = self._get_schema(make_tensors_fn,
                               use_compat_v1,
                               inputs=inputs,
                               input_signature=input_signature,
                               create_session=create_session)
     expected_schema = schema_utils.schema_from_feature_spec(
         feature_spec, domains)
     self.assertEqual(schema, expected_schema)
    def test_bucketization_annotation(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            inputs = {
                'foo': tf.convert_to_tensor([0, 1, 2, 3]),
                'bar': tf.convert_to_tensor([0, 2, 0, 2]),
            }
            boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]),
                                            axis=0)
            boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]),
                                            axis=0)
            outputs = {}
            # tft.apply_buckets will annotate the feature in the output schema to
            # indicate the bucket boundaries that were applied.
            outputs['Bucketized_foo'] = mappers.apply_buckets(
                inputs['foo'], boundaries_foo)
            outputs['Bucketized_bar'] = mappers.apply_buckets(
                inputs['bar'], boundaries_bar)
            return outputs

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        self.assertLen(schema.feature, 2)
        for feature in schema.feature:
            self.assertLen(feature.annotation.extra_metadata, 1)
            for annotation in feature.annotation.extra_metadata:

                # Extract the annotated message and validate its contents
                message = annotations_pb2.BucketBoundaries()
                annotation.Unpack(message)
                if feature.name == 'Bucketized_foo':
                    self.assertAllClose(list(message.boundaries), [.5, 1.5])
                elif feature.name == 'Bucketized_bar':
                    self.assertAllClose(list(message.boundaries), [.1, .2])
                else:
                    raise RuntimeError('Unexpected features in schema')
    def test_annotate_asset(self, use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def foo():
            annotators.annotate_asset('scope/my_key', 'scope/my_value')
            annotators.annotate_asset('my_key2', 'should_be_replaced')
            annotators.annotate_asset('my_key2', 'my_value2')

        if use_tf_compat_v1:
            with tf.Graph().as_default() as graph:
                foo()
        else:
            graph = tf.function(foo).get_concrete_function().graph

        self.assertDictEqual(annotators.get_asset_annotations(graph), {
            'my_key': 'my_value',
            'my_key2': 'my_value2'
        })

        annotators.clear_asset_annotations(graph)
        self.assertDictEqual(annotators.get_asset_annotations(graph), {})
    def test_get_analysis_dataset_keys(self, preprocessing_fn,
                                       full_dataset_keys, cached_dataset_keys,
                                       expected_dataset_keys,
                                       use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        full_dataset_keys = [
            analysis_graph_builder.analyzer_cache.DatasetKey(k)
            for k in full_dataset_keys
        ]
        # We force all dataset keys with entries in the cache dict will have a cache
        # hit.
        mocked_cache_entry_key = b'M'
        input_cache = {
            key: {
                mocked_cache_entry_key: 'C'
            }
            for key in cached_dataset_keys
        }
        feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)}
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        with mock.patch(
                'tensorflow_transform.beam.analysis_graph_builder.'
                'analyzer_cache.make_cache_entry_key',
                return_value=mocked_cache_entry_key):
            dataset_keys = (analysis_graph_builder.get_analysis_dataset_keys(
                preprocessing_fn,
                specs,
                full_dataset_keys,
                input_cache,
                force_tf_compat_v1=use_tf_compat_v1))

        dot_string = nodes.get_dot_graph(
            [analysis_graph_builder._ANALYSIS_GRAPH]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertCountEqual(expected_dataset_keys, dataset_keys)
Beispiel #22
0
    def test_analyze_in_place(self, force_tf_compat_v1):
        if not force_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(inputs):
            return {'x_add_1': inputs['x'] + 1}

        feature_spec = {'x': tf.io.FixedLenFeature([], tf.int64)}
        type_spec = {
            'x': tf.TensorSpec(dtype=tf.int64, shape=[
                None,
            ])
        }
        output_path = os.path.join(self.get_temp_dir(), self._testMethodName)
        impl_helper.analyze_in_place(preprocessing_fn, force_tf_compat_v1,
                                     feature_spec, type_spec, output_path)

        tft_output = TFTransformOutput(output_path)
        expected_value = np.array([2], dtype=np.int64)
        if force_tf_compat_v1:
            with tf.Graph().as_default() as graph:
                with tf.compat.v1.Session(graph=graph).as_default():
                    transformed_features = tft_output.transform_raw_features(
                        {'x': tf.constant([1], dtype=tf.int64)})
                    transformed_value = transformed_features['x_add_1'].eval()
        else:
            transformed_features = tft_output.transform_raw_features(
                {'x': tf.constant([1], dtype=tf.int64)})
            transformed_value = transformed_features['x_add_1'].numpy()
        self.assertEqual(transformed_value, expected_value)

        transformed_feature_spec = tft_output.transformed_feature_spec()
        expected_feature_spec = feature_spec = {
            'x_add_1': tf.io.FixedLenFeature([], tf.int64)
        }
        self.assertEqual(transformed_feature_spec, expected_feature_spec)
    def test_build(self, feature_spec, preprocessing_fn,
                   expected_dot_graph_str, expected_dot_graph_str_tf2,
                   use_tf_compat_v1):
        if not use_tf_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')
        specs = (feature_spec if use_tf_compat_v1 else
                 impl_helper.get_type_specs_from_feature_specs(feature_spec))
        graph, structured_inputs, structured_outputs = (
            impl_helper.trace_preprocessing_function(
                preprocessing_fn,
                specs,
                use_tf_compat_v1=use_tf_compat_v1,
                base_temp_dir=os.path.join(self.get_temp_dir(),
                                           self._testMethodName)))
        transform_fn_future, unused_cache = analysis_graph_builder.build(
            graph, structured_inputs, structured_outputs)

        dot_string = nodes.get_dot_graph([transform_fn_future]).to_string()
        self.WriteRenderedDotFile(dot_string)
        self.assertMultiLineEqual(
            msg='Result dot graph is:\n{}'.format(dot_string),
            first=dot_string,
            second=(expected_dot_graph_str
                    if use_tf_compat_v1 else expected_dot_graph_str_tf2))
Beispiel #24
0
 def setUpClass(cls):
     tft_test_case.skip_if_not_tf2('Tensorflow 2.x required.')
     cls._test_saved_model = _create_test_saved_model()
     cls._saved_model_loader = saved_transform_io_v2.SavedModelLoader(
         cls._test_saved_model)
    def test_infer_feature_schema_with_ragged_tensor(self, use_compat_v1):
        if not use_compat_v1:
            test_case.skip_if_not_tf2('Tensorflow 2.x required')

        def preprocessing_fn(_):
            return {
                'foo':
                tf.RaggedTensor.from_row_splits(values=tf.constant(
                    [3, 1, 4, 1, 5, 9, 2, 6], tf.int64),
                                                row_splits=[0, 4, 4, 7, 8, 8]),
                'bar':
                tf.RaggedTensor.from_row_splits(
                    values=tf.RaggedTensor.from_row_splits(
                        values=tf.ones([5], tf.float32),
                        row_splits=[0, 2, 3, 5]),
                    row_splits=[0, 0, 0, 2, 2, 4]),
                'baz':
                tf.RaggedTensor.from_row_splits(values=tf.ones([5, 3],
                                                               tf.float32),
                                                row_splits=[0, 2, 3, 5]),
                'qux':
                tf.RaggedTensor.from_row_splits(
                    values=tf.RaggedTensor.from_row_splits(
                        values=tf.ones([5, 7], tf.float32),
                        row_splits=[0, 2, 3, 5]),
                    row_splits=[0, 0, 0, 2, 2, 4]),
            }

        schema = self._get_schema(preprocessing_fn,
                                  use_compat_v1,
                                  create_session=True)
        if common_types.is_ragged_feature_available():
            expected_schema_ascii = """
        feature {
          name: "bar$ragged_values"
          type: FLOAT
        }
        feature {
          name: "bar$row_lengths_1"
          type: INT
        }
        feature {
          name: "baz$ragged_values"
          type: FLOAT
        }
        feature {
          name: "foo$ragged_values"
          type: INT
        }
        feature {
          name: "qux$ragged_values"
          type: FLOAT
        }
        feature {
          name: "qux$row_lengths_1"
          type: INT
        }
        tensor_representation_group {
          key: ""
          value {
            tensor_representation {
              key: "foo"
              value {
                ragged_tensor {
                  feature_path { step: "foo$ragged_values" }
                }
              }
            }
            tensor_representation {
              key: "bar"
              value {
                ragged_tensor {
                  feature_path { step: "bar$ragged_values" }
                  partition { row_length: "bar$row_lengths_1"}
                }
              }
            }
            tensor_representation {
              key: "baz"
              value {
                ragged_tensor {
                  feature_path { step: "baz$ragged_values" }
                  partition { uniform_row_length: 3}
                }
              }
            }
            tensor_representation {
              key: "qux"
              value {
                ragged_tensor {
                  feature_path { step: "qux$ragged_values" }
                  partition { row_length: "qux$row_lengths_1"}
                  partition { uniform_row_length: 7}
                }
              }
            }
          }
        }
        """
        else:
            expected_schema_ascii = """
        feature {
          name: "bar"
          type: FLOAT
          annotation {
            tag: "ragged_tensor"
          }
        }
        feature {
          name: "baz"
          type: FLOAT
          annotation {
            tag: "ragged_tensor"
          }
        }
        feature {
          name: "foo"
          type: INT
          annotation {
            tag: "ragged_tensor"
          }
        }
        feature {
          name: "qux"
          type: FLOAT
          annotation {
            tag: "ragged_tensor"
          }
        }
        """
        expected_schema = text_format.Parse(expected_schema_ascii,
                                            schema_pb2.Schema())
        schema_utils_legacy.set_generate_legacy_feature_spec(
            expected_schema, False)
        self.assertProtoEquals(expected_schema, schema)
        if not common_types.is_ragged_feature_available():
            with self.assertRaisesRegexp(
                    ValueError, 'Feature "bar" had tag "ragged_tensor"'):
                schema_utils.schema_as_feature_spec(schema)
 def setUp(self):
     super(CensusExampleV2Test, self).setUp()
     tft_test_case.skip_if_not_tf2('Tensorflow 2.x required.')