コード例 #1
0
  def testFixedLengthReader(self):
    dataset = readers.FixedLengthRecordDataset(
        self._createFixedLengthRecordFiles(), self._record_bytes)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
コード例 #2
0
  def testTextLineReaderWithFlatMap(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles())
    dataset = dataset.flat_map(readers.TextLineDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._text_line)
コード例 #3
0
    def testFixedLengthReader(self):
        dataset = readers.FixedLengthRecordDataset(
            self._createFixedLengthRecordFiles(), self._record_bytes)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
コード例 #4
0
    def __init__(self,
                 dataset_fn,
                 worker_device_map,
                 prefetch_on_device=None,
                 auto_shard=False):
        """Initialize the MultiWorkerDataset object.

    Args:
      dataset_fn: a function that returns a `tf.data.Dataset`.
      worker_device_map: a dict mapping from each worker to a list of devices
        that belong to this worker.
      prefetch_on_device: whether to prefetch to devices.
      auto_shard: whether to auto-shard the dataset.
    """
        self._worker_device_map = worker_device_map
        self._datasets = {}
        # TODO(yuefengz, priyag): support different set of jobs for input
        # processing.
        for i, (worker,
                worker_devices) in enumerate(six.iteritems(worker_device_map)):
            with ops.device(worker):
                worker_input = dataset_fn()
                if auto_shard:
                    worker_input = input_ops.auto_shard_dataset(
                        worker_input, len(worker_device_map), i)
                self._datasets[worker] = PerDeviceDataset(
                    worker_input,
                    worker_devices,
                    prefetch_on_device=prefetch_on_device)
コード例 #5
0
    def testFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTFRecordFiles())
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._record)
コード例 #6
0
    def testTextLineReaderWithFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTextFiles())
        dataset = dataset.flat_map(readers.TextLineDataset)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._text_line)
コード例 #7
0
  def testFlatMap(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._record)
コード例 #8
0
    def testZip(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
        self._verifySimpleShardingOutput(dataset, record_fn)
コード例 #9
0
  def testFixedLengthReaderWithFlatMap(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createFixedLengthRecordFiles())
    dataset = dataset.flat_map(
        lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes))
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
コード例 #10
0
    def testFixedLengthReaderWithFlatMap(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createFixedLengthRecordFiles())
        dataset = dataset.flat_map(
            lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes))
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._fixed_length_record)
コード例 #11
0
  def testZip(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())
    dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f))
    self._verifySimpleShardingOutput(dataset, record_fn)
コード例 #12
0
  def testInterleave(self):
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.interleave(
        readers.TFRecordDataset, cycle_length=4, block_length=self._num_records)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    # Since block_length == num records in each file, the output will still
    # contain records in order of files.
    self._verifySimpleShardingOutput(dataset, self._record)
コード例 #13
0
    def testInterleave(self):
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTFRecordFiles())
        dataset = dataset.interleave(readers.TFRecordDataset,
                                     cycle_length=4,
                                     block_length=self._num_records)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        # Since block_length == num records in each file, the output will still
        # contain records in order of files.
        self._verifySimpleShardingOutput(dataset, self._record)
コード例 #14
0
  def testConcat(self):
    dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset2 = readers.TextLineDataset(self._createTextFiles())
    dataset = dataset1.concatenate(dataset2)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with self.cached_session() as sess:
      for f in range(self._shard_index, self._num_files, self._num_shards):
        for r in range(self._num_records):
          self.assertAllEqual(self._record(r, f), sess.run(next_element))
      for f in range(self._shard_index, self._num_files, self._num_shards):
        for r in range(self._num_records):
          self.assertAllEqual(self._text_line(r, f), sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
コード例 #15
0
  def testListfiles(self):
    filenames = self._createTFRecordFiles()
    file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
    dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with self.cached_session() as sess:
      actual, expected = [], []
      for f in range(self._shard_index, self._num_files, self._num_shards):
        for r in range(self._num_records):
          actual.append(sess.run(next_element))
          expected.append(self._record(r, f))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)
      self.assertAllEqual(expected, actual)
コード例 #16
0
    def testListfiles(self):
        filenames = self._createTFRecordFiles()
        file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
        dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            actual, expected = [], []
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    actual.append(sess.run(next_element))
                    expected.append(self._record(r, f))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
            self.assertAllEqual(expected, actual)
