Example #1
0
    def testIntraOpParallelismThreads(self):
        config.set_intra_op_parallelism_threads(10)
        self.assertEqual(config.get_intra_op_parallelism_threads(),
                         context.context().intra_op_parallelism_threads)

        constant_op.constant(1)
        with self.assertRaises(RuntimeError):
            config.set_intra_op_parallelism_threads(1)
Example #2
0
  def testIntraOpParallelismThreads(self):
    config.set_intra_op_parallelism_threads(10)
    self.assertEqual(
        config.get_intra_op_parallelism_threads(),
        context.context().intra_op_parallelism_threads)

    constant_op.constant(1)
    with self.assertRaises(RuntimeError):
      config.set_intra_op_parallelism_threads(1)
Example #3
0
    def testDevicePolicy(self):
        self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
                         context.context().device_policy)

        # If no op has been executed we should be able to set the device policy as
        # well as any init-time configs.
        config.set_intra_op_parallelism_threads(1)
        config.set_device_policy('silent')
        config.set_intra_op_parallelism_threads(2)

        context.ensure_initialized()

        def copy_tensor(dtype=dtypes.int32):
            with ops.device('CPU:0'):
                cpu_tensor = constant_op.constant(1, dtype=dtype)
            gpu_tensor = cpu_tensor.gpu()
            self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)

        config.set_device_policy('silent')
        self.assertEqual(config.get_device_policy(), 'silent')
        self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
                         context.context().device_policy)
        copy_tensor()

        config.set_device_policy('silent_for_int32')
        self.assertEqual(config.get_device_policy(), 'silent_for_int32')
        self.assertEqual(context.DEVICE_PLACEMENT_SILENT_FOR_INT32,
                         context.context().device_policy)
        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    'Tensors on conflicting devices'):
            copy_tensor(dtypes.float32)
        copy_tensor()

        config.set_device_policy('warn')
        self.assertEqual(config.get_device_policy(), 'warn')
        self.assertEqual(context.DEVICE_PLACEMENT_WARN,
                         context.context().device_policy)
        copy_tensor()

        config.set_device_policy('explicit')
        self.assertEqual(config.get_device_policy(), 'explicit')
        self.assertEqual(context.DEVICE_PLACEMENT_EXPLICIT,
                         context.context().device_policy)
        with self.assertRaisesRegex(errors.InvalidArgumentError,
                                    'Tensors on conflicting devices'):
            copy_tensor()

        config.set_device_policy(None)
        self.assertEqual(config.get_device_policy(), 'silent')
Example #4
0
  def testDevicePolicy(self):
    self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
                     context.context().device_policy)

    # If no op has been executed we should be able to set the device policy as
    # well as any init-time configs.
    config.set_intra_op_parallelism_threads(1)
    config.set_device_policy('silent')
    config.set_intra_op_parallelism_threads(2)

    # Excute a dummy op to ensure that the context has been initialized
    constant_op.constant(1)

    def copy_tensor(dtype=dtypes.int32):
      cpu_tensor = constant_op.constant(1, dtype=dtype)
      gpu_tensor = cpu_tensor.gpu()
      self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)

    config.set_device_policy('silent')
    self.assertEqual(config.get_device_policy(), 'silent')
    self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
                     context.context().device_policy)
    copy_tensor()

    config.set_device_policy('silent_for_int32')
    self.assertEqual(config.get_device_policy(), 'silent_for_int32')
    self.assertEqual(context.DEVICE_PLACEMENT_SILENT_FOR_INT32,
                     context.context().device_policy)
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 'Tensors on conflicting devices'):
      copy_tensor(dtypes.float32)
    copy_tensor()

    config.set_device_policy('warn')
    self.assertEqual(config.get_device_policy(), 'warn')
    self.assertEqual(context.DEVICE_PLACEMENT_WARN,
                     context.context().device_policy)
    copy_tensor()

    config.set_device_policy('explicit')
    self.assertEqual(config.get_device_policy(), 'explicit')
    self.assertEqual(context.DEVICE_PLACEMENT_EXPLICIT,
                     context.context().device_policy)
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 'Tensors on conflicting devices'):
      copy_tensor()

    config.set_device_policy(None)
    self.assertEqual(config.get_device_policy(), 'silent')
Example #5
0
    def testExecutionMode(self):
        self.assertTrue(config.get_synchronous_execution())
        self.assertEqual(context.SYNC, context.context().execution_mode)

        # If no op has been executed we should be able to set the execution mode as
        # well as any init-time configs.
        config.set_intra_op_parallelism_threads(1)
        config.set_synchronous_execution(False)
        config.set_intra_op_parallelism_threads(2)

        config.set_synchronous_execution(True)
        self.assertTrue(config.get_synchronous_execution())
        self.assertEqual(context.SYNC, context.context().execution_mode)
        config.set_synchronous_execution(False)
        self.assertFalse(config.get_synchronous_execution())
        self.assertEqual(context.ASYNC, context.context().execution_mode)
Example #6
0
  def testExecutionMode(self):
    self.assertTrue(config.get_synchronous_execution())
    self.assertEqual(context.SYNC, context.context().execution_mode)

    # If no op has been executed we should be able to set the execution mode as
    # well as any init-time configs.
    config.set_intra_op_parallelism_threads(1)
    config.set_synchronous_execution(False)
    config.set_intra_op_parallelism_threads(2)

    config.set_synchronous_execution(True)
    self.assertTrue(config.get_synchronous_execution())
    self.assertEqual(context.SYNC, context.context().execution_mode)
    config.set_synchronous_execution(False)
    self.assertFalse(config.get_synchronous_execution())
    self.assertEqual(context.ASYNC, context.context().execution_mode)
Example #7
0
    def testIntraOpParallelismThreads(self):
        config.set_intra_op_parallelism_threads(10)
        self.assertEqual(config.get_intra_op_parallelism_threads(),
                         context.context().intra_op_parallelism_threads)

        context.ensure_initialized()

        with self.assertRaises(RuntimeError):
            config.set_intra_op_parallelism_threads(1)

        config.set_intra_op_parallelism_threads(10)