Exemplo n.º 1
0
 def testIteratorResourceCleanup(self):
     filename = os.path.join(self.get_temp_dir(), "text.txt")
     with open(filename, "wt") as f:
         for i in range(3):
             f.write("%d\n" % (i, ))
     with context.eager_mode():
         first_iterator = iter(readers.TextLineDataset(filename))
         self.assertEqual(b"0", next(first_iterator).numpy())
         second_iterator = iter(readers.TextLineDataset(filename))
         self.assertEqual(b"0", next(second_iterator).numpy())
         # Eager kernel caching is based on op attributes, which includes the
         # Dataset's output shape. Create a different kernel to test that they
         # don't create resources with the same names.
         different_kernel_iterator = iter(
             readers.TextLineDataset(filename).repeat().batch(16))
         self.assertEqual([16], next(different_kernel_iterator).shape)
         # Remove our references to the Python Iterator objects, which (assuming no
         # reference cycles) is enough to trigger DestroyResourceOp and close the
         # partially-read files.
         del first_iterator
         del second_iterator
         del different_kernel_iterator
         if not psutil_import_succeeded:
             self.skipTest(
                 "psutil is required to check that we've closed our files.")
         open_files = psutil.Process().open_files()
         self.assertNotIn(filename,
                          [open_file.path for open_file in open_files])
 def interleave_fn(filename):
     # Test function that uses control flow. The True branch is never taken
     concat = string_ops.string_join([filename, "abc"])
     return control_flow_ops.cond(
         math_ops.equal(filename, "abc"),
         lambda: reader_ops.TextLineDataset(concat),
         lambda: reader_ops.TextLineDataset(filename))
Exemplo n.º 3
0
 def testName(self):
     files = self._createFiles(1, 5)
     expected_output = [self._lineText(0, i) for i in range(5)]
     ds = readers.TextLineDataset(files, name="text_line_dataset")
     self.assertDatasetProduces(ds,
                                expected_output=expected_output,
                                assert_items_equal=True)
Exemplo n.º 4
0
 def filename_to_dataset(filename):
     ds = core_readers.TextLineDataset(filename)
     if header:
         ds = ds.skip(1)
     if comment is not None:
         ds = ds.filter(filter_fn)
     return ds
Exemplo n.º 5
0
 def testFileNamesDatasetMustContainStrings(self):
   with self.assertRaisesRegex(
       TypeError,
       "The `filenames` argument must contain `tf.string` elements. Got a "
       "dataset of `tf.int32` elements."):
     filenames = dataset_ops.Dataset.from_tensors(0)
     readers.TextLineDataset(filenames)
Exemplo n.º 6
0
 def filename_to_dataset(filename):
   ds = core_readers.TextLineDataset(filename)
   if skip > 0:
     ds = ds.skip(skip)
   if filter_fn is not None:
     ds = ds.filter(filter_fn)
   return ds
Exemplo n.º 7
0
  def testTextLineReader(self):
    dataset = readers.TextLineDataset(self._createTextFiles())

    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._text_line)
Exemplo n.º 8
0
 def dataset_fn(filenames, num_epochs, batch_size=None):
     repeat_dataset = readers.TextLineDataset(
         filenames,
         compression_type=compression_type).repeat(num_epochs)
     if batch_size:
         return repeat_dataset.batch(batch_size)
     return repeat_dataset
Exemplo n.º 9
0
 def testFileNamesMustBeScalars(self):
   with self.assertRaisesRegex(
       TypeError,
       "The `filenames` argument must contain `tf.string` elements of shape "
       r"\[\] \(i.e. scalars\)."):
     filenames = dataset_ops.Dataset.from_tensors([["File 1", "File 2"],
                                                   ["File 3", "File 4"]])
     readers.TextLineDataset(filenames)
Exemplo n.º 10
0
  def testPathlib(self):
    files = self._createFiles(1, 5)
    files = [pathlib.Path(f) for f in files]

    expected_output = [self._lineText(0, i) for i in range(5)]
    ds = readers.TextLineDataset(files)
    self.assertDatasetProduces(
        ds, expected_output=expected_output, assert_items_equal=True)
