Exemplo n.º 1
0
def get_parse_fn(
    tfrecord_type: str,
    feature_config: FeatureConfig,
    preprocessing_keys_to_fns: dict,
    max_sequence_size: int = 0,
    required_fields_only: bool = False,
    pad_sequence: bool = True,
):
    # Define preprocessing functions
    preprocessing_map = PreprocessingMap()
    preprocessing_map.add_fns(preprocessing_keys_to_fns)

    # Generate parsing function
    if tfrecord_type == TFRecordTypeKey.EXAMPLE:
        parse_fn = make_example_parse_fn(
            feature_config=feature_config,
            preprocessing_map=preprocessing_map,
            required_fields_only=required_fields_only,
        )
    elif tfrecord_type == TFRecordTypeKey.SEQUENCE_EXAMPLE:
        parse_fn = make_sequence_example_parse_fn(
            feature_config=feature_config,
            preprocessing_map=preprocessing_map,
            max_sequence_size=max_sequence_size,
            required_fields_only=required_fields_only,
            pad_sequence=pad_sequence,
        )
    else:
        raise KeyError(
            "Invalid TFRecord type specified: {}".format(tfrecord_type))

    return parse_fn
Exemplo n.º 2
0
def get_parse_fn(
    tfrecord_type: str,
    feature_config: FeatureConfig,
    preprocessing_keys_to_fns: dict,
    max_sequence_size: int = 0,
    required_fields_only: bool = False,
    pad_sequence: bool = True,
) -> tf.function:
    """
    Create a parsing function to extract features from serialized TFRecord data
    using the definition from the FeatureConfig

    Parameters
    ----------
    tfrecord_type: {"example", "sequence_example"}
        Type of TFRecord data to be loaded into a dataset
    feature_config: `FeatureConfig` object
        FeatureConfig object defining the features to be extracted
    preprocessing_keys_to_fns: dict of(str, function), optional
        dictionary of function names mapped to function definitions
        that can now be used for preprocessing while loading the
        TFRecordDataset to create the RelevanceDataset object
    max_sequence_size: int
        Maximum number of sequence per query. Used for padding
    required_fields_only: bool, optional
        Whether to only use required fields from the feature_config
    pad_sequence: bool
        Whether to pad sequence

    Returns
    -------
    `tf.function`
        Parsing function that takes in a serialized SequenceExample or Example message
        and extracts a dictionary of feature tensors
    """
    # Define preprocessing functions
    preprocessing_map = PreprocessingMap()
    preprocessing_map.add_fns(preprocessing_keys_to_fns)

    # Generate parsing function
    if tfrecord_type == TFRecordTypeKey.EXAMPLE:
        parser: TFRecordParser = TFRecordExampleParser(
            feature_config=feature_config,
            preprocessing_map=preprocessing_map,
            required_fields_only=required_fields_only,
        )
    elif tfrecord_type == TFRecordTypeKey.SEQUENCE_EXAMPLE:
        parser = TFRecordSequenceExampleParser(
            feature_config=feature_config,
            preprocessing_map=preprocessing_map,
            max_sequence_size=max_sequence_size,
            required_fields_only=required_fields_only,
            pad_sequence=pad_sequence,
        )
    else:
        raise KeyError(
            "Invalid TFRecord type specified: {}".format(tfrecord_type))

    return parser.get_parse_fn()
Exemplo n.º 3
0
    def setUp(self):
        file_io = LocalIO()
        logger = logging.getLogger()

        self.dataset = tf.data.TFRecordDataset(DATASET_PATH)
        self.proto = next(iter(self.dataset))
        self.feature_config = FeatureConfig.get_instance(
            tfrecord_type=TFRecordTypeKey.EXAMPLE,
            feature_config_dict=file_io.read_yaml(FEATURE_CONFIG_PATH),
            logger=logger,
        )
        self.parser = TFRecordExampleParser(
            feature_config=self.feature_config,
            preprocessing_map=PreprocessingMap(),
            required_fields_only=False,
        )