def _concrete_function_and_xla_comp( function_flat_tf, args_flat_sig_tf ) -> Tuple[TfConcreteFunction, xla_client.XlaComputation]: # TODO(necula): It seems that we need concrete tensors for get_compiler_ir? args_tf_flat = [ tf.constant((0 if a.dtype != tf.bool else False), shape=a.shape, dtype=a.dtype) for a in args_flat_sig_tf ] # TODO(necula): For unoptimized HLO, does it make a difference which device we use? tf_device_name = "/device:CPU:0" with jax2tf_internal.inside_call_tf(): try: func_tf_hlo = function_flat_tf.experimental_get_compiler_ir( *args_tf_flat)(stage="hlo_serialized", device_name=tf_device_name) except Exception as e: msg = ( "Error compiling TensorFlow function. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compileable functions.") raise ValueError(msg) from e # The above has traced the function and in fact has cached a ConcreteFunction # Grab it now, so that we don't have to construct `args_tf_flat` only to # get a cache hit. concrete_function_flat_tf = function_flat_tf.get_concrete_function( *args_tf_flat) return concrete_function_flat_tf, xla_client.XlaComputation(func_tf_hlo)
def _call_tf_translation_rule(builder, *args_op, func_tf, func_tf_concrete, args_treedef, args_tf_sig_flat, out_avals, **_): # TODO(necula): It seems that we need concrete tensors for get_compiler_ir? args_tf_flat = [ tf.constant((0 if a.dtype != tf.bool else False), shape=a.shape, dtype=a.dtype) for a in args_tf_sig_flat ] args_tf = args_treedef.unflatten(args_tf_flat) func_tf = tf.function(func_tf, jit_compile=True) #func_tf_concrete = func_tf.get_concrete_function(*args_tf) captured_ops = [] # Same order as captured_inputs if func_tf_concrete.captured_inputs: # The function uses either captured variables or tensors. msg = ( "call_tf works best with a TensorFlow function that does not capture " "variables or tensors from the context. " "See https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax for a discussion. " f"The following captures were found {func_tf_concrete.captured_inputs}" ) logging.warning(msg) next_var_idx = 0 for inp in func_tf_concrete.captured_inputs: if inp.dtype == tf.resource: # A variable; assume the next variable assert next_var_idx < len(func_tf_concrete.variables) # TODO(necula): better checking that we are picking the right variable var = func_tf_concrete.variables[next_var_idx] next_var_idx += 1 inp_const = np.asarray(var) else: inp_const = np.asarray(inp) captured_ops.append( xops.ConstantLiteral(builder, np.asarray(inp_const))) # TODO(necula): For unoptimized HLO, does it make a difference which device we use? tf_device_name = "/device:CPU:0" func_tf_hlo = func_tf.experimental_get_compiler_ir(*args_tf)( stage="hlo_serialized", device_name=tf_device_name) callee_xla_comp = xla_client.XlaComputation(func_tf_hlo) res_tf = xops.Call(builder, callee_xla_comp, args_op + tuple(captured_ops)) if len(out_avals) == 1: # TF does not wrap singletons as tuples, but JAX expects tuples because # call_tf is a multiple_results primitive. res_untupled = (res_tf, ) else: res_untupled = tuple( xops.GetTupleElement(res_tf, i) for i in range(len(out_avals))) # We may have to cast the results to x32 for JAX def canonicalize_res(res, out_aval: core.AbstractValue): res_dtype = builder.get_shape(res).numpy_dtype() if res_dtype != out_aval.dtype: new_etype = xla_client.dtype_to_etype(out_aval.dtype) return xops.ConvertElementType(res, new_element_type=new_etype) else: return res canonical_res_untupled = tuple( map(canonicalize_res, res_untupled, out_avals)) return xops.Tuple(builder, canonical_res_untupled)
def unpack_xla_computation(any_pb): """Unpacks an `Any` proto to an `XlaComputation`. Args: any_pb: An instance of `google.protobuf.Any` to unpack. Returns: The unpacked instance of `xla_client.XlaComputation`. Raises: TypeError: if `any_pb` is not an `Any` protocol buffer message. ValueError: if the object packed into `any_pb` cannot be unpacked. """ py_typecheck.check_type(any_pb, any_pb2.Any) if any_pb.type_url != _HLO_MODULE_PROTO_URI: raise ValueError('Not a serialized `HloModuleProto`: {}.'.format( str(any_pb.type_url))) return xla_client.XlaComputation(any_pb.value)
def _code_generator_and_avals( function_flat_tf, args_flat_sig_tf, code_gen_optional=False ) -> Tuple[Optional[Callable[[xla.XlaComputationBuilder, Sequence[xla.XlaOp]], xla.XlaOp]], Sequence[core.ShapedArray]]: # Returns and caches a code generator (taking a builder and the # XlaOps for the arguments) and a sequence of result abstract shapes. # It turns out that both for abstract evaluation and for actual compilation # it is useful to actually generate the HLO. This is true because in some # cases just TF-level shape inference is not precise enough to recover the # output shapes (e.g., b/128924522), even in situations where XLA can compile # the code, from which we can get the shapes. # Due to bugs like b/193754660, the compilation may fail. To work around this # issue we pass the `code_gen_optional` when in an abstract evaluation context # in which case we fallback on TF shape inference. Luckily it seen that # it is never the case that we are under tf.function, and we call the # XLA translation rule for call_tf. The latter happens only for jax.jit, but # jax.jit under a tf.function must be under jax2tf.convert, which unrolls # the jit. # TODO(necula): It seems that we need concrete tensors for get_compiler_ir? # We know of one case when TF is sensitive to the values of the tensors that # affect shapes in the computation. In those cases, however, those tensors # are inlined in the computation, which we detect below. args_tf_flat = [ tf.constant((0 if a.dtype != tf.bool else False), shape=a.shape, dtype=a.dtype) for a in args_flat_sig_tf ] # TODO(necula): We should use the proper device, because in some cases we # generate different HLO for different devices. # One example is when the code refers to variables on one device. Or, for # sharding annotations (only supported on TPU). # For now we just use the default device, but ideally we should pass the # intended platform in. The problem is that we want to reuse and cache this # function across abstract_eval and XLA translation, but for abstract_eval # we do not know the platform. tf_device_name = f"/device:{jax.default_backend().upper()}:0" with jax2tf_internal.inside_call_tf(): concrete_function_flat_tf = function_flat_tf.get_concrete_function( *args_flat_sig_tf) captured_inputs = [] if concrete_function_flat_tf.captured_inputs: # The function uses either captured variables or tensors. msg = ( "call_tf works best with a TensorFlow function that does not capture " "variables or tensors from the context. " "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion. " f"The following captures were found {concrete_function_flat_tf.captured_inputs}" ) logging.warning(msg) next_var_idx = 0 for inp in concrete_function_flat_tf.captured_inputs: if inp.dtype == tf.resource: # A variable; assume the next variable assert next_var_idx < len(concrete_function_flat_tf.variables) var = concrete_function_flat_tf.variables[next_var_idx] next_var_idx += 1 assert inp is var.handle # For extra safety captured_inputs.append(var) else: captured_inputs.append(inp) with jax2tf_internal.inside_call_tf(): # The above has traced the function and in fact has cached a ConcreteFunction # Grab it now, so that we don't have to construct `args_tf_flat` only to # get a cache hit. try: func_tf_hlo = function_flat_tf.experimental_get_compiler_ir( *args_tf_flat)(stage="hlo_serialized", device_name=tf_device_name) except Exception as e: if type( e ) is TypeError and "An op outside of the function building code" in str( e): # TODO(b/193754660): this may happen if we are in a function context # Try to salvage the situation if we are just doing abstract_eval, maybe # for jax2tf.convert. We can do that if all the output_shapes are known. def is_fully_known_shape(s): return s.rank is not None and all( [d is not None for d in s]) if code_gen_optional and (all([ is_fully_known_shape(s) for s in concrete_function_flat_tf.output_shapes ])): result_avals = [ # We convert to JAX type, and canonicalize to 32-bit if necessary core.ShapedArray(shape, jax2tf_internal._to_jax_dtype(dtype)) for dtype, shape in zip( concrete_function_flat_tf.output_dtypes, concrete_function_flat_tf.output_shapes) ] return None, result_avals msg = ( "Error compiling TensorFlow function. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compileable functions. " + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion." ) raise ValueError(msg) from e xla_comp = xla_client.XlaComputation(func_tf_hlo) # Check that the function does not have compile-time constant inputs that # have been inlined in the compiled code. xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes() found_parameter_avals = [ core.ShapedArray(found_xla_shape.dimensions(), found_xla_shape.numpy_dtype()) for found_xla_shape in xla_comp_parameter_shapes ] # Add the captured_inputs to args_flat_sig_tf expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs) expected_parameter_avals = [ core.ShapedArray(tuple(arg_sig.shape.as_list()), arg_sig.dtype.as_numpy_dtype) for arg_sig in expected_args_flat_sig_tf ] if found_parameter_avals != expected_parameter_avals: msg = ( "Compiled TensorFlow function has unexpected parameter types " + f"{found_parameter_avals}, while the expected types are " + f"{expected_parameter_avals}. Perhaps the TensorFlow function " + "has shape-influencing inputs, and thus needs to be recompiled " + "for each value of some inputs. " + "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion." ) raise ValueError(msg) # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray: res_dtype = res_shape.numpy_dtype() jax_res_dtype = dtypes.canonicalize_dtype(res_dtype) return core.ShapedArray(res_shape.dimensions(), jax_res_dtype) result_shape = xla_comp.program_shape().result_shape() if not result_shape.is_tuple(): # TF does not wrap singletons as tuples, but JAX expects tuples because # call_tf is a multiple_results primitive. result_shapes = (result_shape, ) else: result_shapes = result_shape.tuple_shapes() # type: ignore result_avals = tuple(map(canonical_res_aval, result_shapes)) # type: ignore def code_gen(builder: xla.XlaShape, args_op: Sequence[xla.XlaOp]) -> xla.XlaOp: captured_ops = [ xops.ConstantLiteral(builder, np.asarray(inp)) for inp in captured_inputs ] res_tf = xops.Call(builder, xla_comp, args_op + tuple(captured_ops)) # type: ignore def post_process_result(idx: int, res_aval: core.ShapedArray, res_shape: xla.XlaShape): res_op = res_tf if result_shape.is_tuple(): res_op = xops.GetTupleElement(res_tf, idx) if res_aval.dtype != res_shape.numpy_dtype(): res_op = xops.ConvertElementType( res_op, new_element_type=xla_client.dtype_to_etype(res_aval.dtype)) return res_op results = [ post_process_result(i, res_aval, res_shape) for i, (res_aval, res_shape) in enumerate(zip(result_avals, result_shapes)) ] return xops.Tuple(builder, results) return code_gen, result_avals