Exemplo n.º 11
0
 def _make_test_datasets(self, inputs, **kwargs):
     # Test by comparing its output to what we could get with map->decode_csv
     filenames = self._setup_files(inputs)
     dataset_expected = core_readers.TextLineDataset(filenames)
     dataset_expected = dataset_expected.map(
         lambda l: parsing_ops.decode_csv(l, **kwargs))
     dataset_actual = readers.CsvDataset(filenames, **kwargs)
     return (dataset_actual, dataset_expected)
Exemplo n.º 12
0
  def testBuffering(self):
    test_filenames = self._createFiles(2, 5, crlf=True)

    repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
    expected_output = []
    for j in range(2):
      expected_output.extend([self._lineText(j, i) for i in range(5)])
    self.assertDatasetProduces(repeat_dataset, expected_output=expected_output)
Exemplo n.º 13
0
 def benchmarkMapWithStrings(self):
   self._setUp(self.STR_VAL)
   for i in range(len(self._filenames)):
     num_cols = self._num_cols[i]
     kwargs = {'record_defaults': [['']] * num_cols}
     dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
     dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs))  # pylint: disable=cell-var-from-loop
     self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
   self._tearDown()
Exemplo n.º 14
0
    def testDirectFilenameTextLineReaderPipeline(self):
        dataset = core_readers.TextLineDataset(self.test_filenames)
        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"%d: %d" % (f, r)  # pylint:disable=g-complex-comprehension
            for f in (0, 5) for r in range(0, 10)
        ]
        self.assertDatasetProduces(dataset, expected)
