def build_input_data(input_size):
    if FLAGS.validation_files.endswith('.csv'):
        if FLAGS.dataset_base_dir is None:
            raise RuntimeError('To use CSV files as input, you must specify'
                               ' --dataset_base_dir')

        input_data = dataloader.CSVInputProcessor(
            csv_file=FLAGS.validation_files,
            data_dir=FLAGS.dataset_base_dir,
            batch_size=BATCH_SIZE,
            is_training=False,
            output_size=input_size,
            num_classes=FLAGS.num_classes,
            provide_filename=True,
        )
    else:
        input_data = dataloader.TFRecordWBBoxInputProcessor(
            file_pattern=FLAGS.validation_files,
            batch_size=BATCH_SIZE,
            is_training=False,
            output_size=input_size,
            num_classes=FLAGS.num_classes,
            num_instances=0,
            provide_filename=True,
        )

    dataset, _, _ = input_data.make_source_dataset()

    return dataset
예제 #2
0
def _build_dataset():
  if FLAGS.input_size is None:
    raise RuntimeError('To use a representative dataset, you must specify'
                         ' --input_size')

  if FLAGS.representative_dataset.endswith('.csv'):
    if FLAGS.dataset_base_dir is None:
      raise RuntimeError('To use CSV files as input, you must specify'
                         ' --dataset_base_dir')

    input_data = dataloader.CSVInputProcessor(
      csv_file=FLAGS.representative_dataset,
      data_dir=FLAGS.dataset_base_dir,
      batch_size=BATCH_SIZE,
      is_training=False,
      output_size=FLAGS.input_size,
      resize_with_pad=FLAGS.resize_with_pad,
      num_classes=FLAGS.num_classes,
    )
  else:
    input_data = dataloader.TFRecordWBBoxInputProcessor(
      file_pattern=FLAGS.representative_dataset,
      batch_size=BATCH_SIZE,
      is_training=False,
      output_size=FLAGS.input_size,
      resize_with_pad=FLAGS.resize_with_pad,
      num_classes=FLAGS.num_classes,
      num_instances=0,
    )

  dataset, _, _ = input_data.make_source_dataset()

  return dataset
예제 #3
0
def build_tfrecord_input_data(file_pattern, num_instances, is_training=False):
    if FLAGS.num_classes is None:
        raise RuntimeError('To use TFRecords as input, you must specify'
                           ' --num_classes')

    input_data = dataloader.TFRecordWBBoxInputProcessor(
        file_pattern=file_pattern,
        batch_size=FLAGS.batch_size,
        num_classes=FLAGS.num_classes,
        num_instances=num_instances,
        is_training=is_training,
        output_size=FLAGS.input_size,
        randaug_num_layers=FLAGS.randaug_num_layers,
        randaug_magnitude=FLAGS.randaug_magnitude,
        seed=FLAGS.random_seed,
    )

    return input_data.make_source_dataset()
예제 #4
0
def build_input_data():
    include_geo_data = FLAGS.geo_prior_ckpt_dir is not None

    input_data = dataloader.TFRecordWBBoxInputProcessor(
        file_pattern=FLAGS.test_files,
        batch_size=FLAGS.batch_size,
        is_training=False,
        output_size=FLAGS.input_size,
        num_classes=FLAGS.num_classes,
        num_instances=0,
        provide_validity_info_output=include_geo_data,
        provide_coord_date_encoded_input=include_geo_data,
        provide_instance_id=True,
        provide_coordinates_input=FLAGS.use_coordinates_inputs)

    dataset, _, _ = input_data.make_source_dataset()

    return dataset