def override_threadpool_fn(dataset): t_options = threading_options.ThreadingOptions() if max_intra_op_parallelism is not None: t_options.max_intra_op_parallelism = max_intra_op_parallelism if num_threads is not None: t_options.private_threadpool_size = num_threads options = dataset_ops.Options() options.experimental_threading = t_options return dataset.with_options(options)
def testOptionsHaveDefaults(self): options1 = dataset_ops.Options() options2 = dataset_ops.Options() self.assertIsNot(options1.experimental_optimization, options2.experimental_optimization) self.assertIsNot(options1.threading, options2.threading) self.assertEqual(options1.experimental_optimization, optimization_options.OptimizationOptions()) self.assertEqual(options1.threading, threading_options.ThreadingOptions())
def testSloppyInterleaveInOrder(self, num_elements, num_parallel_calls): dataset, coordination_events = _make_coordinated_sloppy_dataset( num_elements, num_parallel_calls) options = dataset_ops.Options() options.experimental_threading = threading_options.ThreadingOptions() options.experimental_threading.private_threadpool_size = ( num_parallel_calls + 1) dataset = dataset.with_options(options) get_next = self.getNext(dataset, requires_initialization=True) for i in range(num_elements): coordination_events[i].set() self.assertEqual(i * i, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next())
def testSloppyInterleaveInOrder(self, input_values, cycle_length, block_length, num_parallel_calls): dataset, coordination_events = _make_coordinated_sloppy_dataset( input_values, cycle_length, block_length, num_parallel_calls) options = dataset_ops.Options() options.experimental_threading = threading_options.ThreadingOptions() options.experimental_threading.private_threadpool_size = ( num_parallel_calls + 1) dataset = dataset.with_options(options) get_next = self.getNext(dataset, requires_initialization=True) for expected_element in _interleave(_repeat(input_values, 2), cycle_length, block_length): coordination_events[expected_element].set() self.assertEqual(expected_element * expected_element, self.evaluate(get_next())) with self.assertRaises(errors.OutOfRangeError): self.evaluate(get_next())