Exemplo n.º 15
0
    def testZip(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
        self._verifySimpleShardingOutput(dataset, record_fn)
Exemplo n.º 16
0
 def benchmarkCsvDatasetWithStrings(self):
   self._setUp(self.STR_VAL)
   for i in range(len(self._filenames)):
     num_cols = self._num_cols[i]
     kwargs = {'record_defaults': [['']] * num_cols}
     dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
     dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
     self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
   self._tearDown()
Exemplo n.º 17
0
 def benchmark_csv_dataset_with_floats(self):
     self._set_up(self.FLOAT_VAL)
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [[0.0]] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
         self._run_benchmark(dataset, num_cols, 'csv_float_fused_dataset')
     self._tear_down()
Exemplo n.º 18
0
 def testParallelRead(self):
   test_filenames = self._createFiles(10, 10)
   files = dataset_ops.Dataset.from_tensor_slices(test_filenames).repeat(10)
   expected_output = []
   for j in range(10):
     expected_output.extend(self._lineText(j, i) for i in range(10))
   dataset = readers.TextLineDataset(files, num_parallel_reads=4)
   self.assertDatasetProduces(
       dataset, expected_output=expected_output * 10, assert_items_equal=True)
Exemplo n.º 19
0
 def benchmarkCsvDataset(self):
     self._setUp()
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [[0.0]] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
         dataset = dataset.batch(self._batch_size)
         self._runBenchmark(dataset, num_cols, 'csv_fused_dataset')
     self._tearDown()
Exemplo n.º 20
0
 def benchmark_map_with_floats(self):
     self._set_up(self.FLOAT_VAL)
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [[0.0]] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = dataset.map(
             lambda l: parsing_ops.decode_csv(l, **kwargs))  # pylint: disable=cell-var-from-loop
         self._run_benchmark(dataset, num_cols, 'csv_float_map_decode_csv')
     self._tear_down()
Exemplo n.º 21
0
 def benchmarkBatchThenMap(self):
     self._setUp()
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [[0.0]] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = dataset.map(
             lambda l: gen_parsing_ops.decode_csv(l, **kwargs))  # pylint: disable=cell-var-from-loop
         dataset = dataset.batch(self._batch_size)
         self._runBenchmark(dataset, num_cols, 'csv_map_then_batch')
     self._tearDown()
 def benchmark_csv_dataset_with_strings(self):
     self._set_up(self.STR_VAL)
     for i in range(len(self._filenames)):
         num_cols = self._num_cols[i]
         kwargs = {'record_defaults': [['']] * num_cols}
         dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
         dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat()  # pylint: disable=cell-var-from-loop
         self._run_benchmark(dataset=dataset,
                             num_cols=num_cols,
                             prefix='csv_strings_fused_dataset',
                             benchmark_id=4)
     self._tear_down()
Exemplo n.º 23
0
    def test_from_file(self):
        vocabulary_file = self._createVocabFile("test.txt",
                                                ("one", "two", "three"))
        ds = reader_ops.TextLineDataset(vocabulary_file)
        ds = ds.enumerate(start=1)
        init = lookup_ops.DatasetInitializer(ds)
        table = self.getHashTable()(init, default_value="")
        self.initialize_table(table)

        output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
        result = self.evaluate(output)
        self.assertAllEqual(["two", "three", ""], result)
  def testTextLineDatasetBuffering(self):
    test_filenames = self._createFiles(2, 5, crlf=True)

    repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
    iterator = repeat_dataset.make_one_shot_iterator()

    with self.test_session() as sess:
      for j in range(2):
        for i in range(5):
          self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(iterator.get_next())
Exemplo n.º 25
0
    def __init__(self, filenames, compression_type=None, buffer_size=None):
        """Creates a `TextLineDataset`.

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
        to buffer. A value of 0 results in the default buffering values chosen
        based on the compression type.
    """
        dataset = readers.TextLineDataset(filenames, compression_type,
                                          buffer_size)
        super(TextLineDataset, self).__init__(dataset)
Exemplo n.º 26
0
def make_libsvm_dataset(file_names,
                        num_features,
                        dtype=None,
                        label_dtype=None,
                        batch_size=1,
                        compression_type='',
                        buffer_size=None,
                        num_parallel_parser_calls=None,
                        drop_final_batch=False,
                        prefetch_buffer_size=0):
  """Reads LibSVM files into a dataset.

  Args:
    file_names: A `tf.string` tensor containing one or more filenames.
    num_features: The number of features.
    dtype(Optional): The type of the output feature tensor. Default to tf.float32.
    label_dtype(Optional): The type of the output label tensor. Default to tf.int64.
    batch_size: (Optional.) An int representing the number of records to combine
      in a single batch, default 1.
    compression_type: (Optional.) A `tf.string` scalar evaluating to one of
      `""` (no compression), `"ZLIB"`, or `"GZIP"`.
    buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
      to buffer. A value of 0 results in the default buffering values chosen
      based on the compression type.
    num_parallel_parser_calls: (Optional.) Number of parallel
      records to parse in parallel. Defaults to an automatic selection.
    drop_final_batch: (Optional.) Whether the last batch should be
      dropped in case its size is smaller than `batch_size`; the
      default behavior is not to drop the smaller batch.
    prefetch_buffer_size: (Optional.) An int specifying the number of
      feature batches to prefetch for performance improvement.
      Defaults to auto-tune. Set to 0 to disable prefetching.
  """
  dataset = core_readers.TextLineDataset(file_names,
                                         compression_type=compression_type, 
                                         buffer_size=buffer_size)
  def parsing_func(content):
    return decode_libsvm(content, num_features, dtype, label_type)

  dataset = dataset.apply(batching.map_and_batch(
                                        parsing_func, 
                                        batch_size, 
                                        num_parallel_calls=num_parallel_parser_calls,
                                        drop_remainder=drop_final_batch))
  if prefetch_buffer_size == 0:
    return dataset
  else:
    return dataset.prefetch(buffer_size=prefetch_buffer_size)
Exemplo n.º 27
0
  def testConcat(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())

    dataset = dataset1.concatenate(dataset2)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    next_element_fn = self._getNext(dataset)
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._record(r, f), self.evaluate(next_element_fn()))
    for f in range(self._shard_index, self._num_files, self._num_shards):
      for r in range(self._num_records):
        self.assertAllEqual(
            self._text_line(r, f), self.evaluate(next_element_fn()))
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(next_element_fn())
Exemplo n.º 28
0
    def testConcat(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset1.concatenate(dataset2)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._record(r, f),
                                        sess.run(next_element))
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._text_line(r, f),
                                        sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
Exemplo n.º 29
0
def _TextLineDataset(filename):
  buffer_size = 8 * 1024 * 1024  # 8 MiB per file
  dataset = readers.TextLineDataset(filename, buffer_size=buffer_size)
  return dataset
Exemplo n.º 30
0
 def testTextLineInputs(self):
     dataset = readers.TextLineDataset("")
     self.checkNumInputs(dataset, 0)