def testArbitraryReaderFunc(self): def MakeRecord(i, j): return compat.as_bytes('%04d-%04d' % (i, j)) record_bytes = len(MakeRecord(10, 200)) all_contents = [] for i in range(_NUM_FILES): filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i) with open(filename, 'wb') as f: for j in range(_NUM_ENTRIES): record = MakeRecord(i, j) f.write(record) all_contents.append(record) def FixedLengthFile(filename): return readers.FixedLengthRecordDataset(filename, record_bytes) dataset = datasets.StreamingFilesDataset(os.path.join( self.get_temp_dir(), 'fixed_length*'), filetype=FixedLengthFile) iterator = dataset.make_initializable_iterator() self._sess.run(iterator.initializer) get_next = iterator.get_next() retrieved_values = [] for _ in range(4 * len(all_contents)): retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) self.assertEqual(set(all_contents), set(retrieved_values))
def testTextLineDataset(self): all_contents = [] for i in range(_NUM_FILES): filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i) contents = [] for j in range(_NUM_ENTRIES): contents.append(compat.as_bytes('%d: %d' % (i, j))) with open(filename, 'wb') as f: f.write(b'\n'.join(contents)) all_contents.extend(contents) dataset = datasets.StreamingFilesDataset(os.path.join( self.get_temp_dir(), 'text_line.*.txt'), filetype='text') iterator = dataset.make_initializable_iterator() self._sess.run(iterator.initializer) get_next = iterator.get_next() retrieved_values = [] for _ in range(4 * len(all_contents)): retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) self.assertEqual(set(all_contents), set(retrieved_values))
def testTFRecordDatasetFromDataset(self): filenames = [] all_contents = [] for i in range(_NUM_FILES): filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) filenames.append(filename) writer = python_io.TFRecordWriter(filename) for j in range(_NUM_ENTRIES): record = compat.as_bytes('Record %d of file %d' % (j, i)) writer.write(record) all_contents.append(record) writer.close() filenames = dataset_ops.Dataset.from_tensor_slices(filenames) dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') iterator = dataset.make_initializable_iterator() self._sess.run(iterator.initializer) get_next = iterator.get_next() retrieved_values = [] for _ in range(4 * len(all_contents)): retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) self.assertEqual(set(all_contents), set(retrieved_values))
def tfrecord_ds(file_pattern: str, parser, batch_size: int, training: bool = True, shuffle_buf_sz: int = 50000, n_cores: int = 2, n_folds: int = 1, val_fold_idx: int = 0, streaming: bool = False) -> tf.data.Dataset: """ Create a `tf.data` input pipeline from TFRecords files whose names satisfying a given pattern. Optionally partitions the data into training and validation sets according to k-fold cross-validation requirements. :param file_pattern: file pattern such as `data_train*.tfrec` :param parser: TFRecords parser function, which may also perform data augmentations. :param batch_size: Size of a data batch. :param training: Whether this is a training dataset, in which case the dataset is randomly shuffled and repeated. :param shuffle_buf_sz: Shuffle buffer size, for shuffling a training dataset. Default: 50k records. :param n_cores: Number of CPU cores, i.e., parallel threads. :param n_folds: Number of cross validation folds. Default: 1, meaning no cross validation. :param val_fold_idx: Fold ID for validation set, in cross validation. Ignored when `n_folds` is 1. :param streaming: under construction. :return: a `tf.data` dataset satisfying the above descriptions. """ if streaming: # under construction dataset = tpu_datasets.StreamingFilesDataset( file_pattern, filetype='tfrecord', batch_transfer_size=batch_size) else: dataset = tf.data.Dataset.list_files(file_pattern) fetcher = tf.data.experimental.parallel_interleave( tfrecord_fetch_dataset, cycle_length=n_cores, sloppy=True) dataset = dataset.apply(fetcher) mapper_batcher = tf.data.experimental.map_and_batch( parser, batch_size=batch_size, num_parallel_batches=n_cores, drop_remainder=True) if n_folds > 1: dataset = crossval_ds(dataset, n_folds, val_fold_idx, training) if training: dataset = dataset.shuffle(shuffle_buf_sz) dataset = dataset.repeat() dataset = dataset.apply(mapper_batcher) dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) return dataset
def testArbitraryReaderFuncFromDatasetGenerator(self): def my_generator(): yield (1, [1] * 10) def gen_dataset(dummy): return dataset_ops.Dataset.from_generator( my_generator, (dtypes.int64, dtypes.int64), (tensor_shape.TensorShape([]), tensor_shape.TensorShape([10]))) dataset = datasets.StreamingFilesDataset(dataset_ops.Dataset.range(10), filetype=gen_dataset) iterator = dataset.make_initializable_iterator() self._sess.run(iterator.initializer) get_next = iterator.get_next() retrieved_values = self._sess.run(get_next) self.assertIsInstance(retrieved_values, (list, tuple)) self.assertEqual(len(retrieved_values), 2) self.assertEqual(retrieved_values[0], 1) self.assertItemsEqual(retrieved_values[1], [1] * 10)
def testUnexpectedFilesType(self): with self.assertRaises(ValueError): datasets.StreamingFilesDataset(123, filetype='tfrecord')
def testUnexpectedFiletypeType(self): with self.assertRaises(ValueError): datasets.StreamingFilesDataset(os.path.join( self.get_temp_dir(), '*'), filetype=3)