def testRequestNotToCompile(self):
    with self.test_scope():
      def f(x):
        with ops.device('device:CPU:0'):
          y = 2.0 * x
        return x, y

      wholly_compiled_f = def_function.function(f)
      op_by_op_f = function.defun_with_attributes(
          f, attributes={'_XlaCompile': False})

      x = constant_op.constant([0.0, 2.0], name='data')

      # When function is wholly compiled, all outputs will be on the
      # device on which it is run.
      r_x, r_y = wholly_compiled_f(x)
      self.assertAllEqual([0.0, 2.0], r_x)
      self.assertAllEqual([0.0, 4.0], r_y)
      if context.executing_eagerly():
        # backing_device is only available for eager tensors.
        self.assertRegexpMatches(r_x.backing_device, self.device)
        self.assertRegexpMatches(r_y.backing_device, self.device)

      # When function is executed op-by-op, requested devices will be
      # respected.
      r_x, r_y = op_by_op_f(x)
      self.assertAllEqual([0.0, 2.0], r_x)
      self.assertAllEqual([0.0, 4.0], r_y)
      if context.executing_eagerly():
        # backing_device is only available for eager tensors.
        self.assertRegexpMatches(r_x.backing_device, self.device)
        self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
def _generate_defun_backend(unique_api_name, preferred_device, func):
  function_attributes = {
      _DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
      _DEFUN_DEVICE_ATTRIBUTE: preferred_device,
  }
  return function.defun_with_attributes(func=func,
                                        attributes=function_attributes)
Exemple #3
0
 def _defun(self, fn):
     """Returns a defun generated from the input function."""
     attributes = None
     if self._experimental_compile is not None:
         if self._experimental_compile:
             attributes = {"_XlaCompile": True}
         else:
             attributes = {"_XlaCompile": False}
     return function_lib.defun_with_attributes(
         fn,
         input_signature=self.input_signature,
         attributes=attributes,
         autograph=self._autograph,
         experimental_autograph_options=self.
         _experimental_autograph_options,
         experimental_relax_shapes=self.experimental_relax_shapes)
Exemple #4
0
 def _defun(self, fn):
   """Returns a defun generated from the input function."""
   attributes = {}
   if self._implements is not None:
     attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements
   if self._experimental_compile is not None:
     attributes.update(_XlaCompile=bool(self._experimental_compile))
   if not attributes:
     attributes = None
   return function_lib.defun_with_attributes(
       fn,
       input_signature=self.input_signature,
       attributes=attributes,
       autograph=self._autograph,
       experimental_autograph_options=self._experimental_autograph_options,
       experimental_relax_shapes=self._experimental_relax_shapes)