def build_input_data(input_size): if FLAGS.validation_files.endswith('.csv'): if FLAGS.dataset_base_dir is None: raise RuntimeError('To use CSV files as input, you must specify' ' --dataset_base_dir') input_data = dataloader.CSVInputProcessor( csv_file=FLAGS.validation_files, data_dir=FLAGS.dataset_base_dir, batch_size=BATCH_SIZE, is_training=False, output_size=input_size, num_classes=FLAGS.num_classes, provide_filename=True, ) else: input_data = dataloader.TFRecordWBBoxInputProcessor( file_pattern=FLAGS.validation_files, batch_size=BATCH_SIZE, is_training=False, output_size=input_size, num_classes=FLAGS.num_classes, num_instances=0, provide_filename=True, ) dataset, _, _ = input_data.make_source_dataset() return dataset
def _build_dataset(): if FLAGS.input_size is None: raise RuntimeError('To use a representative dataset, you must specify' ' --input_size') if FLAGS.representative_dataset.endswith('.csv'): if FLAGS.dataset_base_dir is None: raise RuntimeError('To use CSV files as input, you must specify' ' --dataset_base_dir') input_data = dataloader.CSVInputProcessor( csv_file=FLAGS.representative_dataset, data_dir=FLAGS.dataset_base_dir, batch_size=BATCH_SIZE, is_training=False, output_size=FLAGS.input_size, resize_with_pad=FLAGS.resize_with_pad, num_classes=FLAGS.num_classes, ) else: input_data = dataloader.TFRecordWBBoxInputProcessor( file_pattern=FLAGS.representative_dataset, batch_size=BATCH_SIZE, is_training=False, output_size=FLAGS.input_size, resize_with_pad=FLAGS.resize_with_pad, num_classes=FLAGS.num_classes, num_instances=0, ) dataset, _, _ = input_data.make_source_dataset() return dataset
def build_csv_input_data(csv_file, is_training=False): if FLAGS.dataset_base_dir is None: raise RuntimeError('To use CSV files as input, you must specify' ' --dataset_base_dir') input_data = dataloader.CSVInputProcessor( csv_file=csv_file, data_dir=FLAGS.dataset_base_dir, batch_size=FLAGS.batch_size, is_training=is_training, output_size=FLAGS.input_size, randaug_num_layers=FLAGS.randaug_num_layers, randaug_magnitude=FLAGS.randaug_magnitude, seed=FLAGS.random_seed, ) return input_data.make_source_dataset()