def testRequestNotToCompile(self): with self.test_scope(): def f(x): with ops.device('device:CPU:0'): y = 2.0 * x return x, y wholly_compiled_f = def_function.function(f) op_by_op_f = function.defun_with_attributes( f, attributes={'_XlaCompile': False}) x = constant_op.constant([0.0, 2.0], name='data') # When function is wholly compiled, all outputs will be on the # device on which it is run. r_x, r_y = wholly_compiled_f(x) self.assertAllEqual([0.0, 2.0], r_x) self.assertAllEqual([0.0, 4.0], r_y) if context.executing_eagerly(): # backing_device is only available for eager tensors. self.assertRegexpMatches(r_x.backing_device, self.device) self.assertRegexpMatches(r_y.backing_device, self.device) # When function is executed op-by-op, requested devices will be # respected. r_x, r_y = op_by_op_f(x) self.assertAllEqual([0.0, 2.0], r_x) self.assertAllEqual([0.0, 4.0], r_y) if context.executing_eagerly(): # backing_device is only available for eager tensors. self.assertRegexpMatches(r_x.backing_device, self.device) self.assertRegexpMatches(r_y.backing_device, 'device:CPU:0')
def _generate_defun_backend(unique_api_name, preferred_device, func): function_attributes = { _DEFUN_API_NAME_ATTRIBUTE: unique_api_name, _DEFUN_DEVICE_ATTRIBUTE: preferred_device, } return function.defun_with_attributes(func=func, attributes=function_attributes)
def _defun(self, fn): """Returns a defun generated from the input function.""" attributes = None if self._experimental_compile is not None: if self._experimental_compile: attributes = {"_XlaCompile": True} else: attributes = {"_XlaCompile": False} return function_lib.defun_with_attributes( fn, input_signature=self.input_signature, attributes=attributes, autograph=self._autograph, experimental_autograph_options=self. _experimental_autograph_options, experimental_relax_shapes=self.experimental_relax_shapes)
def _defun(self, fn): """Returns a defun generated from the input function.""" attributes = {} if self._implements is not None: attributes[function_lib.IMPLEMENTS_ATTRIBUTE_NAME] = self._implements if self._experimental_compile is not None: attributes.update(_XlaCompile=bool(self._experimental_compile)) if not attributes: attributes = None return function_lib.defun_with_attributes( fn, input_signature=self.input_signature, attributes=attributes, autograph=self._autograph, experimental_autograph_options=self._experimental_autograph_options, experimental_relax_shapes=self._experimental_relax_shapes)