def testInferFeatureSchema(self): columns = { 'a': api._InputColumn(tf.placeholder(tf.float32, (None, )), None), 'b': api._InputColumn(tf.placeholder(tf.string, (1, 2, 3)), None), 'c': api._InputColumn(tf.placeholder(tf.int64, None), None) } schema = impl_helper.infer_feature_schema(columns) expected_schema = sch.Schema( column_schemas={ 'a': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32), sch.LogicalShape([])), sch.FixedColumnRepresentation()), 'b': sch.ColumnSchema( sch.LogicalColumnSchema( sch.dtype_to_domain(tf.string), sch.LogicalShape([sch.Axis(2), sch.Axis(3)])), sch.FixedColumnRepresentation()), 'c': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64), sch.LogicalShape(None)), sch.FixedColumnRepresentation()) }) self.assertEqual(schema, expected_schema)
def testInferFeatureSchemaWithSession(self): with tf.Graph().as_default() as graph: tensors = { 'a': tf.placeholder(tf.float32, (None, )), 'b': tf.placeholder(tf.string, (1, 2, 3)), 'c': tf.placeholder(tf.int64, (None, )) } schema_inference.set_tensor_schema_override( tensors['c'], tf.constant(5), tf.constant(6)) with tf.Session(graph=graph) as session: schema = schema_inference.infer_feature_schema( tensors, graph, session) expected_schema = dataset_schema.Schema( column_schemas={ 'a': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()), 'b': dataset_schema.ColumnSchema( tf.string, [2, 3], dataset_schema.FixedColumnRepresentation()), 'c': dataset_schema.ColumnSchema( dataset_schema.IntDomain( tf.int64, 5, 6, is_categorical=True), [], dataset_schema.FixedColumnRepresentation()) }) self.assertEqual(schema, expected_schema)
def test_schema_equality(self): schema1 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( tf.bool, [1], sch.FixedColumnRepresentation(False)), 'var_float': sch.ColumnSchema( tf.float32, None, sch.ListColumnRepresentation()) }) schema2 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( tf.bool, [1], sch.FixedColumnRepresentation(False)), 'var_float': sch.ColumnSchema( tf.float32, None, sch.ListColumnRepresentation()) }) schema3 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( tf.bool, [1], sch.FixedColumnRepresentation(False)), 'var_float': sch.ColumnSchema( tf.float64, None, sch.ListColumnRepresentation()) }) schema4 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( tf.bool, [1], sch.FixedColumnRepresentation(False)) }) self.assertEqual(schema1, schema2) self.assertNotEqual(schema1, schema3) self.assertNotEqual(schema1, schema4)
def testInferFeatureSchema(self): d = tf.placeholder(tf.int64, None) tensors = { 'a': tf.placeholder(tf.float32, (None, )), 'b': tf.placeholder(tf.string, (1, 2, 3)), 'c': tf.placeholder(tf.int64, None), 'd': d } d_column_schema = sch.ColumnSchema(tf.int64, [1, 2, 3], sch.FixedColumnRepresentation()) api.set_column_schema(d, d_column_schema) schema = impl_helper.infer_feature_schema(tensors) expected_schema = sch.Schema( column_schemas={ 'a': sch.ColumnSchema(tf.float32, [], sch.FixedColumnRepresentation()), 'b': sch.ColumnSchema(tf.string, [2, 3], sch.FixedColumnRepresentation()), 'c': sch.ColumnSchema(tf.int64, None, sch.FixedColumnRepresentation()), 'd': sch.ColumnSchema(tf.int64, [1, 2, 3], sch.FixedColumnRepresentation()) }) self.assertEqual(schema, expected_schema)
def create_raw_metadata(): column_schemas = {} # ColumnSchema for numeric features column_schemas.update({ key: dataset_schema.ColumnSchema(tf.float32, [], dataset_schema.FixedColumnRepresentation()) for key in metadata.NUMERIC_FEATURE_NAMES }) # ColumnSchema for categorical features column_schemas.update({ key: dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation(default_value="null")) for key in metadata.CATEGORICAL_FEATURE_NAMES }) # ColumnSchema for target feature column_schemas[metadata.TARGET_FEATURE_NAME] = dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()) # Dataset Metadata raw_metadata = dataset_metadata.DatasetMetadata( dataset_schema.Schema(column_schemas)) return raw_metadata
def test_column_representation_equality(self): fixed1 = sch.FixedColumnRepresentation(1.1) fixed2 = sch.FixedColumnRepresentation(1.1) fixed3 = sch.FixedColumnRepresentation() list1 = sch.ListColumnRepresentation() list2 = sch.ListColumnRepresentation() sparse1 = sch.SparseColumnRepresentation('val', [ sch.SparseIndexField('idx1', False), sch.SparseIndexField('idx2', True) ]) sparse2 = sch.SparseColumnRepresentation('val', [ sch.SparseIndexField('idx1', False), sch.SparseIndexField('idx2', True) ]) sparse3 = sch.SparseColumnRepresentation('val', [ sch.SparseIndexField('idx1', False), sch.SparseIndexField('idx2', False) ]) self.assertEqual(fixed1, fixed2) self.assertNotEqual(fixed1, fixed3) self.assertNotEqual(fixed1, list1) self.assertNotEqual(fixed1, sparse1) self.assertEqual(list1, list2) self.assertNotEqual(list1, sparse1) self.assertEqual(sparse1, sparse2) self.assertNotEqual(sparse1, sparse3)
def main(p=None): def preprocessing_fn(inputs): """Preprocess input columns into transformed columns.""" x = inputs['x'] y = inputs['y'] s = inputs['s'] x_centered = x - tft.mean(x) y_normalized = tft.scale_to_0_1(y) s_integerized = tft.string_to_int(s) x_centered_times_y_normalized = (x_centered * y_normalized) return { 'x_centered': x_centered, 'y_normalized': y_normalized, 'x_centered_times_y_normalized': x_centered_times_y_normalized, 's_integerized': s_integerized } raw_data = [{ 'x': 1, 'y': 1, 's': 'hello' }, { 'x': 2, 'y': 2, 's': 'world' }, { 'x': 3, 'y': 3, 's': 'hello' }] # raw_data_p = p | beam.Create(raw_data) raw_data_metadata = dataset_metadata.DatasetMetadata( dataset_schema.Schema({ 's': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'y': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()), 'x': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()) })) with beam_impl.Context(temp_dir=tempfile.mkdtemp()): transformed_dataset, transform_fn = ( # pylint: disable=unused-variable (raw_data, raw_data_metadata) | beam_impl.AnalyzeAndTransformDataset(preprocessing_fn)) transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable pprint.pprint(transformed_data) (transformed_data | beam.io.WriteToText( '/Users/luoshixin/Personal/GCPStudy/src/tensorflow/tftransform/tmp' ))
def test_column_schema_equality(self): c1 = sch.ColumnSchema(tf.bool, [1], sch.FixedColumnRepresentation(False)) c2 = sch.ColumnSchema(tf.bool, [1], sch.FixedColumnRepresentation(False)) c3 = sch.ColumnSchema(tf.bool, [1], sch.FixedColumnRepresentation()) c4 = sch.ColumnSchema(tf.bool, [2], sch.FixedColumnRepresentation()) self.assertEqual(c1, c2) self.assertNotEqual(c1, c3) self.assertNotEqual(c3, c4)
def main(): def preprocessing_fn(inputs): """Preprocess input columns into transformed columns.""" x = inputs['x'] y = inputs['y'] s = inputs['s'] x_centered = x - tft.mean(x) y_normalized = tft.scale_to_0_1(y) s_integerized = tft.string_to_int(s) x_centered_times_y_normalized = (x_centered * y_normalized) return { 'x_centered': x_centered, 'y_normalized': y_normalized, 'x_centered_times_y_normalized': x_centered_times_y_normalized, 's_integerized': s_integerized } raw_data = [{ 'x': 1, 'y': 1, 's': 'hello' }, { 'x': 2, 'y': 2, 's': 'world' }, { 'x': 3, 'y': 3, 's': 'hello' }] raw_data_metadata = dataset_metadata.DatasetMetadata( dataset_schema.Schema({ 's': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'y': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()), 'x': dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()) })) with beam_impl.Context(temp_dir=tempfile.mkdtemp()): transform_fn = ((raw_data, raw_data_metadata) | beam_impl.AnalyzeDataset(preprocessing_fn)) transformed_dataset = (((raw_data, raw_data_metadata), transform_fn) | beam_impl.TransformDataset()) # pylint: disable=unused-variable transformed_data, transformed_metadata = transformed_dataset pprint.pprint(transformed_data)
def get_raw_metadata(): raw_metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema({ 'topic': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'raw_title': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'clean_title': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), })) return raw_metadata
def get_metadata(): from tensorflow_transform.tf_metadata import dataset_schema from tensorflow_transform.tf_metadata import dataset_metadata metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema({ 'title': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'content': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), 'topics': dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), })) return metadata
def _make_raw_schema(shape): schema = sch.Schema() schema.column_schemas['raw_a'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation())) schema.column_schemas['raw_b'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation())) schema.column_schemas['raw_label'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation())) return schema
def _make_transformed_schema(shape): schema = sch.Schema() schema.column_schemas['transformed_a'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation())) schema.column_schemas['transformed_b'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation())) schema.column_schemas['transformed_label'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation())) return schema
def get_metadata(): from tensorflow_transform.tf_metadata import dataset_schema from tensorflow_transform.tf_metadata import dataset_metadata metadata = dataset_metadata.DatasetMetadata( dataset_schema.Schema({ "id": dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()), "text": dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()) })) return metadata
def make_tft_input_metadata(schema): """Create tf-transform metadata from given schema.""" tft_schema = {} for col_schema in schema: col_type = col_schema['type'] col_name = col_schema['name'] if col_type == 'NUMBER': tft_schema[col_name] = dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation(default_value=0.0)) elif col_type in ['CATEGORY', 'TEXT', 'IMAGE_URL', 'KEY']: tft_schema[col_name] = dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation(default_value='')) return dataset_metadata.DatasetMetadata(dataset_schema.Schema(tft_schema))
def _from_feature_dict(feature_dict): """Translate a JSON feature dict into a `ColumnSchema`.""" domain = _from_domain_dict(feature_dict['domain']) axes = [] if 'fixedShape' in feature_dict: for axis in feature_dict['fixedShape']['axis']: # int() is needed because protobuf JSON encodes int64 as string axes.append(sch.Axis(int(axis.get('size')))) elif 'valueCount' in feature_dict: # Value_count always means a 1-D feature of unknown size. # We don't support value_count.min and value_count.max yet. axes.append(sch.Axis(None)) tf_options = feature_dict['parsingOptions']['tfOptions'] if tf_options.get('fixedLenFeature') is not None: default_value = None try: # int() is needed because protobuf JSON encodes int64 as string default_value = int(tf_options['fixedLenFeature']['intDefaultValue']) except KeyError: try: default_value = tf_options['fixedLenFeature']['stringDefaultValue'] except KeyError: try: default_value = tf_options['fixedLenFeature']['floatDefaultValue'] except KeyError: pass representation = sch.FixedColumnRepresentation(default_value) elif tf_options.get('varLenFeature') is not None: representation = sch.ListColumnRepresentation() else: raise ValueError('Could not interpret tfOptions: {}'.format(tf_options)) return sch.ColumnSchema(domain, axes, representation)
def testInferFeatureSchema(self): columns = { 'a': api._InputColumn(tf.placeholder(tf.float32, (None,)), None), 'b': api._InputColumn(tf.placeholder(tf.string, (1, 2, 3)), None), 'c': api._InputColumn(tf.placeholder(tf.int64, None), None) } schema = impl_helper.infer_feature_schema(columns) expected_schema = sch.Schema(column_schemas={ 'a': sch.ColumnSchema(tf.float32, [], sch.FixedColumnRepresentation()), 'b': sch.ColumnSchema(tf.string, [2, 3], sch.FixedColumnRepresentation()), 'c': sch.ColumnSchema(tf.int64, None, sch.FixedColumnRepresentation()) }) self.assertEqual(schema, expected_schema)
def test_feature_spec_unsupported_dtype(self): with self.assertRaisesRegexp(ValueError, 'invalid dtype'): sch.Schema({ 'fixed_float': sch.ColumnSchema(tf.float64, [], sch.FixedColumnRepresentation()) })
def testCreatePhasesWithUnwrappedLoop(self): # Test a preprocessing function with control flow. # # The loop represents # # i = 0 # while i < 10: # i += 1 # x += 1 # # We need to call an analyzer after the loop because only the transitive # parents of analyzers are inspected by create_phases def preprocessing_fn(inputs): def _subtract_ten(x): i = tf.constant(0) c = lambda i, x: tf.less(i, 10) b = lambda i, x: (tf.add(i, 1), tf.add(x, -1)) return tf.while_loop(c, b, [i, x])[1] scaled_to_0_1 = mappers.scale_to_0_1(_subtract_ten(inputs['x'])) return {'x_scaled': scaled_to_0_1} input_schema = sch.Schema({ 'x': sch.ColumnSchema(tf.int32, [], sch.FixedColumnRepresentation()) }) graph, _, _ = impl_helper.run_preprocessing_fn( preprocessing_fn, input_schema) with self.assertRaisesRegexp(ValueError, 'Cycle detected'): _ = impl_helper.create_phases(graph)
def testCreatePhasesWithLoop(self): # Test a preprocessing function with control flow. # # The loop represents # # i = 0 # while i < 10: # i += 1 # x += 1 # # To get an error in the case where apply_function is not called, we have # to call an analyzer first (see testCreatePhasesWithUnwrappedLoop). So # we also do so here. def preprocessing_fn(inputs): def _subtract_ten(x): i = tf.constant(0) c = lambda i, x: tf.less(i, 10) b = lambda i, x: (tf.add(i, 1), tf.add(x, -1)) return tf.while_loop(c, b, [i, x])[1] scaled_to_0_1 = mappers.scale_to_0_1( api.apply_function(_subtract_ten, inputs['x'])) return {'x_scaled': scaled_to_0_1} input_schema = sch.Schema({ 'x': sch.ColumnSchema(tf.int32, [], sch.FixedColumnRepresentation()) }) graph, _, _ = impl_helper.run_preprocessing_fn( preprocessing_fn, input_schema) phases = impl_helper.create_phases(graph) self.assertEqual(len(phases), 1) self.assertEqual(len(phases[0].analyzers), 2)
def _create_raw_metadata(): """Create a DatasetMetadata for the raw data.""" column_schemas = { key: dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()) for key in CATEGORICAL_FEATURE_KEYS } column_schemas.update({ key: dataset_schema.ColumnSchema( tf.float32, [], dataset_schema.FixedColumnRepresentation()) for key in NUMERIC_FEATURE_KEYS }) column_schemas[LABEL_KEY] = dataset_schema.ColumnSchema( tf.string, [], dataset_schema.FixedColumnRepresentation()) raw_data_metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema( column_schemas)) return raw_data_metadata
def _make_raw_schema(shape, should_add_unused_feature=False): schema = sch.Schema() schema.column_schemas['raw_a'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation(default_value=0))) schema.column_schemas['raw_b'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation(default_value=1))) schema.column_schemas['raw_label'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation(default_value=-1))) if should_add_unused_feature: schema.column_schemas['raw_unused'] = (sch.ColumnSchema( tf.int64, shape, sch.FixedColumnRepresentation(default_value=1))) return schema
def _make_transformed_schema(): schema = sch.Schema() schema.column_schemas['transformed_a'] = (sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation())) schema.column_schemas['transformed_b'] = (sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation())) schema.column_schemas['transformed_label'] = (sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation())) return schema
def test_feature_spec_unsupported_dtype(self): schema = sch.Schema() schema.column_schemas['fixed_float_with_default'] = (sch.ColumnSchema( tf.float64, [1], sch.FixedColumnRepresentation(0.0))) with self.assertRaisesRegexp( ValueError, 'tf.Example parser supports only types ' r'\[tf.string, tf.int64, tf.float32, tf.bool\]' ', so it is invalid to generate a feature_spec' ' with type tf.float64.'): schema.as_feature_spec()
def test_schema_equality(self): schema1 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation(False)), 'var_float': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32), sch.LogicalShape([sch.Axis(None)])), sch.ListColumnRepresentation()) }) schema2 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation(False)), 'var_float': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32), sch.LogicalShape([sch.Axis(None)])), sch.ListColumnRepresentation()) }) schema3 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation(False)), 'var_float': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float64), sch.LogicalShape([sch.Axis(None)])), sch.ListColumnRepresentation()) }) schema4 = sch.Schema(column_schemas={ 'fixed_bool_with_default': sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation(False)) }) self.assertEqual(schema1, schema2) self.assertNotEqual(schema1, schema3) self.assertNotEqual(schema1, schema4)
def test_schema_with_unsupported_dtype(self): with self.assertRaisesRegexp( ValueError, r'tf.Example parser supports only types \[tf.string, tf.int64, ' r'tf.float32, tf.bool\], so it is invalid to generate a feature_spec' r' with type tf.float64.'): sch.Schema( column_schemas={ 'fixed_float_with_default': sch.ColumnSchema(tf.float64, [1], sch.FixedColumnRepresentation([-1])) })
def test_column_schema_equality(self): c1 = sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation(False)) c2 = sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation(False)) c3 = sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(1)])), sch.FixedColumnRepresentation()) c4 = sch.ColumnSchema( sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool), sch.LogicalShape([sch.Axis(2)])), sch.FixedColumnRepresentation()) self.assertEqual(c1, c2) self.assertNotEqual(c1, c3) self.assertNotEqual(c3, c4)
def get_manually_created_schema(): """Provide a test schema built from scratch using the Schema classes.""" return sch.Schema({ # FixedLenFeatures 'fixed_categorical_int_with_range': sch.ColumnSchema(sch.IntDomain(tf.int64, -5, 10, True), [], sch.FixedColumnRepresentation()), 'fixed_int': sch.ColumnSchema(tf.int64, [5], sch.FixedColumnRepresentation()), 'fixed_float': sch.ColumnSchema(tf.float32, [5], sch.FixedColumnRepresentation()), 'fixed_string': sch.ColumnSchema(tf.string, [5], sch.FixedColumnRepresentation()), # VarLenFeatures 'var_int': sch.ColumnSchema(tf.int64, None, sch.ListColumnRepresentation()), 'var_float': sch.ColumnSchema(tf.float32, None, sch.ListColumnRepresentation()), 'var_string': sch.ColumnSchema(tf.string, None, sch.ListColumnRepresentation()) })
def testCreatePhasesWithDegenerateFunctionApplication(self): # Tests the case of a function whose inputs and outputs overlap. def preprocessing_fn(inputs): return {'index': api.apply_function(lambda x: x, inputs['a'])} input_schema = sch.Schema({ 'a': sch.ColumnSchema(tf.string, [], sch.FixedColumnRepresentation()) }) _, _ = impl_helper.run_preprocessing_fn(preprocessing_fn, input_schema) phases = impl_helper.create_phases() self.assertEqual(len(phases), 0)
def test_infer_column_schema_from_tensor(self): dense = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32, shape=[2, 2]) column_schema = sch.infer_column_schema_from_tensor(dense) expected_column_schema = sch.ColumnSchema( tf.float32, [2], sch.FixedColumnRepresentation()) self.assertEqual(expected_column_schema, column_schema) varlen = tf.sparse_placeholder(tf.string) column_schema = sch.infer_column_schema_from_tensor(varlen) expected_column_schema = sch.ColumnSchema( tf.string, [None], sch.ListColumnRepresentation()) self.assertEqual(expected_column_schema, column_schema)