Ejemplo n.º 1
0
class TestClfDatumParser(absltest.TestCase):

  def setUp(self):
    self.tempdir = '/tmp/test/tfrecord_clf'
    self.serializer = DatumSerializer('image')
    Path(self.tempdir).mkdir(parents=True, exist_ok=True)
    CLF_GEN = image.ClfDatumGenerator('tests/dummy_data/clf')
    gen_kwargs = {'image_set': 'ImageSets'}
    sparse_features = [
        'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'is_difficult'
    ]
    self.writer = TFRecordWriter(CLF_GEN,
                                 self.serializer,
                                 self.tempdir,
                                 'train',
                                 1,
                                 sparse_features=sparse_features,
                                 **gen_kwargs)
    self.writer.create_records()
    self.parser = DatumParser(self.tempdir)

  def tearDown(self):
    rmtree(self.tempdir)

  def test_parser(self):
    dataset = tf.data.TFRecordDataset(tf.data.Dataset.list_files(self.tempdir + '/train-*.tfrecord'))
    dataset = dataset.map(self.parser.parse_fn)
    dataset = dataset.batch(1)
    for batch in dataset:
      self.assertEqual(list(batch.keys()),
                       ['image', 'label_test1', 'label_test2', 'label_test3', 'label_test4'])
      self.assertEqual(batch['image'].shape, [1, 2670, 2870, 3])
      self.assertEqual(batch['label_test1'].numpy(), [1])
Ejemplo n.º 2
0
def _test_create_textjson_records(path):
    tempdir = tempfile.mkdtemp()
    data = {
        1: {
            'text': 'this is text file',
            'label': {
                'polarity': 1
            }
        },
        2: {
            'text': 'this is json file',
            'label': {
                'polarity': 2
            }
        },
        3: {
            'text': 'this is label file',
            'label': {
                'polarity': 0
            }
        },
    }
    with open(os.path.join(tempdir, 'train.json'), 'w') as f:
        json.dump(data, f)
    gen_from_json = text.TextJsonDatumGenerator(tempdir)
    serializer = DatumSerializer('text')
    Path(path).mkdir(parents=True, exist_ok=True)
    textjson_gen = text.TextJsonDatumGenerator(tempdir)
    writer = TFRecordWriter(textjson_gen, serializer, path, 'train', 3)
    writer.create_records()
    rmtree(tempdir)
