예제 #1
0
def write_from_df(
    df: DataFrame,
    tfrecord_file: str,
    feature_config: FeatureConfig,
    tfrecord_type: str,
    logger: Logger = None,
):
    """
    Converts data from CSV files into tfrecord files

    Parameters
    df : `pd.DataFrame`
        pandas DataFrame to be converted to TFRecordDataset
    tfrecord_file : str
        tfrecord file path to write the output
    feature_config : `FeatureConfig`
        FeatureConfig object that defines the features to be loaded in the dataset
        and the preprocessing functions to be applied to each of them
    tfrecord_type : {"example", "sequence_example"}
        Type of the TFRecord protobuf message to be used for TFRecordDataset
    logger : `Logger`, optional
        logging handler for status messages
    """

    if logger:
        logger.info(
            "Writing SequenceExample protobufs to : {}".format(tfrecord_file))
    with io.TFRecordWriter(tfrecord_file) as tf_writer:
        if tfrecord_type == TFRecordTypeKey.EXAMPLE:
            protos = df.apply(
                lambda row: get_example_proto(
                    row=row, features=feature_config.get_all_features()),
                axis=1,
            )
        elif tfrecord_type == TFRecordTypeKey.SEQUENCE_EXAMPLE:
            # Group pandas dataframe on query_id/query key and
            # convert each group to a single sequence example proto
            context_feature_names = feature_config.get_context_features(
                key="name")
            protos = df.groupby(context_feature_names).apply(
                lambda g: get_sequence_example_proto(
                    group=g,
                    context_features=feature_config.get_context_features(),
                    sequence_features=feature_config.get_sequence_features(),
                ))
        else:
            raise Exception("You have entered {} as tfrecords write mode. "
                            "We only support {} and {}.".format(
                                tfrecord_type, TFRecordTypeKey.EXAMPLE,
                                TFRecordTypeKey.SEQUENCE_EXAMPLE))
        # Write to disk
        for proto in protos:
            tf_writer.write(proto.SerializeToString())
예제 #2
0
def write_from_df(
    df: DataFrame,
    tfrecord_file: str,
    feature_config: FeatureConfig,
    tfrecord_type: str,
    logger: Logger = None,
):
    """
    Converts data from CSV files into tfrecord data.
    Output data protobuf format -> train.SequenceExample

    Args:
        df: pandas DataFrame
        tfrecord_file: tfrecord file path to write the output
        feature_config: str path to YAML feature config or str YAML feature config
        tfrecord_type: TFRecordTypeKey.EXAMPLE or TFRecordTypeKey.SEQUENCE_EXAMPLE
        logger: logging object

    NOTE: This method should be moved out of ml4ir and into the preprocessing pipeline
    """

    if logger:
        logger.info(
            "Writing SequenceExample protobufs to : {}".format(tfrecord_file))
    with io.TFRecordWriter(tfrecord_file) as tf_writer:
        if tfrecord_type == TFRecordTypeKey.EXAMPLE:
            protos = df.apply(
                lambda row: get_example_proto(
                    row=row, features=feature_config.get_all_features()),
                axis=1,
            )
        elif tfrecord_type == TFRecordTypeKey.SEQUENCE_EXAMPLE:
            # Group pandas dataframe on query_id/query key and
            # convert each group to a single sequence example proto
            context_feature_names = feature_config.get_context_features(
                key="name")
            protos = df.groupby(context_feature_names).apply(
                lambda g: get_sequence_example_proto(
                    group=g,
                    context_features=feature_config.get_context_features(),
                    sequence_features=feature_config.get_sequence_features(),
                ))
        else:
            raise Exception("You have entered {} as tfrecords write mode. "
                            "We only support {} and {}.".format(
                                tfrecord_type, TFRecordTypeKey.EXAMPLE,
                                TFRecordTypeKey.SEQUENCE_EXAMPLE))
        # Write to disk
        for proto in protos:
            tf_writer.write(proto.SerializeToString())