コード例 #17
0
ファイル: values.py プロジェクト: sonnyhu/tensorflow
  def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
    """Initialize the MultiWorkerDataset object.

    Args:
      dataset_fn: a function that returns a `tf.data.Dataset`.
      worker_device_map: a dict mapping from each worker to a list of devices
        that belong to this worker.
      prefetch_on_device: whether to prefetch to devices.
    """
    self._worker_device_map = worker_device_map
    self._datasets = {}
    # TODO(yuefengz, priyag): support different set of jobs for input
    # processing.
    for i, (worker, worker_devices) in enumerate(
        six.iteritems(worker_device_map)):
      with ops.device(worker):
        worker_input = dataset_fn()
        worker_input = input_ops.auto_shard_dataset(
            worker_input, len(worker_device_map), i)
        self._datasets[worker] = PerDeviceDataset(
            worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
コード例 #18
0
    def testConcat(self):
        dataset1 = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset2 = readers.TextLineDataset(self._createTextFiles())
        dataset = dataset1.concatenate(dataset2)
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._record(r, f),
                                        sess.run(next_element))
            for f in range(self._shard_index, self._num_files,
                           self._num_shards):
                for r in range(self._num_records):
                    self.assertAllEqual(self._text_line(r, f),
                                        sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)
コード例 #19
0
  def testComplexPipeline(self):
    # Setup a complex input pipeline.
    batch_size = 2
    num_epochs = 5
    dataset = dataset_ops.Dataset.from_tensor_slices(
        self._createTFRecordFiles())
    dataset = dataset.shuffle(buffer_size=self._num_files)
    dataset = dataset.flat_map(readers.TFRecordDataset)
    dataset = dataset.prefetch(buffer_size=batch_size)
    dataset = dataset.shuffle(2 * self._num_files * self._num_records)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(lambda x: x)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=None)

    # Auto shard.
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    # Verify output.
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    with self.cached_session() as sess:
      actual = []
      num_iterations = (self._num_files * self._num_records * num_epochs) // (
          self._num_shards * batch_size)
      for _ in range(num_iterations):
        actual.extend(sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)

      expected = []
      for f in range(0, self._num_files, self._num_shards):
        for r in range(self._num_records):
          expected.append(self._record(r, f))
      expected *= num_epochs

      self.assertAllEqual(sorted(expected), sorted(actual))
コード例 #20
0
    def testComplexPipeline(self):
        # Setup a complex input pipeline.
        batch_size = 2
        num_epochs = 5
        dataset = dataset_ops.Dataset.from_tensor_slices(
            self._createTFRecordFiles())
        dataset = dataset.shuffle(buffer_size=self._num_files)
        dataset = dataset.flat_map(readers.TFRecordDataset)
        dataset = dataset.prefetch(buffer_size=batch_size)
        dataset = dataset.shuffle(2 * self._num_files * self._num_records)
        dataset = dataset.repeat(num_epochs)
        dataset = dataset.map(lambda x: x)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(buffer_size=None)

        # Auto shard.
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        # Verify output.
        iterator = dataset.make_one_shot_iterator()
        next_element = iterator.get_next()
        with self.cached_session() as sess:
            actual = []
            num_iterations = (self._num_files * self._num_records *
                              num_epochs) // (self._num_shards * batch_size)
            for _ in range(num_iterations):
                actual.extend(sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

            expected = []
            for f in range(0, self._num_files, self._num_shards):
                for r in range(self._num_records):
                    expected.append(self._record(r, f))
            expected *= num_epochs

            self.assertAllEqual(sorted(expected), sorted(actual))
コード例 #21
0
  def testTFRecordDataset(self):
    dataset = readers.TFRecordDataset(self._createTFRecordFiles())
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._record)
コード例 #22
0
    def testTextLineReader(self):
        dataset = readers.TextLineDataset(self._createTextFiles())
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._text_line)
コード例 #23
0
    def testTFRecordDataset(self):
        dataset = readers.TFRecordDataset(self._createTFRecordFiles())
        dataset = input_ops.auto_shard_dataset(dataset, self._num_shards,
                                               self._shard_index)

        self._verifySimpleShardingOutput(dataset, self._record)
コード例 #24
0
  def testTextLineReader(self):
    dataset = readers.TextLineDataset(self._createTextFiles())
    dataset = input_ops.auto_shard_dataset(
        dataset, self._num_shards, self._shard_index)

    self._verifySimpleShardingOutput(dataset, self._text_line)