示例#1
0
    def testArbitraryReaderFunc(self):
        def MakeRecord(i, j):
            return compat.as_bytes('%04d-%04d' % (i, j))

        record_bytes = len(MakeRecord(10, 200))

        all_contents = []
        for i in range(_NUM_FILES):
            filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i)
            with open(filename, 'wb') as f:
                for j in range(_NUM_ENTRIES):
                    record = MakeRecord(i, j)
                    f.write(record)
                    all_contents.append(record)

        def FixedLengthFile(filename):
            return readers.FixedLengthRecordDataset(filename, record_bytes)

        dataset = datasets.StreamingFilesDataset(os.path.join(
            self.get_temp_dir(), 'fixed_length*'),
                                                 filetype=FixedLengthFile)

        iterator = dataset.make_initializable_iterator()
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = []
        for _ in range(4 * len(all_contents)):
            retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))

        self.assertEqual(set(all_contents), set(retrieved_values))
示例#2
0
    def testTextLineDataset(self):
        all_contents = []
        for i in range(_NUM_FILES):
            filename = os.path.join(self.get_temp_dir(),
                                    'text_line.%d.txt' % i)
            contents = []
            for j in range(_NUM_ENTRIES):
                contents.append(compat.as_bytes('%d: %d' % (i, j)))
            with open(filename, 'wb') as f:
                f.write(b'\n'.join(contents))
            all_contents.extend(contents)

        dataset = datasets.StreamingFilesDataset(os.path.join(
            self.get_temp_dir(), 'text_line.*.txt'),
                                                 filetype='text')

        iterator = dataset.make_initializable_iterator()
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = []
        for _ in range(4 * len(all_contents)):
            retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))

        self.assertEqual(set(all_contents), set(retrieved_values))
示例#3
0
    def testTFRecordDatasetFromDataset(self):
        filenames = []
        all_contents = []
        for i in range(_NUM_FILES):
            filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
            filenames.append(filename)
            writer = python_io.TFRecordWriter(filename)
            for j in range(_NUM_ENTRIES):
                record = compat.as_bytes('Record %d of file %d' % (j, i))
                writer.write(record)
                all_contents.append(record)
            writer.close()

        filenames = dataset_ops.Dataset.from_tensor_slices(filenames)

        dataset = datasets.StreamingFilesDataset(filenames,
                                                 filetype='tfrecord')

        iterator = dataset.make_initializable_iterator()
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = []
        for _ in range(4 * len(all_contents)):
            retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))

        self.assertEqual(set(all_contents), set(retrieved_values))
示例#4
0
def tfrecord_ds(file_pattern: str,
                parser,
                batch_size: int,
                training: bool = True,
                shuffle_buf_sz: int = 50000,
                n_cores: int = 2,
                n_folds: int = 1,
                val_fold_idx: int = 0,
                streaming: bool = False) -> tf.data.Dataset:
    """
    Create a `tf.data` input pipeline from TFRecords files whose names satisfying a given pattern. Optionally partitions
    the data into training and validation sets according to k-fold cross-validation requirements.

    :param file_pattern: file pattern such as `data_train*.tfrec`
    :param parser: TFRecords parser function, which may also perform data augmentations.
    :param batch_size: Size of a data batch.
    :param training: Whether this is a training dataset, in which case the dataset is randomly shuffled and repeated.
    :param shuffle_buf_sz: Shuffle buffer size, for shuffling a training dataset. Default: 50k records.
    :param n_cores: Number of CPU cores, i.e., parallel threads.
    :param n_folds: Number of cross validation folds. Default: 1, meaning no cross validation.
    :param val_fold_idx: Fold ID for validation set, in cross validation. Ignored when `n_folds` is 1.
    :param streaming: under construction.
    :return: a `tf.data` dataset satisfying the above descriptions.
    """
    if streaming:
        # under construction
        dataset = tpu_datasets.StreamingFilesDataset(
            file_pattern, filetype='tfrecord', batch_transfer_size=batch_size)
    else:
        dataset = tf.data.Dataset.list_files(file_pattern)
        fetcher = tf.data.experimental.parallel_interleave(
            tfrecord_fetch_dataset, cycle_length=n_cores, sloppy=True)
        dataset = dataset.apply(fetcher)

    mapper_batcher = tf.data.experimental.map_and_batch(
        parser,
        batch_size=batch_size,
        num_parallel_batches=n_cores,
        drop_remainder=True)

    if n_folds > 1:
        dataset = crossval_ds(dataset, n_folds, val_fold_idx, training)

    if training:
        dataset = dataset.shuffle(shuffle_buf_sz)
        dataset = dataset.repeat()

    dataset = dataset.apply(mapper_batcher)
    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
    return dataset
示例#5
0
    def testArbitraryReaderFuncFromDatasetGenerator(self):
        def my_generator():
            yield (1, [1] * 10)

        def gen_dataset(dummy):
            return dataset_ops.Dataset.from_generator(
                my_generator, (dtypes.int64, dtypes.int64),
                (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10])))

        dataset = datasets.StreamingFilesDataset(dataset_ops.Dataset.range(10),
                                                 filetype=gen_dataset)

        iterator = dataset.make_initializable_iterator()
        self._sess.run(iterator.initializer)
        get_next = iterator.get_next()

        retrieved_values = self._sess.run(get_next)

        self.assertIsInstance(retrieved_values, (list, tuple))
        self.assertEqual(len(retrieved_values), 2)
        self.assertEqual(retrieved_values[0], 1)
        self.assertItemsEqual(retrieved_values[1], [1] * 10)
示例#6
0
 def testUnexpectedFilesType(self):
     with self.assertRaises(ValueError):
         datasets.StreamingFilesDataset(123, filetype='tfrecord')
示例#7
0
 def testUnexpectedFiletypeType(self):
     with self.assertRaises(ValueError):
         datasets.StreamingFilesDataset(os.path.join(
             self.get_temp_dir(), '*'),
                                        filetype=3)