def testExecutionMode(self): self.assertTrue(config.get_synchronous_execution()) self.assertEqual(context.SYNC, context.context().execution_mode) # If no op has been executed we should be able to set the execution mode as # well as any init-time configs. config.set_intra_op_parallelism_threads(1) config.set_synchronous_execution(False) config.set_intra_op_parallelism_threads(2) config.set_synchronous_execution(True) self.assertTrue(config.get_synchronous_execution()) self.assertEqual(context.SYNC, context.context().execution_mode) config.set_synchronous_execution(False) self.assertFalse(config.get_synchronous_execution()) self.assertEqual(context.ASYNC, context.context().execution_mode)
def test_collective_reduce_async_context(self): previous = config.get_synchronous_execution() try: context._reset_context() config.set_synchronous_execution(False) self.setUp() # Note that ops on the parallel device currently don't execute # asynchronously. The test is just that we don't get deadlocks. x = self.device.pack( [constant_op.constant(-1.5), constant_op.constant(3.5)]) with self.device: reduced = _collective_sum(x, num_replicas=2) outputs = self.device.unpack(reduced) self.assertAllClose([2., 2.], outputs) self.assertIn(self.device.components[0], outputs[0].backing_device) self.assertIn(self.device.components[1], outputs[1].backing_device) finally: context._reset_context() config.set_synchronous_execution(previous)