def _get_feature_types_from_schema( schema: schema_pb2.Schema, column_names: List[types.FeatureName]) -> List[csv_decoder.ColumnInfo]: """Get statistics feature types from the input schema.""" schema_type_to_stats_type = { schema_pb2.INT: csv_decoder.ColumnType.INT, schema_pb2.FLOAT: csv_decoder.ColumnType.FLOAT, schema_pb2.BYTES: csv_decoder.ColumnType.STRING } feature_type_map = {} for feature in schema.feature: feature_type_map[feature.name] = schema_type_to_stats_type[feature.type] return [ csv_decoder.ColumnInfo(col_name, feature_type_map.get(col_name, None)) for col_name in column_names ]
def _check_types(actual): self.assertLen(actual, 1) self.assertCountEqual([ csv_decoder.ColumnInfo(n, t) for n, t in zip(column_names, expected_types) ], actual[0])