Пример #1
0
    def testOverrideThreadPool(self):
        def get_thread_id(_):
            # Python creates a dummy thread object to represent the current
            # thread when called from an "alien" thread (such as a
            # `PrivateThreadPool` thread in this case). It does not include
            # the TensorFlow-given display name, but it has a unique
            # identifier that maps one-to-one with the underlying OS thread.
            return np.array(threading.current_thread().ident).astype(np.int64)

        for num_threads in [1, 2, 4, 8, 16]:

            dataset = (Dataset.range(1000).map(
                lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
                num_parallel_calls=32).apply(unique.unique()))

            dataset = threadpool.override_threadpool(
                dataset,
                threadpool.PrivateThreadPool(
                    num_threads,
                    display_name='private_thread_pool_%d' % num_threads))

            thread_ids = []
            for next_element in datasets.Iterator(dataset):
                thread_ids.append(next_element)
            self.assertEqual(len(thread_ids), len(set(thread_ids)))
            self.assertGreater(len(thread_ids), 0)
            # NOTE(mrry): We don't control the thread pool scheduling, and
            # so cannot guarantee that all of the threads in the pool will
            # perform work.
            self.assertLessEqual(len(thread_ids), num_threads)
  def _testSimpleHelper(self, dtype, test_cases):
    """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component
        is the expected sequence of outputs from the transformation.
    """

    # The `current_test_case` will be updated when we loop over `test_cases`
    # below; declare it here so that the generator can capture it once.
    current_test_case = []
    dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
                                                 dtype).apply(unique.unique())
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      for test_case, expected in test_cases:
        current_test_case = test_case
        sess.run(iterator.initializer)
        for element in expected:
          if dtype == dtypes.string:
            element = compat.as_bytes(element)
          self.assertAllEqual(element, sess.run(next_element))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(next_element)
    def _testSimpleHelper(self, dtype, test_cases):
        """Test the `unique()` transformation on a list of test cases.

    Args:
      dtype: The `dtype` of the elements in each test case.
      test_cases: A list of pairs of lists. The first component is the test
        input that will be passed to the transformation; the second component
        is the expected sequence of outputs from the transformation.
    """

        # The `current_test_case` will be updated when we loop over `test_cases`
        # below; declare it here so that the generator can capture it once.
        current_test_case = []
        dataset = dataset_ops.Dataset.from_generator(
            lambda: current_test_case, dtype).apply(unique.unique())
        iterator = dataset.make_initializable_iterator()
        next_element = iterator.get_next()

        with self.test_session() as sess:
            for test_case, expected in test_cases:
                current_test_case = test_case
                sess.run(iterator.initializer)
                for element in expected:
                    if dtype == dtypes.string:
                        element = compat.as_bytes(element)
                    self.assertAllEqual(element, sess.run(next_element))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element)
Пример #4
0
  def testOverrideThreadPool(self):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    for num_threads in [1, 2, 4, 8, 16]:

      dataset = (
          Dataset.range(1000).map(
              lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
              num_parallel_calls=32).apply(unique.unique()))

      dataset = threadpool.override_threadpool(
          dataset,
          threadpool.PrivateThreadPool(
              num_threads, display_name='private_thread_pool_%d' % num_threads))

      thread_ids = []
      for next_element in datasets.Iterator(dataset):
        thread_ids.append(next_element)
      self.assertEqual(len(thread_ids), len(set(thread_ids)))
      self.assertGreater(len(thread_ids), 0)
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
  def testNumThreads(self, num_threads, max_intra_op_parallelism):

    def get_thread_id(_):
      # Python creates a dummy thread object to represent the current
      # thread when called from an "alien" thread (such as a
      # `PrivateThreadPool` thread in this case). It does not include
      # the TensorFlow-given display name, but it has a unique
      # identifier that maps one-to-one with the underlying OS thread.
      return np.array(threading.current_thread().ident).astype(np.int64)

    dataset = (
        dataset_ops.Dataset.range(1000).map(
            lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
            num_parallel_calls=32).apply(unique.unique()))

    dataset = threadpool.override_threadpool(
        dataset,
        threadpool.PrivateThreadPool(
            num_threads,
            max_intra_op_parallelism=max_intra_op_parallelism,
            display_name="private_thread_pool_%d" % num_threads))

    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(iterator.initializer)
      thread_ids = []
      try:
        while True:
          thread_ids.append(sess.run(next_element))
      except errors.OutOfRangeError:
        pass
      self.assertEqual(len(thread_ids), len(set(thread_ids)))
      self.assertGreater(len(thread_ids), 0)
      # NOTE(mrry): We don't control the thread pool scheduling, and
      # so cannot guarantee that all of the threads in the pool will
      # perform work.
      self.assertLessEqual(len(thread_ids), num_threads)
Пример #6
0
    def testNumThreads(self):
        def get_thread_id(_):
            # Python creates a dummy thread object to represent the current
            # thread when called from an "alien" thread (such as a
            # `PrivateThreadPool` thread in this case). It does not include
            # the TensorFlow-given display name, but it has a unique
            # identifier that maps one-to-one with the underlying OS thread.
            return np.array(threading.current_thread().ident).astype(np.int64)

        for num_threads in [1, 2, 4, 8, 16]:

            dataset = (dataset_ops.Dataset.range(1000).map(
                lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
                num_parallel_calls=32).apply(unique.unique()))

            dataset = threadpool.override_threadpool(
                dataset,
                threadpool.PrivateThreadPool(
                    num_threads,
                    display_name="private_thread_pool_%d" % num_threads))

            iterator = dataset.make_initializable_iterator()
            next_element = iterator.get_next()

            with self.test_session() as sess:
                sess.run(iterator.initializer)
                thread_ids = []
                try:
                    while True:
                        thread_ids.append(sess.run(next_element))
                except errors.OutOfRangeError:
                    pass
                self.assertEqual(len(thread_ids), len(set(thread_ids)))
                self.assertGreater(len(thread_ids), 0)
                # NOTE (mrry): We don't control the thread pool scheduling, and id:756
                # https://github.com/imdone/tensorflow/issues/757
                # so cannot guarantee that all of the threads in the pool will
                # perform work.
                self.assertLessEqual(len(thread_ids), num_threads)
 def build_dataset(num_elements, unique_elem_range):
   return dataset_ops.Dataset.range(num_elements).map(
       lambda x: x % unique_elem_range).apply(unique.unique())
 def build_dataset(num_elements, unique_elem_range):
     return dataset_ops.Dataset.range(num_elements).map(
         lambda x: x % unique_elem_range).apply(unique.unique())