def build_input_data(input_size): if FLAGS.validation_files.endswith('.csv'): if FLAGS.dataset_base_dir is None: raise RuntimeError('To use CSV files as input, you must specify' ' --dataset_base_dir') input_data = dataloader.CSVInputProcessor( csv_file=FLAGS.validation_files, data_dir=FLAGS.dataset_base_dir, batch_size=BATCH_SIZE, is_training=False, output_size=input_size, num_classes=FLAGS.num_classes, provide_filename=True, ) else: input_data = dataloader.TFRecordWBBoxInputProcessor( file_pattern=FLAGS.validation_files, batch_size=BATCH_SIZE, is_training=False, output_size=input_size, num_classes=FLAGS.num_classes, num_instances=0, provide_filename=True, ) dataset, _, _ = input_data.make_source_dataset() return dataset
def _build_dataset(): if FLAGS.input_size is None: raise RuntimeError('To use a representative dataset, you must specify' ' --input_size') if FLAGS.representative_dataset.endswith('.csv'): if FLAGS.dataset_base_dir is None: raise RuntimeError('To use CSV files as input, you must specify' ' --dataset_base_dir') input_data = dataloader.CSVInputProcessor( csv_file=FLAGS.representative_dataset, data_dir=FLAGS.dataset_base_dir, batch_size=BATCH_SIZE, is_training=False, output_size=FLAGS.input_size, resize_with_pad=FLAGS.resize_with_pad, num_classes=FLAGS.num_classes, ) else: input_data = dataloader.TFRecordWBBoxInputProcessor( file_pattern=FLAGS.representative_dataset, batch_size=BATCH_SIZE, is_training=False, output_size=FLAGS.input_size, resize_with_pad=FLAGS.resize_with_pad, num_classes=FLAGS.num_classes, num_instances=0, ) dataset, _, _ = input_data.make_source_dataset() return dataset
def build_tfrecord_input_data(file_pattern, num_instances, is_training=False): if FLAGS.num_classes is None: raise RuntimeError('To use TFRecords as input, you must specify' ' --num_classes') input_data = dataloader.TFRecordWBBoxInputProcessor( file_pattern=file_pattern, batch_size=FLAGS.batch_size, num_classes=FLAGS.num_classes, num_instances=num_instances, is_training=is_training, output_size=FLAGS.input_size, randaug_num_layers=FLAGS.randaug_num_layers, randaug_magnitude=FLAGS.randaug_magnitude, seed=FLAGS.random_seed, ) return input_data.make_source_dataset()
def build_input_data(): include_geo_data = FLAGS.geo_prior_ckpt_dir is not None input_data = dataloader.TFRecordWBBoxInputProcessor( file_pattern=FLAGS.test_files, batch_size=FLAGS.batch_size, is_training=False, output_size=FLAGS.input_size, num_classes=FLAGS.num_classes, num_instances=0, provide_validity_info_output=include_geo_data, provide_coord_date_encoded_input=include_geo_data, provide_instance_id=True, provide_coordinates_input=FLAGS.use_coordinates_inputs) dataset, _, _ = input_data.make_source_dataset() return dataset