def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_): # On GPU we use dlpack to avoid copies of data to the host. def _arg_jax_to_tf(arg_jax): if (isinstance(arg_jax, xla.DeviceArray) and arg_jax.device_buffer.client.platform in _DLPACK_PLATFORMS and arg_jax.dtype in dlpack.SUPPORTED_DTYPES): arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False) return tf.experimental.dlpack.from_dlpack(arg_dlpack) # The following avoids copies to the host on CPU, always for DeviceArray # and even for ndarray if they are sufficiently aligned. # TODO(necula): on TPU this copies to the host! return tf.constant(np.asarray(arg_jax)) args_tf_flat = tuple(map(_arg_jax_to_tf, args_jax_flat)) with jax2tf_internal.inside_call_tf(): # Call in TF eager mode res_tf_flat = callable_flat_tf(*args_tf_flat) def _res_tf_to_jax(res_tf: TfVal): res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf) if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES: res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type res_jax_platform = res_tf_platform.lower() if res_jax_platform in _DLPACK_PLATFORMS: res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack(res_dlpack) return jax.device_put(np.asarray(res_tf)) return list(map(_res_tf_to_jax, res_tf_flat))
def _jax2tf_call_tf(*args: TfVal, callable_flat_tf: Callable, **_) -> TfVal: with jax2tf_internal.inside_call_tf(): res_tf_flat = callable_flat_tf(*args) return res_tf_flat
def _code_generator_and_avals( function_flat_tf, args_flat_sig_tf, code_gen_optional=False ) -> Tuple[Optional[Callable[[mlir.ModuleContext, Sequence[ir.Value]], Sequence[ir.Value]]], Sequence[core.ShapedArray]]: # Returns and caches a code generator (taking a builder and the # XlaOps for the arguments) and a sequence of result abstract shapes. # It turns out that both for abstract evaluation and for actual compilation # it is useful to actually generate the HLO. This is true because in some # cases just TF-level shape inference is not precise enough to recover the # output shapes (e.g., b/128924522), even in situations where XLA can compile # the code, from which we can get the shapes. # Due to bugs like b/193754660, the compilation may fail. To work around this # issue we pass the `code_gen_optional` when in an abstract evaluation context # in which case we fallback on TF shape inference. Luckily it seen that # it is never the case that we are under tf.function, and we call the # XLA translation rule for call_tf. The latter happens only for jax.jit, but # jax.jit under a tf.function must be under jax2tf.convert, which unrolls # the jit. # TODO(necula): It seems that we need concrete tensors for get_compiler_ir? # We know of one case when TF is sensitive to the values of the tensors that # affect shapes in the computation. In those cases, however, those tensors # are inlined in the computation, which we detect below. args_tf_flat = [ tf.constant((0 if a.dtype != tf.bool else False), shape=a.shape, dtype=a.dtype) for a in args_flat_sig_tf] # TODO(necula): We should use the proper device, because in some cases we # generate different HLO for different devices. # One example is when the code refers to variables on one device. Or, for # sharding annotations (only supported on TPU). # For now we just use the default device, but ideally we should pass the # intended platform in. The problem is that we want to reuse and cache this # function across abstract_eval and XLA translation, but for abstract_eval # we do not know the platform. tf_device_name = f"/device:{jax.default_backend().upper()}:0" with jax2tf_internal.inside_call_tf(): concrete_function_flat_tf = function_flat_tf.get_concrete_function(*args_flat_sig_tf) captured_inputs = [] if concrete_function_flat_tf.captured_inputs: # The function uses either captured variables or tensors. msg = ( "call_tf works best with a TensorFlow function that does not capture " "variables or tensors from the context. " "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion. " f"The following captures were found {concrete_function_flat_tf.captured_inputs}") logging.warning(msg) for inp in concrete_function_flat_tf.captured_inputs: if inp.dtype == tf.resource: # A variable; lookup by handle inp_vars = [v for v in concrete_function_flat_tf.variables if inp is v.handle] assert len(inp_vars) == 1, f"Found {inp_vars}" captured_inputs.append(inp_vars[0]) else: captured_inputs.append(inp) with jax2tf_internal.inside_call_tf(): # The above has traced the function and in fact has cached a ConcreteFunction # Grab it now, so that we don't have to construct `args_tf_flat` only to # get a cache hit. try: func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(*args_tf_flat)( stage="hlo_serialized", device_name=tf_device_name) except Exception as e: # TODO(b/193754660): This is a workaround. Use a more robust mechanism # instead of relying on error message. # Check two different error messages, to ensure the code works internally # (with "out of scope") and also in OSS (with "An op outside ..."). if type(e) is TypeError and ("out of scope" in str(e) or "An op outside of the function building code" in str(e)): # TODO(b/193754660): this may happen if we are in a function context # Try to salvage the situation if we are just doing abstract_eval, maybe # for jax2tf.convert. We can do that if all the output_shapes are known. def is_fully_known_shape(s): return s.rank is not None and all([d is not None for d in s]) if code_gen_optional and ( all([is_fully_known_shape(s) for s in concrete_function_flat_tf.output_shapes])): result_avals = [ # We convert to JAX type, and canonicalize to 32-bit if necessary core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) for dtype, shape in zip(concrete_function_flat_tf.output_dtypes, concrete_function_flat_tf.output_shapes)] return None, result_avals msg = ("Error compiling TensorFlow function. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compileable functions with static output shapes. " + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.") raise ValueError(msg) from e xla_comp = xla_client.XlaComputation(func_tf_hlo) # Check that the function does not have compile-time constant inputs that # have been inlined in the compiled code. xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes() found_parameter_avals = [ core.ShapedArray(found_xla_shape.dimensions(), dtypes.canonicalize_dtype(found_xla_shape.numpy_dtype())) for found_xla_shape in xla_comp_parameter_shapes ] # Add the captured_inputs to args_flat_sig_tf expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs) expected_parameter_avals = [ core.ShapedArray(tuple(arg_sig.shape.as_list()), dtypes.canonicalize_dtype(arg_sig.dtype.as_numpy_dtype)) for arg_sig in expected_args_flat_sig_tf] if found_parameter_avals != expected_parameter_avals: msg = ("Compiled TensorFlow function has unexpected parameter types " + f"{found_parameter_avals}, while the expected types are " + f"{expected_parameter_avals}. Perhaps the TensorFlow function " + "has shape-influencing inputs, and thus needs to be recompiled " + "for each value of some inputs. " + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.") raise ValueError(msg) # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray: if not res_shape.is_static(): msg = ("Compiled TensorFlow function has dynamic output shape " + f"{res_shape}. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compileable functions with static output shapes. " + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion.") raise ValueError(msg) res_dtype = res_shape.numpy_dtype() jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) result_shape = xla_comp.program_shape().result_shape() if not result_shape.is_tuple(): # TF does not wrap singletons as tuples, but JAX expects tuples because # call_tf is a multiple_results primitive. result_shapes = (result_shape,) else: result_shapes = result_shape.tuple_shapes() # type: ignore result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value] ) -> Sequence[ir.Value]: captured_ops = tuple(mlir.ir_constant(np.asarray(inp), canonicalize_types=False) for inp in captured_inputs) submodule = mlir.xla_computation_to_mhlo_module(xla_comp) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}", submodule) call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) if result_shape.is_tuple(): flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result for i in range(len(result_shapes))] else: flat_results = call.results outputs = [] for op, res_aval, res_shape in zip(flat_results, result_avals, result_shapes): if res_aval.dtype != res_shape.numpy_dtype(): op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result outputs.append(op) return outputs return code_gen, result_avals