def _testNumThreadsHelper(self, num_threads, override_threadpool_fn):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    dataset = (
        dataset_ops.Dataset.range(1000).map(
            lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
            num_parallel_calls=32).apply(unique.unique()))
    dataset = override_threadpool_fn(dataset)
    next_element = self.getNext(dataset, requires_initialization=True)

    thread_ids = []
    try:
      while True:
        thread_ids.append(self.evaluate(next_element()))
    except errors.OutOfRangeError:
      pass
    self.assertLen(thread_ids, len(set(thread_ids)))
    self.assertNotEmpty(thread_ids)
    if num_threads:
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
Esempio n. 2
0
    def testOverrideThreadPool(self):
        def get_thread_id(_):
            # Python creates a dummy thread object to represent the current
            # thread when called from an "alien" thread (such as a
            # `PrivateThreadPool` thread in this case). It does not include
            # the TensorFlow-given display name, but it has a unique
            # identifier that maps one-to-one with the underlying OS thread.
            return np.array(threading.current_thread().ident).astype(np.int64)

        for num_threads in [1, 2, 4, 8, 16]:

            dataset = (Dataset.range(1000).map(
                lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
                num_parallel_calls=32).apply(unique.unique()))

            dataset = threadpool.override_threadpool(
                dataset,
                threadpool.PrivateThreadPool(
                    num_threads,
                    display_name='private_thread_pool_%d' % num_threads))

            thread_ids = []
            for next_element in datasets.Iterator(dataset):
                thread_ids.append(next_element)
            self.assertEqual(len(thread_ids), len(set(thread_ids)))
            self.assertGreater(len(thread_ids), 0)
            # NOTE(mrry): We don't control the thread pool scheduling, and
            # so cannot guarantee that all of the threads in the pool will
            # perform work.
            self.assertLessEqual(len(thread_ids), num_threads)
Esempio n. 3
0
  def testOverrideThreadPool(self):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    for num_threads in [1, 2, 4, 8, 16]:

      dataset = (
          Dataset.range(1000).map(
              lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
              num_parallel_calls=32).apply(unique.unique()))

      dataset = threadpool.override_threadpool(
          dataset,
          threadpool.PrivateThreadPool(
              num_threads, display_name='private_thread_pool_%d' % num_threads))

      thread_ids = []
      for next_element in datasets.Iterator(dataset):
        thread_ids.append(next_element)
      self.assertEqual(len(thread_ids), len(set(thread_ids)))
      self.assertGreater(len(thread_ids), 0)
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
Esempio n. 4
0
  def _testSimpleHelper(self, dtype, test_cases):
    """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component
        is the expected sequence of outputs from the transformation.
    """

    # The `current_test_case` will be updated when we loop over `test_cases`
    # below; declare it here so that the generator can capture it once.
    current_test_case = []
    dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
                                                 dtype).apply(unique.unique())
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      for test_case, expected in test_cases:
        current_test_case = test_case
        self.evaluate(iterator.initializer)
        for element in expected:
          if dtype == dtypes.string:
            element = compat.as_bytes(element)
          self.assertAllEqual(element, self.evaluate(next_element))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(next_element)
