예제 #1
0
 def setUp(self):
     self.tempdir = tempfile.mkdtemp()
     _test_create_det_records(self.tempdir)
     configs = DatasetConfigs()
     configs.batch_size_train = 1
     configs.batch_size_val = 1
     self._dataset = Dataset(self.tempdir, configs)
예제 #2
0
class TestClfDataset(absltest.TestCase):
    def setUp(self):
        self.tempdir = tempfile.mkdtemp()
        _test_create_clf_records(self.tempdir)
        configs = DatasetConfigs()
        configs.batch_size_train = 1
        configs.batch_size_val = 1
        self._dataset = Dataset(self.tempdir, configs)

    def tearDown(self):
        rmtree(self.tempdir)

    def test_train_fn(self):
        ds = self._dataset.train_fn('train', False)
        batch = next(iter(ds))
        self.assertEqual(batch['image'].shape, [1, 2670, 2870, 3])
        self.assertEqual(batch['label_test1'], 1)

    def test_train_fn_shuffle(self):
        ds = self._dataset.train_fn('train', True)
        batch = next(iter(ds))
        self.assertEqual(batch['image'].shape, [1, 2670, 2870, 3])
        self.assertEqual(batch['label_test1'], 1)

    def test_val_fn(self):
        ds = self._dataset.train_fn('val', False)
        batch = next(iter(ds))
        self.assertEqual(batch['image'].shape, [1, 2670, 2870, 3])
        self.assertEqual(batch['label_test1'], 1)

    def test_padded_shapes(self):
        exp = {
            'image': [None] * 3,
            'label_test1': [],
            'label_test2': [],
            'label_test3': [],
            'label_test4': []
        }
        self.assertEqual(self._dataset.padded_shapes, exp)

    def test_dataset_configs_prop(self):
        configs = self._dataset.dataset_configs
        self.assertEqual(configs.batch_size_train, 1)
        self.assertEqual(configs.batch_size_val, 1)
        configs.batch_size_train = 16
        configs.batch_size_val = 16
        self._dataset.dataset_configs = configs
        configs = self._dataset.dataset_configs
        self.assertEqual(configs.batch_size_train, 16)
        self.assertEqual(configs.batch_size_val, 16)
예제 #3
0
class TestSegDataset(absltest.TestCase):
    def setUp(self):
        self.tempdir = tempfile.mkdtemp()
        _test_create_seg_records(self.tempdir)
        configs = DatasetConfigs()
        configs.batch_size_train = 1
        configs.batch_size_val = 1
        self._dataset = Dataset(self.tempdir, configs)

    def tearDown(self):
        rmtree(self.tempdir)

    def test_train_fn(self):
        ds = self._dataset.train_fn('train', False)
        batch = next(iter(ds))
        self.assertEqual(batch['image'].shape, [1, 281, 500, 3])
        self.assertEqual(batch['label'].shape, [1, 281, 500, 3])
예제 #4
0
class TestTextJsonDataset(absltest.TestCase):
    def setUp(self):
        self.tempdir = tempfile.mkdtemp()
        _test_create_textjson_records(self.tempdir)
        configs = DatasetConfigs()
        configs.batch_size_train = 3
        self._dataset = Dataset(self.tempdir, configs)

    def tearDown(self):
        rmtree(self.tempdir)

    def test_train_fn(self):
        ds = self._dataset.train_fn('train', False)
        batch = next(iter(ds))
        self.assertEqual(batch['text'].shape, [3])
        self.assertEqual(batch['polarity'].shape, [3])
        np.array_equal(batch['polarity'].numpy(), [1, 2, 0])
        self.assertEqual(list(batch['text'].numpy()), [
            b'this is label file', b'this is json file', b'this is text file'
        ])
