예제 #1
0
  def testExecutionMode(self):
    self.assertTrue(config.get_synchronous_execution())
    self.assertEqual(context.SYNC, context.context().execution_mode)

    # If no op has been executed we should be able to set the execution mode as
    # well as any init-time configs.
    config.set_intra_op_parallelism_threads(1)
    config.set_synchronous_execution(False)
    config.set_intra_op_parallelism_threads(2)

    config.set_synchronous_execution(True)
    self.assertTrue(config.get_synchronous_execution())
    self.assertEqual(context.SYNC, context.context().execution_mode)
    config.set_synchronous_execution(False)
    self.assertFalse(config.get_synchronous_execution())
    self.assertEqual(context.ASYNC, context.context().execution_mode)
예제 #2
0
    def testExecutionMode(self):
        self.assertTrue(config.get_synchronous_execution())
        self.assertEqual(context.SYNC, context.context().execution_mode)

        # If no op has been executed we should be able to set the execution mode as
        # well as any init-time configs.
        config.set_intra_op_parallelism_threads(1)
        config.set_synchronous_execution(False)
        config.set_intra_op_parallelism_threads(2)

        config.set_synchronous_execution(True)
        self.assertTrue(config.get_synchronous_execution())
        self.assertEqual(context.SYNC, context.context().execution_mode)
        config.set_synchronous_execution(False)
        self.assertFalse(config.get_synchronous_execution())
        self.assertEqual(context.ASYNC, context.context().execution_mode)
예제 #3
0
 def test_collective_reduce_async_context(self):
     previous = config.get_synchronous_execution()
     try:
         context._reset_context()
         config.set_synchronous_execution(False)
         self.setUp()
         # Note that ops on the parallel device currently don't execute
         # asynchronously. The test is just that we don't get deadlocks.
         x = self.device.pack(
             [constant_op.constant(-1.5),
              constant_op.constant(3.5)])
         with self.device:
             reduced = _collective_sum(x, num_replicas=2)
             outputs = self.device.unpack(reduced)
         self.assertAllClose([2., 2.], outputs)
         self.assertIn(self.device.components[0], outputs[0].backing_device)
         self.assertIn(self.device.components[1], outputs[1].backing_device)
     finally:
         context._reset_context()
         config.set_synchronous_execution(previous)