Ejemplo n.º 1
0
def _concrete_function_and_xla_comp(
        function_flat_tf, args_flat_sig_tf
) -> Tuple[TfConcreteFunction, xla_client.XlaComputation]:
    # TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
    args_tf_flat = [
        tf.constant((0 if a.dtype != tf.bool else False),
                    shape=a.shape,
                    dtype=a.dtype) for a in args_flat_sig_tf
    ]

    # TODO(necula): For unoptimized HLO, does it make a difference which device we use?
    tf_device_name = "/device:CPU:0"
    with jax2tf_internal.inside_call_tf():
        try:
            func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
                *args_tf_flat)(stage="hlo_serialized",
                               device_name=tf_device_name)
        except Exception as e:
            msg = (
                "Error compiling TensorFlow function. call_tf can used " +
                "in a staged context (under jax.jit, lax.scan, etc.) only with "
                + "compileable functions.")
            raise ValueError(msg) from e

        # The above has traced the function and in fact has cached a ConcreteFunction
        # Grab it now, so that we don't have to construct `args_tf_flat` only to
        # get a cache hit.
        concrete_function_flat_tf = function_flat_tf.get_concrete_function(
            *args_tf_flat)
    return concrete_function_flat_tf, xla_client.XlaComputation(func_tf_hlo)
Ejemplo n.º 2
0
def _call_tf_translation_rule(builder, *args_op, func_tf, func_tf_concrete,
                              args_treedef, args_tf_sig_flat, out_avals, **_):
    # TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
    args_tf_flat = [
        tf.constant((0 if a.dtype != tf.bool else False),
                    shape=a.shape,
                    dtype=a.dtype) for a in args_tf_sig_flat
    ]
    args_tf = args_treedef.unflatten(args_tf_flat)
    func_tf = tf.function(func_tf, jit_compile=True)
    #func_tf_concrete = func_tf.get_concrete_function(*args_tf)
    captured_ops = []  # Same order as captured_inputs
    if func_tf_concrete.captured_inputs:
        # The function uses either captured variables or tensors.
        msg = (
            "call_tf works best with a TensorFlow function that does not capture "
            "variables or tensors from the context. "
            "See https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax for a discussion. "
            f"The following captures were found {func_tf_concrete.captured_inputs}"
        )
        logging.warning(msg)

        next_var_idx = 0
        for inp in func_tf_concrete.captured_inputs:
            if inp.dtype == tf.resource:  # A variable; assume the next variable
                assert next_var_idx < len(func_tf_concrete.variables)
                # TODO(necula): better checking that we are picking the right variable
                var = func_tf_concrete.variables[next_var_idx]
                next_var_idx += 1
                inp_const = np.asarray(var)
            else:
                inp_const = np.asarray(inp)
            captured_ops.append(
                xops.ConstantLiteral(builder, np.asarray(inp_const)))

    # TODO(necula): For unoptimized HLO, does it make a difference which device we use?
    tf_device_name = "/device:CPU:0"
    func_tf_hlo = func_tf.experimental_get_compiler_ir(*args_tf)(
        stage="hlo_serialized", device_name=tf_device_name)
    callee_xla_comp = xla_client.XlaComputation(func_tf_hlo)
    res_tf = xops.Call(builder, callee_xla_comp, args_op + tuple(captured_ops))
    if len(out_avals) == 1:
        # TF does not wrap singletons as tuples, but JAX expects tuples because
        # call_tf is a multiple_results primitive.
        res_untupled = (res_tf, )
    else:
        res_untupled = tuple(
            xops.GetTupleElement(res_tf, i) for i in range(len(out_avals)))
    # We may have to cast the results to x32 for JAX
    def canonicalize_res(res, out_aval: core.AbstractValue):
        res_dtype = builder.get_shape(res).numpy_dtype()
        if res_dtype != out_aval.dtype:
            new_etype = xla_client.dtype_to_etype(out_aval.dtype)
            return xops.ConvertElementType(res, new_element_type=new_etype)
        else:
            return res

    canonical_res_untupled = tuple(
        map(canonicalize_res, res_untupled, out_avals))
    return xops.Tuple(builder, canonical_res_untupled)
