def build_input_data(input_size):
    if FLAGS.validation_files.endswith('.csv'):
        if FLAGS.dataset_base_dir is None:
            raise RuntimeError('To use CSV files as input, you must specify'
                               ' --dataset_base_dir')

        input_data = dataloader.CSVInputProcessor(
            csv_file=FLAGS.validation_files,
            data_dir=FLAGS.dataset_base_dir,
            batch_size=BATCH_SIZE,
            is_training=False,
            output_size=input_size,
            num_classes=FLAGS.num_classes,
            provide_filename=True,
        )
    else:
        input_data = dataloader.TFRecordWBBoxInputProcessor(
            file_pattern=FLAGS.validation_files,
            batch_size=BATCH_SIZE,
            is_training=False,
            output_size=input_size,
            num_classes=FLAGS.num_classes,
            num_instances=0,
            provide_filename=True,
        )

    dataset, _, _ = input_data.make_source_dataset()

    return dataset
コード例 #2
0
def _build_dataset():
  if FLAGS.input_size is None:
    raise RuntimeError('To use a representative dataset, you must specify'
                         ' --input_size')

  if FLAGS.representative_dataset.endswith('.csv'):
    if FLAGS.dataset_base_dir is None:
      raise RuntimeError('To use CSV files as input, you must specify'
                         ' --dataset_base_dir')

    input_data = dataloader.CSVInputProcessor(
      csv_file=FLAGS.representative_dataset,
      data_dir=FLAGS.dataset_base_dir,
      batch_size=BATCH_SIZE,
      is_training=False,
      output_size=FLAGS.input_size,
      resize_with_pad=FLAGS.resize_with_pad,
      num_classes=FLAGS.num_classes,
    )
  else:
    input_data = dataloader.TFRecordWBBoxInputProcessor(
      file_pattern=FLAGS.representative_dataset,
      batch_size=BATCH_SIZE,
      is_training=False,
      output_size=FLAGS.input_size,
      resize_with_pad=FLAGS.resize_with_pad,
      num_classes=FLAGS.num_classes,
      num_instances=0,
    )

  dataset, _, _ = input_data.make_source_dataset()

  return dataset
コード例 #3
0
def build_csv_input_data(csv_file, is_training=False):
    if FLAGS.dataset_base_dir is None:
        raise RuntimeError('To use CSV files as input, you must specify'
                           ' --dataset_base_dir')

    input_data = dataloader.CSVInputProcessor(
        csv_file=csv_file,
        data_dir=FLAGS.dataset_base_dir,
        batch_size=FLAGS.batch_size,
        is_training=is_training,
        output_size=FLAGS.input_size,
        randaug_num_layers=FLAGS.randaug_num_layers,
        randaug_magnitude=FLAGS.randaug_magnitude,
        seed=FLAGS.random_seed,
    )

    return input_data.make_source_dataset()