Exemplo n.º 1
0
def load_seed_data(data_dir, logger, file_io: LocalIO):
    # Start with loading in CSV format only
    if file_io.path_exists(data_dir):
        dfs = file_io.get_files_in_directory(data_dir)
        return file_io.read_df_list(dfs)
    else:
        logger.error("Error! Data directory must exist and be specified")
Exemplo n.º 2
0
def main(args):
    """Convert CSV files into tfrecord Example/SequenceExample files"""
    # Setup logging
    logger: Logger = setup_logging()
    file_io = LocalIO(logger)

    # Get all CSV files to be converted, depending on user's arguments
    if args.csv_dir:
        csv_files: List[str] = file_io.get_files_in_directory(
            indir=args.csv_dir, extension="*.csv")
    else:
        csv_files: List[str] = args.csv_files

    # Load feat config

    feature_config: FeatureConfig = FeatureConfig.get_instance(
        tfrecord_type=MODES[args.tfmode],
        feature_config_dict=file_io.read_yaml(args.feature_config),
        logger=logger,
    )

    # Convert to TFRecord SequenceExample protobufs and save
    if args.keep_single_files:
        # Convert each CSV file individually - better performance
        for csv_file in csv_files:
            tfrecord_file: str = os.path.basename(csv_file).replace(".csv", "")
            tfrecord_file: str = os.path.join(
                args.out_dir, "{}.tfrecord".format(tfrecord_file))
            write_from_files(
                csv_files=[csv_file],
                tfrecord_file=tfrecord_file,
                feature_config=feature_config,
                logger=logger,
                tfrecord_type=MODES[args.tfmode],
            )

    else:
        # Convert all CSV files at once - expensive groupby operation
        tfrecord_file: str = os.path.join(args.out_dir, "combined.tfrecord")
        write_from_files(
            csv_files=csv_files,
            tfrecord_file=tfrecord_file,
            feature_config=feature_config,
            logger=logger,
            tfrecord_type=MODES[args.tfmode],
            file_io=file_io,
        )