def make_sequence_example_parse_fn( feature_config: FeatureConfig, preprocessing_map: PreprocessingMap, max_sequence_size: int = 25, required_fields_only: bool = False, pad_sequence: bool = True, ) -> tf.function: """ Create a parse function using the SequenceExample features spec Parameters ---------- feature_config : `FeatureConfig` FeatureConfig object defining context and sequence feature information preprocessing_map : int map of preprocessing feature functions max_sequence_size : int Maximum number of sequence per query. Used for padding required_fields_only : bool, optional Whether to only use required fields from the feature_config pad_sequence : bool Whether to pad sequence Returns ------- `tf.function` Parsing function that takes in a serialized SequenceExample message and extracts a feature dictionary for context and sequence features """ context_features_spec = dict() sequence_features_spec = dict() for feature_info in feature_config.get_all_features(): serving_info = feature_info["serving_info"] if not required_fields_only or serving_info.get("required", feature_info["trainable"]): feature_name = feature_info["name"] dtype = feature_info["dtype"] default_value = feature_config.get_default_value(feature_info) if feature_info["tfrecord_type"] == SequenceExampleTypeKey.CONTEXT: context_features_spec[feature_name] = io.FixedLenFeature( [], dtype, default_value=default_value ) elif feature_info["tfrecord_type"] == SequenceExampleTypeKey.SEQUENCE: sequence_features_spec[feature_name] = io.VarLenFeature(dtype=dtype) @tf.function def _parse_sequence_example_fn(sequence_example_proto): """ Parse the input `tf.SequenceExample` proto using the features_spec Parameters ---------- sequence_example_proto : string serialized tfrecord SequenceExample protobuf message Returns ------- features : dict parsed features as `tf.Tensor` objects extracted from the protobuf labels : `tf.Tensor` parsed label as a `tf.Tensor` object extracted from the protobuf """ context_features, sequence_features = io.parse_single_sequence_example( serialized=sequence_example_proto, context_features=context_features_spec, sequence_features=sequence_features_spec, ) features_dict = dict() # Handle context features for feature_info in feature_config.get_context_features(): feature_node_name = feature_info.get("node_name", feature_info["name"]) default_tensor = tf.constant( value=feature_config.get_default_value(feature_info), dtype=feature_info["dtype"], ) feature_tensor = context_features.get(feature_info["name"], default_tensor) feature_tensor = tf.expand_dims(feature_tensor, axis=0) # Preprocess features feature_tensor = preprocess_feature(feature_tensor, feature_info, preprocessing_map) features_dict[feature_node_name] = feature_tensor # Define mask to identify padded sequence if required_fields_only and not feature_config.get_rank("serving_info")["required"]: """ Define dummy mask if the rank field is not a required field for serving NOTE: This masks all max_sequence_size as 1 as there is no real way to know the number of sequence in the query. There is no predefined required field, and hence we would need to do a full pass of all features to find the record shape. This approach might be unstable if different features have different shapes. Hence we just mask all sequence """ features_dict["mask"] = tf.constant( value=1, shape=[max_sequence_size], dtype=feature_config.get_rank("dtype") ) sequence_size = tf.constant(max_sequence_size, dtype=tf.int64) else: # Typically used at training time, to pad/clip to a fixed number of sequence per query # Use rank as a reference tensor to infer shape/sequence_size in query reference_tensor = sequence_features.get(feature_config.get_rank(key="node_name")) # Add mask for identifying padded sequence mask = tf.ones_like(sparse.to_dense(sparse.reset_shape(reference_tensor))) sequence_size = tf.cast(tf.reduce_sum(mask), tf.int64) if pad_sequence: mask = tf.expand_dims(mask, axis=-1) def crop_fn(): tf.print("\n[WARN] Bad query found. Number of sequence : ", tf.shape(mask)[1]) return image.crop_to_bounding_box( mask, offset_height=0, offset_width=0, target_height=1, target_width=max_sequence_size, ) mask = tf.cond( tf.shape(mask)[1] <= max_sequence_size, # Pad if there are missing sequence lambda: image.pad_to_bounding_box( mask, offset_height=0, offset_width=0, target_height=1, target_width=max_sequence_size, ), # Crop if there are extra sequence crop_fn, ) mask = tf.squeeze(mask) else: mask = tf.squeeze(mask, axis=0) # Check validity of mask tf.debugging.assert_greater(sequence_size, tf.constant(0, dtype=tf.int64)) features_dict["mask"] = mask sequence_size = max_sequence_size if pad_sequence else sequence_size # Pad sequence features to max_sequence_size for feature_info in feature_config.get_sequence_features(): feature_node_name = feature_info.get("node_name", feature_info["name"]) default_tensor = tf.fill( value=tf.constant( value=feature_config.get_default_value(feature_info), dtype=feature_info["dtype"], ), dims=[max_sequence_size if pad_sequence else sequence_size], ) feature_tensor = sequence_features.get(feature_info["name"], default_tensor) if isinstance(feature_tensor, sparse.SparseTensor): feature_tensor = sparse.reset_shape( feature_tensor, new_shape=[1, max_sequence_size if pad_sequence else sequence_size], ) feature_tensor = sparse.to_dense(feature_tensor) feature_tensor = tf.squeeze(feature_tensor, axis=0) # Preprocess features feature_tensor = preprocess_feature(feature_tensor, feature_info, preprocessing_map) features_dict[feature_node_name] = feature_tensor labels = features_dict.pop(feature_config.get_label(key="name")) return features_dict, labels return _parse_sequence_example_fn
def make_example_parse_fn( feature_config: FeatureConfig, preprocessing_map: PreprocessingMap, required_fields_only: bool = False, ) -> tf.function: """ Create a parse function using the Example features spec Parameters ---------- feature_config : `FeatureConfig` FeatureConfig object defining context and sequence feature information preprocessing_map : `PreprocessingMap` object map of preprocessing feature functions required_fields_only : bool, optional Whether to only use required fields from the feature_config Returns ------- `tf.function` Parsing function that takes in a serialized Example message and extracts a feature dictionary """ features_spec = dict() for feature_info in feature_config.get_all_features(): serving_info = feature_info["serving_info"] if not required_fields_only or serving_info.get("required", feature_info["trainable"]): feature_name = feature_info["name"] dtype = feature_info["dtype"] default_value = feature_config.get_default_value(feature_info) features_spec[feature_name] = io.FixedLenFeature( [], dtype, default_value=default_value ) print(features_spec) @tf.function def _parse_example_fn(example_proto): """ Parse the input `tf.Example` proto using the features_spec Parameters ---------- example_proto : string serialized tfrecord Example protobuf message Returns ------- features : dict parsed features as `tf.Tensor` objects extracted from the protobuf labels : `tf.Tensor` parsed label as a `tf.Tensor` object extracted from the protobuf """ features = io.parse_single_example(serialized=example_proto, features=features_spec) features_dict = dict() # Process all features, including label. for feature_info in feature_config.get_all_features(): feature_node_name = feature_info.get("node_name", feature_info["name"]) default_tensor = tf.constant( value=feature_config.get_default_value(feature_info), dtype=feature_info["dtype"], ) feature_tensor = features.get(feature_info["name"], default_tensor) feature_tensor = tf.expand_dims(feature_tensor, axis=0) feature_tensor = preprocess_feature(feature_tensor, feature_info, preprocessing_map) features_dict[feature_node_name] = feature_tensor labels = features_dict.pop(feature_config.get_label(key="name")) return features_dict, labels return _parse_example_fn
def make_example_parse_fn( feature_config: FeatureConfig, preprocessing_map: PreprocessingMap, required_fields_only: bool = False, ) -> tf.function: """ Create a parse function using the Example features spec Args: feature_config: FeatureConfig object defining context and sequence features max_sequence_size: Maximum number of sequence per query. Used for padding. required_fields_only: Whether to only use required fields from the feature_config pad_sequence: Whether to pad sequence """ features_spec = dict() for feature_info in feature_config.get_all_features(): serving_info = feature_info["serving_info"] if not required_fields_only or serving_info.get( "required", feature_info["trainable"]): feature_name = feature_info["name"] dtype = feature_info["dtype"] default_value = feature_config.get_default_value(feature_info) features_spec[feature_name] = io.FixedLenFeature( [], dtype, default_value=default_value) print(features_spec) @tf.function def _parse_example_fn(example_proto): """ Parse the input `tf.Example` proto using the features_spec Args: example_proto: tfrecord Example protobuf data Returns: features: parsed features extracted from the protobuf labels: parsed label extracted from the protobuf """ features = io.parse_single_example(serialized=example_proto, features=features_spec) features_dict = dict() # Process all features, including label. for feature_info in feature_config.get_all_features(): feature_node_name = feature_info.get("node_name", feature_info["name"]) default_tensor = tf.constant( value=feature_config.get_default_value(feature_info), dtype=feature_info["dtype"], ) feature_tensor = features.get(feature_info["name"], default_tensor) feature_tensor = tf.expand_dims(feature_tensor, axis=0) feature_tensor = preprocess_feature(feature_tensor, feature_info, preprocessing_map) features_dict[feature_node_name] = feature_tensor labels = features_dict.pop(feature_config.get_label(key="name")) return features_dict, labels return _parse_example_fn