Ejemplo n.º 3
0
def unpack_xla_computation(any_pb):
    """Unpacks an `Any` proto to an `XlaComputation`.

  Args:
    any_pb: An instance of `google.protobuf.Any` to unpack.

  Returns:
    The unpacked instance of `xla_client.XlaComputation`.

  Raises:
    TypeError: if `any_pb` is not an `Any` protocol buffer message.
    ValueError: if the object packed into `any_pb` cannot be unpacked.
  """
    py_typecheck.check_type(any_pb, any_pb2.Any)
    if any_pb.type_url != _HLO_MODULE_PROTO_URI:
        raise ValueError('Not a serialized `HloModuleProto`: {}.'.format(
            str(any_pb.type_url)))
    return xla_client.XlaComputation(any_pb.value)
Ejemplo n.º 4
0
def _code_generator_and_avals(
    function_flat_tf,
    args_flat_sig_tf,
    code_gen_optional=False
) -> Tuple[Optional[Callable[[xla.XlaComputationBuilder, Sequence[xla.XlaOp]],
                             xla.XlaOp]], Sequence[core.ShapedArray]]:
    # Returns and caches a code generator (taking a builder and the
    # XlaOps for the arguments) and a sequence of result abstract shapes.

    # It turns out that both for abstract evaluation and for actual compilation
    # it is useful to actually generate the HLO. This is true because in some
    # cases just TF-level shape inference is not precise enough to recover the
    # output shapes (e.g., b/128924522), even in situations where XLA can compile
    # the code, from which we can get the shapes.

    # Due to bugs like b/193754660, the compilation may fail. To work around this
    # issue we pass the `code_gen_optional` when in an abstract evaluation context
    # in which case we fallback on TF shape inference. Luckily it seen that
    # it is never the case that we are under tf.function, and we call the
    # XLA translation rule for call_tf. The latter happens only for jax.jit, but
    # jax.jit under a tf.function must be under jax2tf.convert, which unrolls
    # the jit.

    # TODO(necula): It seems that we need concrete tensors for get_compiler_ir?
    # We know of one case when TF is sensitive to the values of the tensors that
    # affect shapes in the computation. In those cases, however, those tensors
    # are inlined in the computation, which we detect below.
    args_tf_flat = [
        tf.constant((0 if a.dtype != tf.bool else False),
                    shape=a.shape,
                    dtype=a.dtype) for a in args_flat_sig_tf
    ]

    # TODO(necula): We should use the proper device, because in some cases we
    # generate different HLO for different devices.
    # One example is when the code refers to variables on one device. Or, for
    # sharding annotations (only supported on TPU).
    # For now we just use the default device, but ideally we should pass the
    # intended platform in. The problem is that we want to reuse and cache this
    # function across abstract_eval and XLA translation, but for abstract_eval
    # we do not know the platform.
    tf_device_name = f"/device:{jax.default_backend().upper()}:0"
    with jax2tf_internal.inside_call_tf():
        concrete_function_flat_tf = function_flat_tf.get_concrete_function(
            *args_flat_sig_tf)

    captured_inputs = []
    if concrete_function_flat_tf.captured_inputs:
        # The function uses either captured variables or tensors.
        msg = (
            "call_tf works best with a TensorFlow function that does not capture "
            "variables or tensors from the context. "
            "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion. "
            f"The following captures were found {concrete_function_flat_tf.captured_inputs}"
        )
        logging.warning(msg)
        next_var_idx = 0
        for inp in concrete_function_flat_tf.captured_inputs:
            if inp.dtype == tf.resource:  # A variable; assume the next variable
                assert next_var_idx < len(concrete_function_flat_tf.variables)
                var = concrete_function_flat_tf.variables[next_var_idx]
                next_var_idx += 1
                assert inp is var.handle  # For extra safety
                captured_inputs.append(var)
            else:
                captured_inputs.append(inp)

    with jax2tf_internal.inside_call_tf():
        # The above has traced the function and in fact has cached a ConcreteFunction
        # Grab it now, so that we don't have to construct `args_tf_flat` only to
        # get a cache hit.
        try:
            func_tf_hlo = function_flat_tf.experimental_get_compiler_ir(
                *args_tf_flat)(stage="hlo_serialized",
                               device_name=tf_device_name)
        except Exception as e:
            if type(
                    e
            ) is TypeError and "An op outside of the function building code" in str(
                    e):
                # TODO(b/193754660): this may happen if we are in a function context
                # Try to salvage the situation if we are just doing abstract_eval, maybe
                # for jax2tf.convert. We can do that if all the output_shapes are known.
                def is_fully_known_shape(s):
                    return s.rank is not None and all(
                        [d is not None for d in s])

                if code_gen_optional and (all([
                        is_fully_known_shape(s)
                        for s in concrete_function_flat_tf.output_shapes
                ])):
                    result_avals = [
                        # We convert to JAX type, and canonicalize to 32-bit if necessary
                        core.ShapedArray(shape,
                                         jax2tf_internal._to_jax_dtype(dtype))
                        for dtype, shape in zip(
                            concrete_function_flat_tf.output_dtypes,
                            concrete_function_flat_tf.output_shapes)
                    ]
                    return None, result_avals
            msg = (
                "Error compiling TensorFlow function. call_tf can used " +
                "in a staged context (under jax.jit, lax.scan, etc.) only with "
                + "compileable functions. " +
                "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion."
            )
            raise ValueError(msg) from e

    xla_comp = xla_client.XlaComputation(func_tf_hlo)
    # Check that the function does not have compile-time constant inputs that
    # have been inlined in the compiled code.
    xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
    found_parameter_avals = [
        core.ShapedArray(found_xla_shape.dimensions(),
                         found_xla_shape.numpy_dtype())
        for found_xla_shape in xla_comp_parameter_shapes
    ]
    # Add the captured_inputs to args_flat_sig_tf
    expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
    expected_parameter_avals = [
        core.ShapedArray(tuple(arg_sig.shape.as_list()),
                         arg_sig.dtype.as_numpy_dtype)
        for arg_sig in expected_args_flat_sig_tf
    ]
    if found_parameter_avals != expected_parameter_avals:
        msg = (
            "Compiled TensorFlow function has unexpected parameter types " +
            f"{found_parameter_avals}, while the expected types are " +
            f"{expected_parameter_avals}. Perhaps the TensorFlow function " +
            "has shape-influencing inputs, and thus needs to be recompiled " +
            "for each value of some inputs. " +
            "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion."
        )
        raise ValueError(msg)

    # Canonicalize the results; e.g., makes them x32 if JAX is in 32-bit mode
    def canonical_res_aval(res_shape: xla.XlaShape) -> core.ShapedArray:
        res_dtype = res_shape.numpy_dtype()
        jax_res_dtype = dtypes.canonicalize_dtype(res_dtype)
        return core.ShapedArray(res_shape.dimensions(), jax_res_dtype)

    result_shape = xla_comp.program_shape().result_shape()
    if not result_shape.is_tuple():
        # TF does not wrap singletons as tuples, but JAX expects tuples because
        # call_tf is a multiple_results primitive.
        result_shapes = (result_shape, )
    else:
        result_shapes = result_shape.tuple_shapes()  # type: ignore

    result_avals = tuple(map(canonical_res_aval,
                             result_shapes))  # type: ignore

    def code_gen(builder: xla.XlaShape,
                 args_op: Sequence[xla.XlaOp]) -> xla.XlaOp:
        captured_ops = [
            xops.ConstantLiteral(builder, np.asarray(inp))
            for inp in captured_inputs
        ]

        res_tf = xops.Call(builder, xla_comp,
                           args_op + tuple(captured_ops))  # type: ignore

        def post_process_result(idx: int, res_aval: core.ShapedArray,
                                res_shape: xla.XlaShape):
            res_op = res_tf
            if result_shape.is_tuple():
                res_op = xops.GetTupleElement(res_tf, idx)
            if res_aval.dtype != res_shape.numpy_dtype():
                res_op = xops.ConvertElementType(
                    res_op,
                    new_element_type=xla_client.dtype_to_etype(res_aval.dtype))
            return res_op

        results = [
            post_process_result(i, res_aval, res_shape)
            for i, (res_aval,
                    res_shape) in enumerate(zip(result_avals, result_shapes))
        ]
        return xops.Tuple(builder, results)

    return code_gen, result_avals