예제 #5
0
class TestDetDataset(absltest.TestCase):
    def setUp(self):
        self.tempdir = tempfile.mkdtemp()
        _test_create_det_records(self.tempdir)
        configs = DatasetConfigs()
        configs.batch_size_train = 1
        configs.batch_size_val = 1
        self._dataset = Dataset(self.tempdir, configs)

    def tearDown(self):
        rmtree(self.tempdir)

    def test_train_fn(self):
        ds = self._dataset.train_fn('train', False)
        batch = next(iter(ds))
        self.assertEqual(batch['image'].shape, [1, 281, 500, 3])
        np.array_equal(batch['xmin'].numpy(),
                       np.array([0.208, 0.266, 0.39, 0.052], dtype=np.float32))
        np.array_equal(
            batch['pose'],
            np.asarray([[b'frontal', b'left', b'rear', b'rear']],
                       dtype=np.str))
예제 #6
0
def load(path: str,
         dataset_configs: Optional[ConfigBase] = None) -> DatasetType:
    """Load tfrecord dataset as `tf.data.Daatset`.

  Args:
    path: path to the storage location with tfrecord files and metadata.
    dataset_configs: A DatasetConfigs can be used to control the parameter for
      the output tf.data.Dataset. This is designed to give an extensive control of
      the dataset pre and post processsing operation to the end-user.

    dataset_configs has the following configurable attributes.
      buffer_size: Representing the number of elements from this dataset from which the
           new dataset will sample, default: 100.
      seed: Random seed for tfrecord files based randomness, default: 6052020.
      full_dataset: 'Returns the dataset as a single batch for dataset with only one element,
        default: False.
      batch_size_train: Batch size for training data, default: 32.
      batch_size_val: Batch size for validation data, default: 32.
      batch_size_test: Batch size for test data, default: 32.
      shuffle_files: Shuffle tfrecord input files, default: True.
      reshuffle_each_iteration: If true indicates that the dataset should be pseudorandomly
        reshuffled each time it is iterated over, default: False.
      cache: If true the first time the dataset is iterated over, its elements will be cached
        either the specified file or in memory. Subsequent iterations will use the
        cached data, default: False.
      cache_filename: Representing the name of a directory on the file system to use for caching
        elements in this Dataset, default: ''.
      bucket_op: The sequence length based bucketing operation options.
        it has following sub attributes:
        bucket_boundaries: Upper length boundaries of the buckets, default: [0].
        bucket_batch_sizes: Batch size per bucket. Length should be len(bucket_boundaries) + 1,
         default: [32, 32].
      bucket_fn: Function from element in Dataset to tf.int32, determines the length of the
        element which will determine the bucket it goes into, default: None.
      pre_batching_callback_train: Preprocessing operation to use on a single case of the dataset
        before batching, default: None.
      post_batching_callback_train: Processing operation to use on a batch of the dataset after
        batching, default: None.
      pre_batching_callback_val: Preprocessing operation to use on a single case of the dataset
        before batching, default: None.
      post_batching_callback_val: Processing operation to use on a batch of the dataset after
        batching, default: None.
      pre_batching_callback_test: Preprocessing operation to use on a single case of the dataset
        before batching, default: None.
      post_batching_callback_test: Processing operation to use on a batch of the dataset after
        batching, default: None.
      read_config: A TFRReadconfigs object can be used to control the parameter required to read
        tfrecord files to construct a tf.data.Dataset.
        it supports the following sub atrributes
        experimental_interleave_sort_fn: Dataset interleave sort function, default: None.
        shuffle_reshuffle_each_iteration: Shuffle files each iteration before reading,
          default: True.
        interleave_cycle_length: The number of input elements that will be processed concurrently,
          default: -1,
        interleave_block_length: The number of consecutive elements to produce from each input
          element before cycling to another input element, default: 1.
        seed: Random seed for tfrecord files based randomness, default: 6052020.
        options: Tensorflow data api options for dataset prep and reading,
          default: tf.data.Options(),

  Returns:
    a `tf.data.Dataset`.
  """
    if not tf.io.gfile.isdir(path):
        logging.error(
            f'Input path: {path} is not a directory. Provide a path where tfrecord files metadata\
          are stored.')
        raise ValueError(
            f'Input path: {path} is not a directory. Provide a path where tfrecord files metadata\
          are stored.')
    if not dataset_configs:
        logging.info('Using default dataset configuration.')
        dataset_configs = DatasetConfigs()
    return Dataset(path, dataset_configs)