Esempio n. 5
0
    def _testSimpleHelper(self, dtype, test_cases):
        """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component
        is the expected sequence of outputs from the transformation.
    """

        # The `current_test_case` will be updated when we loop over `test_cases`
        # below; declare it here so that the generator can capture it once.
        current_test_case = []
        dataset = dataset_ops.Dataset.from_generator(
            lambda: current_test_case, dtype).apply(unique.unique())
        iterator = dataset_ops.make_initializable_iterator(dataset)
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            for test_case, expected in test_cases:
                current_test_case = test_case
                self.evaluate(iterator.initializer)
                for element in expected:
                    if dtype == dtypes.string:
                        element = compat.as_bytes(element)
                    self.assertAllEqual(element, self.evaluate(next_element))
                with self.assertRaises(errors.OutOfRangeError):
                    self.evaluate(next_element)
  def testUnsupportedOpInPipeline(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.batch(5)
    dataset = dataset.apply(unique.unique())

    with self.assertRaises(errors.NotFoundError):
      dataset = distribute._AutoShardDataset(dataset, 2, 0)
      self.evaluate(self.getNext(dataset)())
Esempio n. 7
0
    def testUnsupportedOpInPipeline(self):
        dataset = dataset_ops.Dataset.list_files(self.test_filenames)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.batch(5)
        dataset = dataset.apply(unique.unique())

        with self.assertRaises(errors.NotFoundError):
            dataset = distribute._AutoShardDataset(dataset, 2, 0)
            self.evaluate(self.getNext(dataset)())
Esempio n. 8
0
    def testUnsupportedTypes(self):
        """Should raise TypeError when element type doesn't match with the

    dtypes.int64, dtypes.int32 or dtypes.string (supported types).
    """

        for dtype in [
                dtypes.bool, dtypes.double, dtypes.complex64, dtypes.float32,
                dtypes.float64, dtypes.qint16, dtypes.qint32
        ]:
            with self.assertRaises(TypeError):
                _ = dataset_ops.Dataset.from_generator(
                    lambda: [], dtype).apply(unique.unique())
  def testUnknownOpInPipelineStillShardsAtTheEnd(self):
    dataset = dataset_ops.Dataset.list_files(self.test_filenames, shuffle=False)
    dataset = dataset.flat_map(core_readers.TFRecordDataset)
    dataset = dataset.apply(unique.unique())

    dataset = distribute._AutoShardDataset(dataset, 5, 0)

    expected = [
        b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
        for f in range(0, 10)
        for r in (0, 5)
    ]
    self.assertDatasetProduces(dataset, expected)
    def testUnknownOpInPipelineStillShardsAtTheEnd(self):
        dataset = dataset_ops.Dataset.list_files(self._filenames,
                                                 shuffle=False)
        dataset = dataset.flat_map(core_readers.TFRecordDataset)
        dataset = dataset.apply(unique.unique())

        dataset = distribute._AutoShardDataset(dataset, 5, 0)

        expected = [
            b"Record %d of file %d" % (r, f)  # pylint:disable=g-complex-comprehension
            for f in range(0, 10) for r in (0, 5)
        ]
        self.assertDatasetProduces(dataset, expected)
Esempio n. 11
0
def unique():
  """Creates a `Dataset` from another `Dataset`, discarding duplicates.

  Use this transformation to produce a dataset that contains one instance of
  each unique element in the input. For example:

  ```python
  dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])

  # Using `unique()` will drop the duplicate elements.
  dataset = dataset.apply(tf.data.experimental.unique())  # ==> { 1, 37, 2 }
  ```

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
  return experimental_unique.unique()
Esempio n. 12
0
def unique():
    """Creates a `Dataset` from another `Dataset`, discarding duplicates.

  Use this transformation to produce a dataset that contains one instance of
  each unique element in the input. For example:

  ```python
  dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])

  # Using `unique()` will drop the duplicate elements.
  dataset = dataset.apply(tf.data.experimental.unique())  # ==> { 1, 37, 2 }
  ```

  Returns:
    A `Dataset` transformation function, which can be passed to
    `tf.data.Dataset.apply`.
  """
    return experimental_unique.unique()
  def testNumThreads(self, num_threads, max_intra_op_parallelism):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    dataset = (
        dataset_ops.Dataset.range(1000).map(
            lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
            num_parallel_calls=32).apply(unique.unique()))

    dataset = threadpool.override_threadpool(
        dataset,
        threadpool.PrivateThreadPool(
            num_threads,
            max_intra_op_parallelism=max_intra_op_parallelism,
            display_name="private_thread_pool_%d" % num_threads))

    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(iterator.initializer)
      thread_ids = []
      try:
        while True:
          thread_ids.append(sess.run(next_element))
      except errors.OutOfRangeError:
        pass
      self.assertEqual(len(thread_ids), len(set(thread_ids)))
      self.assertGreater(len(thread_ids), 0)
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
Esempio n. 14
0
  def _testSimpleHelper(self, dtype, test_cases):
    """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component
        is the expected sequence of outputs from the transformation.
    """

    # The `current_test_case` will be updated when we loop over `test_cases`
    # below; declare it here so that the generator can capture it once.
    current_test_case = []
    dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
                                                 dtype).apply(unique.unique())

    for test_case, expected in test_cases:
      current_test_case = test_case
      self.assertDatasetProduces(dataset, [
          compat.as_bytes(element) if dtype == dtypes.string else element
          for element in expected
      ])
Esempio n. 15
0
    def _testSimpleHelper(self, dtype, test_cases):
        """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component
        is the expected sequence of outputs from the transformation.
    """

        # The `current_test_case` will be updated when we loop over `test_cases`
        # below; declare it here so that the generator can capture it once.
        current_test_case = []
        dataset = dataset_ops.Dataset.from_generator(
            lambda: current_test_case, dtype).apply(unique.unique())

        for test_case, expected in test_cases:
            current_test_case = test_case
            self.assertDatasetProduces(dataset, [
                compat.as_bytes(element) if dtype == dtypes.string else element
                for element in expected
            ])
Esempio n. 16
0
  def testNumThreads(self, num_threads, max_intra_op_parallelism):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    dataset = (
        dataset_ops.Dataset.range(1000).map(
            lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
            num_parallel_calls=32).apply(unique.unique()))

    dataset = threadpool.override_threadpool(
        dataset,
        threadpool.PrivateThreadPool(
            num_threads,
            max_intra_op_parallelism=max_intra_op_parallelism,
            display_name="private_thread_pool_%d" % num_threads))

    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      self.evaluate(iterator.initializer)
      thread_ids = []
      try:
        while True:
          thread_ids.append(sess.run(next_element))
      except errors.OutOfRangeError:
        pass
      self.assertEqual(len(thread_ids), len(set(thread_ids)))
      self.assertGreater(len(thread_ids), 0)
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
 def build_dataset(num_elements, unique_elem_range):
   return dataset_ops.Dataset.range(num_elements).map(
       lambda x: x % unique_elem_range).apply(unique.unique())
Esempio n. 18
0
 def build_dataset(num_elements, unique_elem_range):
     return dataset_ops.Dataset.range(num_elements).map(
         lambda x: x % unique_elem_range).apply(unique.unique())