def override_threadpool_fn(dataset):
   t_options = threading_options.ThreadingOptions()
   if max_intra_op_parallelism is not None:
     t_options.max_intra_op_parallelism = max_intra_op_parallelism
   if num_threads is not None:
     t_options.private_threadpool_size = num_threads
   options = dataset_ops.Options()
   options.experimental_threading = t_options
   return dataset.with_options(options)
Beispiel #2
0
 def testOptionsHaveDefaults(self):
     options1 = dataset_ops.Options()
     options2 = dataset_ops.Options()
     self.assertIsNot(options1.experimental_optimization,
                      options2.experimental_optimization)
     self.assertIsNot(options1.threading, options2.threading)
     self.assertEqual(options1.experimental_optimization,
                      optimization_options.OptimizationOptions())
     self.assertEqual(options1.threading,
                      threading_options.ThreadingOptions())
Beispiel #3
0
 def testSloppyInterleaveInOrder(self, num_elements, num_parallel_calls):
     dataset, coordination_events = _make_coordinated_sloppy_dataset(
         num_elements, num_parallel_calls)
     options = dataset_ops.Options()
     options.experimental_threading = threading_options.ThreadingOptions()
     options.experimental_threading.private_threadpool_size = (
         num_parallel_calls + 1)
     dataset = dataset.with_options(options)
     get_next = self.getNext(dataset, requires_initialization=True)
     for i in range(num_elements):
         coordination_events[i].set()
         self.assertEqual(i * i, self.evaluate(get_next()))
     with self.assertRaises(errors.OutOfRangeError):
         self.evaluate(get_next())
Beispiel #4
0
    def testSloppyInterleaveInOrder(self, input_values, cycle_length,
                                    block_length, num_parallel_calls):
        dataset, coordination_events = _make_coordinated_sloppy_dataset(
            input_values, cycle_length, block_length, num_parallel_calls)
        options = dataset_ops.Options()
        options.experimental_threading = threading_options.ThreadingOptions()
        options.experimental_threading.private_threadpool_size = (
            num_parallel_calls + 1)
        dataset = dataset.with_options(options)

        get_next = self.getNext(dataset, requires_initialization=True)
        for expected_element in _interleave(_repeat(input_values, 2),
                                            cycle_length, block_length):
            coordination_events[expected_element].set()
            self.assertEqual(expected_element * expected_element,
                             self.evaluate(get_next()))
        with self.assertRaises(errors.OutOfRangeError):
            self.evaluate(get_next())