Ejemplo n.º 3
0
 def setUp(self):
   class_map = {
       name: idx + 1
       for idx, name in enumerate(
           open(os.path.join('tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
   }
   self.tempdir = '/tmp/test/tfrecord_det'
   self.serializer = DatumSerializer('image')
   Path(self.tempdir).mkdir(parents=True, exist_ok=True)
   DET_GEN = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                     gen_config=AttrDict(has_test_annotations=True,
                                                         class_map=class_map))
   gen_kwargs = {'image_set': 'ImageSets'}
   sparse_features = [
       'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'labels_difficult'
   ]
   self.writer = TFRecordWriter(DET_GEN,
                                self.serializer,
                                self.tempdir,
                                'train',
                                2,
                                sparse_features=sparse_features,
                                **gen_kwargs)
   self.writer.create_records()
   self.parser = DatumParser(self.tempdir)
Ejemplo n.º 4
0
class TestSegDatumParser(absltest.TestCase):

  def setUp(self):
    self.tempdir = '/tmp/test/tfrecord_seg'
    self.serializer = DatumSerializer('image')
    Path(self.tempdir).mkdir(parents=True, exist_ok=True)
    SEG_GEN = image.SegDatumGenerator('tests/dummy_data/seg/voc')
    gen_kwargs = {'image_set': 'ImageSets'}
    sparse_features = [
        'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'is_difficult'
    ]
    self.writer = TFRecordWriter(SEG_GEN,
                                 self.serializer,
                                 self.tempdir,
                                 'train',
                                 1,
                                 sparse_features=sparse_features,
                                 **gen_kwargs)
    self.writer.create_records()
    self.parser = DatumParser(self.tempdir)

  def tearDown(self):
    rmtree(self.tempdir)

  def test_parser(self):
    dataset = tf.data.TFRecordDataset(tf.data.Dataset.list_files(self.tempdir + '/train-*.tfrecord'))
    dataset = dataset.map(self.parser.parse_fn)
    dataset = dataset.batch(1)
    batch = next(iter(dataset))
    self.assertEqual(list(batch.keys()), ['image', 'label'])
    self.assertEqual(batch['image'].shape, [1, 366, 500, 3])
    self.assertEqual(batch['label'].shape, [1, 366, 500, 3])
Ejemplo n.º 5
0
 def test_flush(self, *args):
   generator, num_examples = args
   gen_kwargs = {'image_set': 'ImageSets'}
   self.writer = TFRecordWriter(generator, self.serializer, self.tempdir, 'train', num_examples,
                                **gen_kwargs)
   self.writer.cache_records()
   self.writer.flush()
Ejemplo n.º 6
0
def _test_create_det_records(path):
    class_map = {
        name: idx + 1
        for idx, name in enumerate(
            open(os.path.join(
                'tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
    }
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    det_gen = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                      gen_config=AttrDict(
                                          has_test_annotations=True,
                                          class_map=class_map))
    gen_kwargs = {'image_set': 'ImageSets'}
    sparse_features = [
        'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated',
        'labels_difficult'
    ]
    writer = TFRecordWriter(det_gen,
                            serializer,
                            path,
                            'train',
                            2,
                            sparse_features=sparse_features,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(det_gen,
                            serializer,
                            path,
                            'val',
                            1,
                            sparse_features=sparse_features,
                            **gen_kwargs)
    writer.create_records()
Ejemplo n.º 7
0
def _test_create_seg_records(path):
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    seg_gen = image.SegDatumGenerator('tests/dummy_data/seg/voc')
    gen_kwargs = {'image_set': 'ImageSets'}
    writer = TFRecordWriter(seg_gen, serializer, path, 'train', 1,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(seg_gen, serializer, path, 'val', 1, **gen_kwargs)
    writer.create_records()
Ejemplo n.º 8
0
 def setUp(self):
   self.tempdir = '/tmp/test/tfrecord_clf'
   self.serializer = DatumSerializer('image')
   Path(self.tempdir).mkdir(parents=True, exist_ok=True)
   CLF_GEN = image.ClfDatumGenerator('tests/dummy_data/clf')
   gen_kwargs = {'image_set': 'ImageSets'}
   sparse_features = [
       'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'is_difficult'
   ]
   self.writer = TFRecordWriter(CLF_GEN,
                                self.serializer,
                                self.tempdir,
                                'train',
                                1,
                                sparse_features=sparse_features,
                                **gen_kwargs)
   self.writer.create_records()
   self.parser = DatumParser(self.tempdir)
Ejemplo n.º 9
0
def main(_: Any) -> None:
    logging.info('Loading datum comversion configuration file.')
    config = load_module('config', FLAGS.config_path).cnf
    splits = config.splits
    if FLAGS.splits:
        splits = FLAGS.splits.split(',')
    generator = config.generator(FLAGS.input_path)
    Path(FLAGS.output_path).mkdir(parents=True, exist_ok=True)
    for split in splits:
        logging.info(f'Creating tfrecord writer for split: {split}.')
        tfr_writer = TFRecordWriter(generator,
                                    config.serializer,
                                    FLAGS.output_path,
                                    split,
                                    config.num_examples.get(split),
                                    sparse_features=config.sparse_features,
                                    **config.gen_kwargs)
        logging.info('Starting conversion process.')
        tfr_writer.create_records()
        logging.info(f'Completed tfrecord conversion for input split: {split}')
Ejemplo n.º 10
0
class TestDetDatumParser(absltest.TestCase):

  def setUp(self):
    class_map = {
        name: idx + 1
        for idx, name in enumerate(
            open(os.path.join('tests/dummy_data/det/voc/voc2012.names')).read().splitlines())
    }
    self.tempdir = '/tmp/test/tfrecord_det'
    self.serializer = DatumSerializer('image')
    Path(self.tempdir).mkdir(parents=True, exist_ok=True)
    DET_GEN = image.DetDatumGenerator('tests/dummy_data/det/voc',
                                      gen_config=AttrDict(has_test_annotations=True,
                                                          class_map=class_map))
    gen_kwargs = {'image_set': 'ImageSets'}
    sparse_features = [
        'xmin', 'xmax', 'ymin', 'ymax', 'labels', 'pose', 'is_truncated', 'labels_difficult'
    ]
    self.writer = TFRecordWriter(DET_GEN,
                                 self.serializer,
                                 self.tempdir,
                                 'train',
                                 2,
                                 sparse_features=sparse_features,
                                 **gen_kwargs)
    self.writer.create_records()
    self.parser = DatumParser(self.tempdir)

  def tearDown(self):
    rmtree(self.tempdir)

  def test_parser(self):
    dataset = tf.data.TFRecordDataset(tf.data.Dataset.list_files(self.tempdir + '/train-*.tfrecord'))
    dataset = dataset.map(self.parser.parse_fn)
    dataset = dataset.batch(1)
    batch = next(iter(dataset))
    self.assertEqual(list(batch.keys()), [
        'is_truncated', 'labels', 'labels_difficult', 'pose', 'xmax', 'xmin', 'ymax', 'ymin', 'image'
    ])
    self.assertEqual(batch['image'].shape, [1, 500, 486, 3])
    self.assertEqual(batch['xmin'].numpy(), np.asarray([0.3580247], dtype=np.float32))
Ejemplo n.º 11
0
class TestClfTFRecordWriter(absltest.TestCase):

  def setUp(self):
    self.serializer = DatumSerializer('image')
    self.tempdir = tempfile.mkdtemp()
    Path(self.tempdir).mkdir(parents=True, exist_ok=True)

  def tearDown(self):
    rmtree(self.tempdir)

  def test_cache_records(self, *args):
    generator, num_examples = args
    gen_kwargs = {'image_set': 'ImageSets'}
    self.writer = TFRecordWriter(generator, self.serializer, self.tempdir, 'train', num_examples,
                                 **gen_kwargs)
    self.writer.cache_records()
    self.assertEqual(self.writer.current_examples, num_examples)

  def test_flush(self, *args):
    generator, num_examples = args
    gen_kwargs = {'image_set': 'ImageSets'}
    self.writer = TFRecordWriter(generator, self.serializer, self.tempdir, 'train', num_examples,
                                 **gen_kwargs)
    self.writer.cache_records()
    self.writer.flush()
Ejemplo n.º 12
0
def _test_create_clf_records(path):
    serializer = DatumSerializer('image')
    Path(path).mkdir(parents=True, exist_ok=True)
    clf_gen = image.ClfDatumGenerator('tests/dummy_data/clf')
    gen_kwargs = {'image_set': 'ImageSets'}
    writer = TFRecordWriter(clf_gen,
                            serializer,
                            path,
                            'train',
                            1,
                            sparse_features=None,
                            **gen_kwargs)
    writer.create_records()
    writer = TFRecordWriter(clf_gen,
                            serializer,
                            path,
                            'val',
                            1,
                            sparse_features=None,
                            **gen_kwargs)
    writer.create_records()
Ejemplo n.º 13
0
def export_to_tfrecord(input_path: str, output_path: str, problem_type: str,
                       write_configs: ConfigBase) -> None:
    """Export data to tfrecord format.

  Args:
    input_path: Root path to input data folder.
    output_path: Path to store output tfrecords and generated metadata.
    problem_type: Type of the problem, see `datum.probelm.types` for available problems.

    write_configs has the following configurable attributes:
      generator: Generator class.
      serializer: Serializer instance.
      splits: A dict with split names as keys and split attributes as values.
        Following split attributes are supported:
          num_examples: Number of examples in the split.
          extension: Input image extension in case of image data, all input should have same
            extension.
            default: For image data, `.jpg`
          image_dir: Name of the directory containing the data, used for image classification.
            default: split name, eg: train for classification, for detection `JPEGImages` -a s per
              VOC12 folder structure.
          csv_path: Path to ground truths csv file, used for classification dataset.
            default: split name with .csv extension, eg: train.csv
          set_dir: In case of VOC12 style  datasets, image set information.
            default: `ImageSets` as per VOC12 dataset folder structure
          annotation_dir: Directory with annotations, used for VOC12 style datasets.
            default: `Annotations` as per VOC12 dataset folder structure
          label_dir: Directory with label images, used in segmentation.
            default: `SegmentationClass` for as per VOC12 folder structure.
          image_extension: Extension of input images, used in segmentation.
            default: `.jpg`
          label_extension: Extension of label images, used in segmentation.
            default: `.png`

    Raises:
      ValueError: If splits information is not in the dict format.
  """
    label_names_file = None
    if problem_type == types.IMAGE_DET:
        label_names_file = os.path.join(input_path, "classes.names")
    base_write_configs = get_default_write_configs(problem_type,
                                                   label_names_file)
    write_configs = base_write_configs.merge(write_configs)
    splits = write_configs.splits
    if not isinstance(splits, dict):
        raise ValueError(
            f"Splits must be a dict in the input config: {write_configs}")
    generator = write_configs.generator(input_path)
    Path(output_path).mkdir(parents=True, exist_ok=True)
    for split, split_kwargs in splits.items():
        logging.info(f'Creating tfrecord writer for split: {split}.')
        tfr_writer = TFRecordWriter(
            generator,
            write_configs.serializer,
            output_path,
            split,
            split_kwargs["num_examples"],
            sparse_features=write_configs.sparse_features,
            **split_kwargs)
        logging.info('Starting conversion process.')
        tfr_writer.create_records()
        logging.info(f'Completed tfrecord conversion for input split: {split}')