Example #1
0
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tf2xla_pb2

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

def f(a, b):
  return jax.lax.add(a, b).sum()

f = tf.function(jax2tf.convert(f))
a = b = tf.ones([1, 1])
cf = f.get_concrete_function(a, b)

graph_def = cf.graph.as_graph_def()
with open('graph.pb', 'wb') as fp:
  fp.write(graph_def.SerializeToString())

config = tf2xla_pb2.Config()
batch_size = 1

feeds = [o.name for o in cf.graph.get_operations() if o.name.startswith('jax2tf_arg')]
fetches = [o.name for o in cf.graph.get_operations() if o.name.startswith('jax2tf_out')]

for idx, x in enumerate(cf.inputs):
	x.set_shape([batch_size] + list(x.shape)[1:])
	feed = config.feed.add()
	feed.id.node_name = feeds[idx]
	feed.shape.MergeFrom(x.shape.as_proto())
Example #2
0
 def test_simple(self):
     f_jax = jnp.sin
     f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax))
     x = np.float32(0.7)
     self.assertAllClose(f_jax(x), f_jax_rt(x))
Example #3
0
 def f_tf_outer(x_tf):
     y_tf = tf.math.sin(x_tf)
     z_tf = jax2tf.convert(f_jax)(y_tf)
     return tf.math.sin(z_tf)
Example #4
0
    def convert_and_save_model(
            jax_fn: tp.Callable[[tp.Any, tp.Any], tp.Any],
            params,
            model_dir: str,
            *,
            input_signatures: tp.Sequence[tf.TensorSpec],
            shape_polymorphic_input_spec: tp.Optional[str] = None,
            with_gradient: bool = False,
            enable_xla: bool = True,
            compile_model: bool = True,
            save_model_options: tp.Optional[
                tf.saved_model.SaveOptions] = None):
        """Convert a JAX function and saves a SavedModel.
        This is an example, for serious uses you will likely want to copy and
        expand it as needed (see note at the top of the model).
        Use this function if you have a trained ML model that has both a prediction
        function and trained parameters, which you want to save separately from the
        function graph as variables (e.g., to avoid limits on the size of the
        GraphDef, or to enable fine-tuning.) If you don't have such parameters,
        you can still use this library function but probably don't need it
        (see jax2tf/README.md for some simple examples).
        In order to use this wrapper you must first convert your model to a function
        with two arguments: the parameters and the input on which you want to do
        inference. Both arguments may be np.ndarray or (nested)
        tuples/lists/dictionaries thereof.
        See the README.md for a discussion of how to prepare Flax and Haiku models.
        Args:
        jax_fn: a JAX function taking two arguments, the parameters and the inputs.
            Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
        params: the parameters, to be used as first argument for `jax_fn`. These
            must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
            saved as the variables of the SavedModel.
        model_dir: the directory where the model should be saved.
        input_signatures: the input signatures for the second argument of `jax_fn`
            (the input). A signature must be a `tensorflow.TensorSpec` instance, or a
            (nested) tuple/list/dictionary thereof with a structure matching the
            second argument of `jax_fn`. The first input_signature will be saved as
            the default serving signature. The additional signatures will be used
            only to ensure that the `jax_fn` is traced and converted to TF for the
            corresponding input shapes.
        shape_polymorphic_input_spec: if given then it will be used as the
            `in_shapes` argument to jax2tf.convert for the second parameter of
            `jax_fn`. In this case, a single `input_signatures` is supported, and
            should have `None` in the polymorphic dimensions. Should be a string, or a
            (nesteD) tuple/list/dictionary thereof with a structure matching the
            second argument of `jax_fn`.
        with_gradient: whether the SavedModel should support gradients. If True,
            then a custom gradient is saved. If False, then a
            tf.raw_ops.PreventGradient is saved to error if a gradient is attempted.
            (At the moment due to a bug in SavedModel, custom gradients are not
            supported.)
        enable_xla: whether the jax2tf converter is allowed to use TFXLA ops. If
            False, the conversion tries harder to use purely TF ops and raises an
            exception if it is not possible. (default: True)
        compile_model: use TensorFlow jit_compiler on the SavedModel. This
            is needed if the SavedModel will be used for TensorFlow serving.
        save_model_options: options to pass to savedmodel.save.
        """

        if not input_signatures:
            raise ValueError("At least one input_signature must be given")
        if shape_polymorphic_input_spec is not None:
            if len(input_signatures) > 1:
                raise ValueError("For shape-polymorphic conversion a single "
                                 "input_signature is supported.")
        tf_fn = jax2tf.convert(
            jax_fn,
            with_gradient=with_gradient,
            in_shapes=[None, shape_polymorphic_input_spec],
            enable_xla=enable_xla,
        )

        # Create tf.Variables for the parameters. If you want more useful variable
        # names, you can use `tree.map_structure_with_path` from the `dm-tree` package
        param_vars = tf.nest.map_structure(
            # Due to a bug in SavedModel it is not possible to use tf.GradientTape on
            # a function converted with jax2tf and loaded from SavedModel. Thus, we
            # mark the variables as non-trainable to ensure that users of the
            # SavedModel will not try to fine tune them.
            lambda param: tf.Variable(param, trainable=with_gradient),
            params,
        )
        tf_fun = tf.function(
            lambda inputs: tf_fn(param_vars, inputs),
            autograph=False,
            experimental_compile=compile_model,
        )

        signatures = {}
        # This signature is needed for TensorFlow Serving use.
        signatures[
            tf.saved_model.
            DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf_fun.get_concrete_function(
                input_signatures[0])

        for input_signature in input_signatures[1:]:
            # If there are more signatures, trace and cache a TF function for each one
            tf_fun.get_concrete_function(input_signature)

        wrapper = _ReusableSavedModelWrapper(tf_fun, param_vars)
        tf.saved_model.save(wrapper,
                            model_dir,
                            signatures=signatures,
                            options=save_model_options)
Example #5
0
def _make_trained_model(train_data: tf.data.Dataset,
                        eval_data: tf.data.Dataset,
                        num_epochs: int, steps_per_epoch: int,
                        eval_steps_per_epoch: int, tensorboard_log_dir: str):
  """Execute model training and evaluation loop.

  Args:
    train_data: a dataset with training pairs (_InputBatch, _LabelBatch).
    eval_data: a dataset with evaluation pairs (_InputBatch, _LabelBatch).
    num_epochs: number of training epochs.
    steps_per_epoch: number of steps for a training epoch. Should be the number
       of samples in your train_data divided by the batch size.
    eval_steps_per_epoch: number of steps for evaluation at the end of each
       training epoch. Should be the number of samples in your eval_data
       divided by the batch size.
    tensorboard_log_dir: Directory where the tensorboard summaries are written.

  Returns:
    An instance of tf.Model.
  """
  learning_rate = 1e-2

  rng = jax.random.PRNGKey(0)

  summary_writer = tensorboard.SummaryWriter(tensorboard_log_dir)
  summary_writer.hparams(
      dict(
          learning_rate=learning_rate,
          num_epochs=num_epochs,
          steps_per_epoch=steps_per_epoch,
          eval_steps_per_epoch=eval_steps_per_epoch))

  rng, init_rng = jax.random.split(rng)
  # Initialize with some fake data of the proper shape.
  init_val = dict((feature, jnp.array([[1.]], dtype=jnp.float32))
                  for feature in _FEATURE_KEYS_XF)
  model = _FlaxPenguinModel()
  params = model.init(init_rng, init_val)['params']

  optimizer_def = flax.optim.Adam(learning_rate=learning_rate)
  optimizer = optimizer_def.create(params)

  for epoch in range(1, num_epochs + 1):
    optimizer, train_metrics = _train_epoch(model, optimizer, train_data,
                                            steps_per_epoch)
    absl.logging.info('Flax train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                      train_metrics['loss'], train_metrics['accuracy'] * 100)

    eval_metrics = _eval_epoch(model, optimizer.target, eval_data,
                               eval_steps_per_epoch)
    absl.logging.info('Flax eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
                      eval_metrics['loss'], eval_metrics['accuracy'] * 100)
    summary_writer.scalar('epoch_train_loss', train_metrics['loss'], epoch)
    summary_writer.scalar('epoch_train_accuracy', train_metrics['accuracy'],
                          epoch)
    summary_writer.scalar('epoch_eval_loss', eval_metrics['loss'], epoch)
    summary_writer.scalar('epoch_eval_accuracy', eval_metrics['accuracy'],
                          epoch)

  summary_writer.flush()

  # The prediction function for the trained model
  def predict(params: _Params, inputs: _InputBatch):
    return model.apply({'params': params}, inputs)

  trained_params = optimizer.target

  # Convert the prediction function to TF, with a variable batch dimension
  # for all inputs.
  tf_fn = jax2tf.convert(predict, with_gradient=False, enable_xla=True,
                         polymorphic_shapes=(None, '(b, 1)'))

  # Create tf.Variables for the parameters. If you want more useful variable
  # names, you can use `tree.map_structure_with_path` from the `dm-tree`
  # package.
  param_vars = tf.nest.map_structure(
      # Due to a bug in SavedModel it is not possible to use tf.GradientTape
      # on a function converted with jax2tf and loaded from SavedModel.
      # Thus, we mark the variables as non-trainable to ensure that users of
      # the SavedModel will not try to fine tune them.
      lambda param: tf.Variable(param, trainable=False),
      trained_params)
  tf_graph = tf.function(
      lambda inputs: tf_fn(param_vars, inputs),
      autograph=False,
      experimental_compile=True)
  return _SavedModelWrapper(tf_graph, param_vars)
Example #6
0
 def test_bfloat16_passed_by_tf(self):
   f_jax = lambda a, b: a + b
   f_tf = tf.function(jax2tf.convert(f_jax),
                      input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
                                       tf.TensorSpec([512, 512], tf.bfloat16)])
   self.assertIsNotNone(f_tf.get_concrete_function())
Example #7
0
def jax2tf_xla(test_case: ModelTestCase):
    """Converts the given `module` using the jax2tf emitter with enable_xla=True."""
    jax_fn = functools.partial(test_case.apply, test_case.variables)
    tf_fn = jax2tf.convert(jax_fn, enable_xla=True)
    _compare(test_case, jax_fn, tf_fn, "JAX vs TF (enable_xla=True)")
Example #8
0
def main(_):
    logging.info('Loading the MNIST TensorFlow dataset')
    train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN,
                                    batch_size=mnist_lib.train_batch_size)
    test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
                                   batch_size=FLAGS.serving_batch_size)

    (flax_predict,
     flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds,
                                              FLAGS.num_epochs)

    def predict(image):
        return flax_predict(flax_params, image)

    # Convert Flax model to TF function.
    tf_predict = tf.function(
        jax2tf.convert(predict, enable_xla=False),
        input_signature=[
            tf.TensorSpec(shape=[FLAGS.serving_batch_size, 28, 28, 1],
                          dtype=tf.float32,
                          name='input')
        ],
        autograph=False)

    # Convert TF function to TF Lite format.
    converter = tf.lite.TFLiteConverter.from_concrete_functions(
        [tf_predict.get_concrete_function()])
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
        tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
    ]
    tflite_float_model = converter.convert()

    # Show model size in KBs.
    float_model_size = len(tflite_float_model) / 1024
    print('Float model size = %dKBs.' % float_model_size)

    # Re-convert the model to TF Lite using quantization.
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_quantized_model = converter.convert()

    # Show model size in KBs.
    quantized_model_size = len(tflite_quantized_model) / 1024
    print('Quantized model size = %dKBs,' % quantized_model_size)
    print('which is about %d%% of the float model size.' %
          (quantized_model_size * 100 / float_model_size))

    # Evaluate the TF Lite float model. You'll find that its accurary is identical
    # to the original Flax model because they are essentially the same model
    # stored in different format.
    float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds)
    print('Float model accuracy = %.4f' % float_accuracy)

    # Evalualte the TF Lite quantized model.
    # Don't be surprised if you see quantized model accuracy is higher than
    # the original float model. It happens sometimes :)
    quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds)
    print('Quantized model accuracy = %.4f' % quantized_accuracy)
    print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))

    f = open(FLAGS.tflite_file_path, 'wb')
    f.write(tflite_quantized_model)
    f.close()
Example #9
0
 def test_round_trip_reverse(self):
     f_tf = tf.math.sin
     f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf))
     x = np.float32(0.7)
     self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
Example #10
0
    def ConvertAndCompare(self,
                          func_jax: Callable,
                          *args,
                          enable_xla: bool = True,
                          limitations: Sequence = ()):
        """Compares jax_func(*args) with convert(jax_func)(*args).

    It compares the result of JAX, TF ("eager" mode),
    TF with tf.function ("graph" mode), and TF with
    tf.function(jit_compile=True) ("compiled" mode). In each mode,
    either we expect to encounter a known limitation, or the value should
    match the value from the JAX execution.

    Args:
      func_jax: the function to invoke (``func_jax(*args)``)
      args: the arguments.
      enable_xla: if True, allows the use of XLA ops in jax2tf.convert
        (default: True).
      limitations: the set of limitations for this harness (not yet filtered
        by mode).
    """
        # Run JAX. Should not fail, we assume that the harness has been filtered
        # already by JAX unimplemented primitives.
        result_jax = func_jax(*args)  # JAX
        result_tf = None

        func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla)
        tf_args = _make_tf_args(args)

        unexpected_successes: List[str] = []
        # Run the "compiled" mode first, it is most important
        for mode in ("compiled", "eager", "graph"):

            def log_message(extra):
                return f"[{self._testMethodName}] mode={mode}: {extra}"

            jax2tf_limits = tuple(
                filter(lambda l: l.filter(mode=mode), limitations))

            skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run]
            if skip_tf_run:
                logging.info(
                    log_message(
                        f"Skip TF run due to limitations {skip_tf_run}"))
                continue

            try:
                result_tf = _run_tf_function(func_tf, *tf_args, mode=mode)
                tf_exception = None
            except Exception as e:
                tf_exception = e

            expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error]
            if tf_exception:
                if expect_tf_error:
                    logging.info(
                        log_message(
                            "Found expected TF error with enabled limitations "
                            f"{expect_tf_error}; TF error is {tf_exception}"))
                    continue
                else:
                    raise tf_exception
            else:
                if expect_tf_error:
                    # It is more ergonomic to print all successful modes once
                    logging.warning(
                        log_message(
                            f"Unexpected success with known limitations {expect_tf_error}"
                        ))
                    unexpected_successes.append(f"{mode}: {expect_tf_error}")

            skip_comparison = [l for l in jax2tf_limits if l.skip_comparison]
            if skip_comparison:
                logging.warning(
                    log_message(
                        f"Skip result comparison due to {skip_comparison}"))
                continue

            max_tol = None
            max_tol_lim = None if not jax2tf_limits else jax2tf_limits[
                0].get_max_tolerance_limitation(jax2tf_limits)
            if max_tol_lim is not None:
                max_tol = max_tol_lim.tol
                logging.info(
                    log_message(f"Using tol={max_tol} due to {max_tol_lim}"))

            # Convert results to np.arrays
            result_tf = tf.nest.map_structure(lambda t: t.numpy(),
                                              result_tf)  # type: ignore

            custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
            assert len(
                custom_assert_lim
            ) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"

            if custom_assert_lim:
                logging.info(
                    log_message(
                        f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"
                    ))
                custom_assert_lim[0].custom_assert(self,
                                                   result_jax,
                                                   result_tf,
                                                   args=args,
                                                   tol=max_tol)
            else:
                logging.info(
                    log_message(f"Running default assert with tol={max_tol}"))
                # In compiled mode we expect the same result as JAX by default
                self.assertAllClose(result_jax,
                                    result_tf,
                                    atol=max_tol,
                                    rtol=max_tol)

        # end "for mode"

        if unexpected_successes:
            msg = (f"[{self._testMethodName}] The following are unexpected "
                   "successful modes:\n" + "\n".join(unexpected_successes))
            logging.warning(msg)
            # Uncomment the below if you want to see warnings as failures
            # self.assertEmpty(msg)
        return result_jax, result_tf
Example #11
0
    def ConvertAndCompare(self,
                          func_jax: Callable,
                          *args,
                          custom_assert: Optional[Callable] = None,
                          always_custom_assert: bool = False,
                          expect_tf_exceptions: bool = False,
                          atol=None,
                          rtol=None) -> Tuple[Any, Any]:
        """Compares jax_func(*args) with convert(jax_func)(*args).

    It compares the result of JAX, TF ("eager" mode),
    TF with tf.function ("graph" mode), and TF with
    tf.function(experimental_compile=True) ("compiled" mode). In each mode,
    either we expect an exception (see `expect_tf_exceptions`) or the value
    should match the value from the JAX execution.

    Args:
      custom_assert: a function that will be called
        `custom_assert(result_jax, result_tf)` to assert equality of the
        results. Use this function when JAX and TF produce different results.
        This function is only used for "eager" and "graph" modes by default, not
        for the "compiled" mode, because in that case we expect the results to
        be equal.
      always_custom_assert: if True, custom_assert is also called in "compiled"
        mode. This is useful in cases where JAX and TF produce different but
        equally valid results.
      expect_tf_exceptions: if True, there may be exceptions in some evaluation
        modes; when there is no exception the result should be the same
        as in JAX.
    """
        original_impl = jax2tf.jax2tf.TensorFlowTrace.get_primitive_impl

        # Monkey-patch jax2tf.TensorFlowTrace.get_primitive_impl to wrap the
        # resulting primitive in a categorizer.
        wrapper = correctness_stats.collect_limitations
        jax2tf.jax2tf.TensorFlowTrace.get_primitive_impl = (  # type: ignore
            lambda s, p: wrapper(p, original_impl(s, p)))

        def restore_get_primitive_impl():
            jax2tf.jax2tf.TensorFlowTrace.get_primitive_impl = original_impl

        # Restore the original jax2tf.TensorFlowTrace.get_primitive_impl
        # implementation at the end of the test.
        self.addCleanup(restore_get_primitive_impl)

        # Run JAX
        result_jax = func_jax(*args)
        # Run TF in all execution modes
        func_tf = jax2tf.convert(func_jax)

        def convert_if_bfloat16(v):
            if hasattr(v, "dtype"):
                return tf.convert_to_tensor(
                    np.array(v, jnp.float32) if v.dtype == jnp.bfloat16 else v,
                    jax2tf.jax2tf.to_tf_dtype(v.dtype))
            return v

        tf_args = tuple(map(convert_if_bfloat16, args))

        def run_tf(mode):
            if mode == "eager":
                return func_tf(*tf_args)
            elif mode == "graph":
                return tf.function(func_tf, autograph=False)(*tf_args)
            elif mode == "compiled":
                return tf.function(func_tf,
                                   autograph=False,
                                   experimental_compile=True)(*tf_args)
            else:
                assert False

        def is_tf_exception(lim: correctness_stats.Limitation):
            return (lim.error_type == 'Missing TF support'
                    and self.tf_default_device.device_type in lim.devices)

        result_tf = None
        for mode in ("eager", "graph", "compiled"):
            current_limitations_len = len(correctness_stats.all_limitations)
            try:
                result_tf = run_tf(mode)
            except Exception as e:
                new_limitations = (correctness_stats.
                                   all_limitations[current_limitations_len:])
                detected_tf_exception = any(
                    map(is_tf_exception, new_limitations))

                if not (expect_tf_exceptions or detected_tf_exception):
                    raise e
                else:
                    for lim in new_limitations:
                        print("Detected limitation: {} for {} devices.".format(
                            lim.error_string, ', '.join(lim.devices)))

                    print(
                        f"Encountered expected exception for mode={mode}: {e}")
                    continue

            if custom_assert is not None and (mode in ("eager", "graph")
                                              or always_custom_assert):
                custom_assert(result_jax, result_tf)
            else:
                # In compiled mode we expect the same result as JAX by default
                self.assertAllClose(result_jax,
                                    result_tf,
                                    atol=atol,
                                    rtol=rtol)

        return (result_jax, result_tf)
Example #12
0
 def test_variable_input(self):
     f_jax = lambda x: jnp.sin(jnp.cos(x))
     f_tf = jax2tf.convert(f_jax)
     v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_))
     self.assertIsInstance(f_tf(v), tf.Tensor)
     self.assertAllClose(f_jax(0.7), f_tf(v))
Example #13
0
 def _tf_grad(a, b):
     with tf.GradientTape() as tape:
         tape.watch(a)
         result = jax2tf.convert(f_jax)(a, b)
     return result, tape.gradient(result, a)
Example #14
0
 def __init__(self, apply_fn, params, name=None):
     super().__init__(name=name)
     self.apply_fn = jax2tf.convert(apply_fn)
     self.params = tf.nest.map_structure(tf.Variable, params)
Example #15
0
 def test_variable_input(self):
   f_jax = lambda x: jnp.sin(jnp.cos(x))
   f_tf = jax2tf.convert(f_jax)
   v = tf.Variable(0.7)
   self.assertIsInstance(f_tf(v), tf.Tensor)
   self.assertAllClose(f_jax(0.7), f_tf(v))
Example #16
0
  def test_gradients_int_argument(self, with_function=True):
    # https://github.com/google/jax/issues/6975
    # Also issue #6975.
    # An expanded version of test_gradients_unused_argument
    state = dict(
        float_used=np.array([0.7, 0.9], dtype=np.float32),
        float_passthrough=np.float16(1.),
        float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32),
        int_used=np.int16(5),
        int_passthrough=np.int8(7),
        int_unused=np.array([1, 2, 3], dtype=np.uint32),
        bool_used=np.array([True, False, False, True], dtype=np.bool_),
        bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_),
        bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_),
    )
    def jax_f(state):
      res = dict(state,
                 float_used=2. * state["float_used"],
                 int_used=3 * state["int_used"],
                 bool_used=(state["bool_used"] == state["bool_used"]))
      del res["float_unused"]
      del res["int_unused"]
      del res["bool_unused"]
      return res

    args = (state,)
    res_jax = jax_f(*args)
    # Native JAX AD
    vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax)
    grad_jax, = vjp_jax_fun(*args_vjp)

    def compare_with_overrides(*, what, expected, **expected_overrides):
      what_keys = set(what.keys())
      expected_keys = set(expected.keys())
      self.assertEqual(what_keys, expected_keys)
      for k, w in what.items():
        e = expected[k]
        if k in expected_overrides:
          if expected_overrides[k] == "ZERO":
            e = np.zeros_like(w)
          elif expected_overrides[k] == "ZERO_INT32":
            e = np.zeros(np.shape(w), dtype=np.int32)
          elif expected_overrides[k] == "ONE":
            e = np.ones_like(w)
          else:
            e = expected_overrides[k]

        if e is None:
          self.assertIsNone(w, msg=k)
        else:
          self.assertIsNotNone(w, msg=k)
        w = w.numpy() if isinstance(w, tf.Tensor) else e
        e = e.numpy() if isinstance(e, tf.Tensor) else e
        try:
          self.assertAllClose(e, w, err_msg=k)
        except:
          print(f"Failed at {k}")
          raise


    # compare_with_overrides(g_jax, {},
    #   bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0),
    #   bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0),
    #   bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0),
    #   float_passthrough=np.ones_like(state["float_passthrough"]),
    #   float_unused=np.zeros_like(state["float_unused"]),
    #   float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype),
    #   int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0),
    #   int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0),
    #   int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0))


    # Now native TF gradients, only to test how native TF AD works
    _, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad(
        jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO)
    compare_with_overrides(what=grad_tf_0,
                           expected=grad_jax,
                           float_unused="ZERO",
                           bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO",
                           int_used="ZERO", int_passthrough="ONE", int_unused="ZERO")

    _, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad(
        jax_f, args,
        unconnected_gradients=tf.UnconnectedGradients.NONE)
    compare_with_overrides(what=grad_tf_None,
                           expected=grad_tf_0,
                           float_unused=None, int_used=None, int_unused=None,
                           bool_used=None, bool_unused=None)

    f_tf_jax = jax2tf.convert(jax_f)
    if with_function:
      f_tf_jax = tf.function(f_tf_jax, autograph=False)

    _, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args)
    # Same results as TF native AD with tf.UnconnectedGradients.ZERO
    compare_with_overrides(what=grad_tf_jax_0,
                           expected=grad_tf_0,
                           int_passthrough="ZERO", bool_passthrough="ZERO")

    _, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad(
        f_tf_jax, args,
        unconnected_gradients=tf.UnconnectedGradients.NONE)
    compare_with_overrides(what=grad_tf_jax_None,
                           expected=grad_tf_0,
                           int_used=None, int_passthrough=None, int_unused=None,
                           bool_unused=None, bool_used=None, bool_passthrough=None)

    # Not convert the JAX gradient function
    tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun)
    grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp)
    compare_with_overrides(what=grad_tf_vjp_jax,
                           expected=grad_tf_0,
                           bool_passthrough="ZERO_INT32",
                           bool_unused="ZERO_INT32", bool_used="ZERO_INT32",
                           int_passthrough="ZERO_INT32", int_unused="ZERO_INT32",
                           int_used="ZERO_INT32")
Example #17
0
 def test_nested_jit(self):
   f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
   f_tf = jax2tf.convert(f_jax)
   np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))
Example #18
0
 def test_nested_convert_error(self):
   def outer(y):
     return jax2tf.convert(jnp.sin)(y)  # Inner convert takes tracer args
   with self.assertRaisesRegex(
       ValueError, "convert must be used outside all JAX transformations"):
     jax2tf.convert(outer)(np.ones((4, )))
Example #19
0
 def test_bfloat16_returned_by_jax(self):
   f_jax = lambda a, b: (a + b).astype(jnp.bfloat16)
   f_tf = jax2tf.convert(f_jax)
   self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)
Example #20
0
 def outer(y):
   return jax2tf.convert(jnp.sin)(y)  # Inner convert takes tracer args
Example #21
0
    def ConvertAndCompare(self,
                          func_jax: Callable,
                          *args,
                          custom_assert: Optional[Callable] = None,
                          expect_tf_exceptions: bool = False,
                          atol=None,
                          rtol=None) -> Tuple[Any, Any]:
        """Compares jax_func(*args) with convert(jax_func)(*args).

    It compares the result of JAX, TF ("eager" mode),
    TF with tf.function ("graph" mode), and TF with
    tf.function(experimental_compile=True) ("compiled" mode). In each mode,
    either we expect an exception (see `expect_tf_exceptions`) or the value
    should match the value from the JAX execution.

    Args:
      custom_assert: a function that will be called
        `custom_assert(result_jax, result_tf)` to assert equality of the
        results. Use this function when JAX and TF produce different results.
        This function is only used for "eager" and "graph" modes, not for the
        "compiled" mode, because in that case we expect always the results
        to be equal.
      expect_tf_exceptions: if True, there may be exceptions in some evaluation
        modes; when there is no exception the result should be the same
        as in JAX.
    """
        # Run JAX
        result_jax = func_jax(*args)
        # Run TF in all execution modes
        func_tf = jax2tf.convert(func_jax)

        def run_tf(mode):
            if mode == "eager":
                return func_tf(*args)
            elif mode == "graph":
                return tf.function(func_tf, autograph=False)(*args)
            elif mode == "compiled":
                return tf.function(func_tf,
                                   autograph=False,
                                   experimental_compile=True)(*args)
            else:
                assert False

        result_tf = None
        for mode in ("eager", "graph", "compiled"):
            try:
                result_tf = run_tf(mode)
            except Exception as e:
                if not expect_tf_exceptions:
                    raise e
                else:
                    print(f"Encountered exception for mode={mode}: {e}")
                    continue

            if custom_assert is not None and mode in ("eager", "graph"):
                custom_assert(result_jax, result_tf)
            else:
                # In compiled mode we always expect the same result as JAX
                self.assertAllClose(result_jax,
                                    result_tf,
                                    atol=atol,
                                    rtol=rtol)

        return (result_jax, result_tf)
Example #22
0
 def outer(y):
   sin_1 = jax2tf.convert(jnp.sin)(1.)  # Inner convert takes non-tracer arg
   return y + sin_1
Example #23
0
    def _check_sharding_annotations(self,
                                    f_jax,
                                    args: Sequence[Any],
                                    *,
                                    expected: Sequence[str],
                                    expected_opt: Sequence[str],
                                    num_partitions=2):
        """Check expected patterns in the HLO generated from f_jax and its conversion.

    We run this check on CPU also, which is useful for debugging locally.
    We currently check the unoptimized HLO against `expected` on CPU and TPU,
    and we check the optimized HLO against `expected_opt` on TPU only and
    only for JAX.

    See `self.AssertShardingAnnotations` for documentation of `expected`
    and `expected_opt`.
    """
        if jtu.device_under_test() == "gpu":
            raise unittest.SkipTest("Sharding HLO tests not useful for GPU")

        jax_comp = jax.xla_computation(f_jax)(*args)
        jax_hlo = jax_comp.as_hlo_text()
        if LOG_HLO:
            logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
        self.AssertShardingAnnotations("JAX before optimizations", jax_hlo,
                                       expected)

        if jtu.device_under_test() == "tpu":
            backend = jax._src.lib.xla_bridge.get_backend()
            num_replicas = 1
            device_assignment = np.arange(num_partitions * num_replicas)
            device_assignment = np.reshape(device_assignment,
                                           (-1, num_partitions))
            use_spmd_partitioning = num_partitions > 1
            compile_options = jax._src.lib.xla_bridge.get_compile_options(
                num_replicas=num_replicas,
                num_partitions=num_partitions,
                device_assignment=device_assignment,
                use_spmd_partitioning=use_spmd_partitioning,
            )
            jax_optimized_hlo = backend.compile(
                jax_comp, compile_options).hlo_modules()[0].to_string()
            if LOG_HLO:
                logging.info("[%s] got JAX optimized HLO for platform %s %s",
                             self._testMethodName, backend.platform,
                             jax_optimized_hlo)
            self.AssertShardingAnnotations("JAX after optimizations",
                                           jax_optimized_hlo, expected_opt)

        f_tf = jax2tf.convert(f_jax)
        device_name = f"/device:{jtu.device_under_test().upper()}:0"
        tf_hlo = (tf.function(f_tf, jit_compile=True,
                              autograph=False).experimental_get_compiler_ir(
                                  *args)(stage="hlo", device_name=device_name))
        if LOG_HLO:
            logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo)
        self.AssertShardingAnnotations("TF before optimizations", tf_hlo,
                                       expected)
        tf_optimized_hlo = (tf.function(
            f_tf, jit_compile=True).experimental_get_compiler_ir(*args)(
                stage="optimized_hlo", device_name=device_name))
        if LOG_HLO:
            logging.info("[%s] got TF optimized HLO for %s: %s",
                         self._testMethodName, device_name, tf_optimized_hlo)
Example #24
0
 def test_convert_argument_non_callable_error(self):
   with self.assertRaisesRegex(TypeError, "Expected a callable value"):
     jax2tf.convert(5.)
Example #25
0
    def ConvertAndCompare(self,
                          func_jax: Callable,
                          *args,
                          enable_xla: bool = True,
                          limitations: Sequence = ()):
        """Compares jax_func(*args) with convert(jax_func)(*args).

    It compares the result of JAX, TF ("eager" mode),
    TF with tf.function ("graph" mode), and TF with
    tf.function(jit_compile=True) ("compiled" mode). In each mode,
    either we expect to encounter a known limitation, or the value should
    match the value from the JAX execution.

    Args:
      func_jax: the function to invoke (``func_jax(*args)``)
      args: the arguments.
      enable_xla: if True, allows the use of XLA ops in jax2tf.convert
        (default: True).
      limitations: the set of limitations for this harness (not yet filtered
        by mode).
    """
        # Run JAX. Should not fail, we assume that the harness has been filtered
        # already by JAX unimplemented primitives.
        result_jax = func_jax(*args)  # JAX
        result_tf = None

        func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla)

        unexpected_successes: List[str] = []
        # Run the "compiled" mode first, it is most important
        for mode in ("compiled", "eager", "graph"):

            def log_message(extra):
                return f"[{self._testMethodName}] mode={mode}: {extra}"

            jax2tf_limits = tuple(
                filter(lambda l: l.filter(mode=mode), limitations))

            skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run]
            if skip_tf_run:
                logging.info(
                    log_message(
                        f"Skip TF run due to limitations {skip_tf_run}"))
                continue

            try:
                result_tf = _run_tf_function(func_tf, *args, mode=mode)
                tf_exception = None
            except Exception as e:
                tf_exception = e

            expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error]
            if tf_exception:
                if expect_tf_error:
                    logging.info(
                        log_message(
                            "Found expected TF error with enabled limitations "
                            f"{expect_tf_error}; TF error is {tf_exception}"))
                    continue
                else:
                    raise tf_exception
            else:
                if expect_tf_error:
                    # It is more ergonomic to print all successful modes once
                    logging.warning(
                        log_message(
                            f"Unexpected success with known limitations {expect_tf_error}"
                        ))
                    unexpected_successes.append(f"{mode}: {expect_tf_error}")

            if (jtu.device_under_test() == "gpu"
                    and "dot_general_preferred" in self._testMethodName):
                logging.info(
                    log_message(
                        f"Arguments are {args}, JAX result is {result_jax}\nand TF result is {result_tf}"
                    ))

            skip_comparison = [l for l in jax2tf_limits if l.skip_comparison]
            if skip_comparison:
                logging.warning(
                    log_message(
                        f"Skip result comparison due to {skip_comparison}"))
                continue

            max_tol = None
            max_tol_lim = None if not jax2tf_limits else jax2tf_limits[
                0].get_max_tolerance_limitation(jax2tf_limits)
            if max_tol_lim is not None:
                max_tol = max_tol_lim.tol
                logging.info(
                    log_message(f"Using tol={max_tol} due to {max_tol_lim}"))

            # Convert results to np.arrays
            result_tf = tf.nest.map_structure(lambda t: t.numpy(),
                                              result_tf)  # type: ignore

            custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert]
            assert len(
                custom_assert_lim
            ) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}"

            try:
                err_msg = f"TF mode {mode}."
                log_hlo_on_error = mode == "compiled" or jtu.device_under_test(
                ) == "tpu"
                if log_hlo_on_error:
                    err_msg += " See the logs for JAX and TF HLO comparisons."
                if custom_assert_lim:
                    logging.info(
                        log_message(
                            f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}"
                        ))
                    custom_assert_lim[0].custom_assert(self,
                                                       result_jax,
                                                       result_tf,
                                                       args=args,
                                                       tol=max_tol,
                                                       err_msg=err_msg)
                else:
                    logging.info(
                        log_message(
                            f"Running default assert with tol={max_tol}"))
                    self.assertAllClose(result_jax,
                                        result_tf,
                                        atol=max_tol,
                                        rtol=max_tol,
                                        err_msg=err_msg)
            except AssertionError as e:
                # Print the HLO for comparison
                if not log_hlo_on_error:
                    print(
                        f"[{self._testMethodName}] Not logging HLO because the "
                        f"mode was {mode}")
                    raise

                logging.info(
                    f"[{self._testMethodName}] Logging HLO for exception in mode {mode}: {e}"
                )
                jax_comp = jax.xla_computation(func_jax)(*args)
                jax_hlo = jax_comp.as_hlo_text()
                logging.info(f"[{self._testMethodName}] "
                             f"JAX NON_OPT HLO\n{jax_hlo}")

                tf_args_signature = _make_tf_input_signature(*args)
                # If we give the signature, we cannot pass scalars
                tf_args_no_scalars = tuple(
                    map(
                        lambda a, sig: tf.convert_to_tensor(
                            a, dtype=sig.dtype), args, tf_args_signature))

                tf_func_compiled = tf.function(
                    func_tf,
                    autograph=False,
                    jit_compile=True,
                    input_signature=tf_args_signature)
                tf_hlo = tf_func_compiled.experimental_get_compiler_ir(
                    *tf_args_no_scalars)(stage="hlo")
                logging.info(
                    f"[{self._testMethodName}] TF NON OPT HLO\n{tf_hlo}")

                backend = jax.lib.xla_bridge.get_backend()
                modules = backend.compile(jax_comp).hlo_modules()
                jax_opt_hlo = modules[0].to_string()
                logging.info(f"[{self._testMethodName}] "
                             f"JAX OPT HLO\n{jax_opt_hlo}")

                # TODO(b/189265364): Remove this workaround
                if (jtu.device_under_test() == "gpu"
                        and "dot_general" in self._testMethodName):
                    print(
                        f"[{self._testMethodName}] Not logging TF OPT HLO because of "
                        f"crash in tf.experimental_get_compiler_ir (b/189265364)"
                    )
                else:
                    tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir(
                        *tf_args_no_scalars)(stage="optimized_hlo")
                    logging.info(
                        f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}")

                raise

        # end "for mode"

        if unexpected_successes:
            msg = (f"[{self._testMethodName}] The following are unexpected "
                   "successful modes:\n" + "\n".join(unexpected_successes))
            logging.warning(msg)
            # Uncomment the below if you want to see warnings as failures
            # self.assertEmpty(msg)
        return result_jax, result_tf
Example #26
0
 def test_convert_argument_non_tensor_error(self):
   with self.assertRaisesRegex(TypeError,
                               "Argument.*should be NumPy array"):
     jax2tf.convert(lambda x: x)(lambda y: y)
Example #27
0
 def test_shape_poly(self):
     f_jax = jnp.sin
     f_jax_rt = jax2tf.call_tf(
         jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"]))
     x = np.array([0.7, 0.8], dtype=np.float32)
     self.assertAllClose(f_jax(x), f_jax_rt(x))
Example #28
0
 def test_argument_eager_tensor(self):
   x = jax2tf.convert(jnp.sin)(1.)
   jax2tf.convert(jnp.cos)(x)  # No error
Example #29
0
 def test_stop_gradient(self):
     f = jax2tf.convert(lax.stop_gradient)
     self.assertEqual(f(tf.ones([])), 1.)
def convert_and_save_model(jax_fn,
                           params,
                           model_dir,
                           *,
                           input_signatures,
                           polymorphic_shapes=None,
                           with_gradient=False,
                           enable_xla=True,
                           compile_model=True,
                           saved_model_options=None):
    """Convert a JAX function and saves a SavedModel.

  This is an example, we do not promise backwards compatibility for this code.
  For serious uses, please copy and and expand it as needed (see note at the top
  of the module).

  Use this function if you have a trained ML model that has both a prediction
  function and trained parameters, which you want to save separately from the
  function graph as variables (e.g., to avoid limits on the size of the
  GraphDef, or to enable fine-tuning.) If you don't have such parameters,
  you can still use this library function but probably don't need it
  (see jax2tf/README.md for some simple examples).

  In order to use this wrapper you must first convert your model to a function
  with two arguments: the parameters and the input on which you want to do
  inference. Both arguments may be np.ndarray or (nested)
  tuples/lists/dictionaries thereof.

  See the README.md for a discussion of how to prepare Flax and Haiku models.

  Args:
    jax_fn: a JAX function taking two arguments, the parameters and the inputs.
      Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
    params: the parameters, to be used as first argument for `jax_fn`. These
      must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
      saved as the variables of the SavedModel.
    model_dir: the directory where the model should be saved.
    input_signatures: the input signatures for the second argument of `jax_fn`
      (the input). A signature must be a `tensorflow.TensorSpec` instance, or a
      (nested) tuple/list/dictionary thereof with a structure matching the
      second argument of `jax_fn`. The first input_signature will be saved as
      the default serving signature. The additional signatures will be used
      only to ensure that the `jax_fn` is traced and converted to TF for the
      corresponding input shapes.
    with_gradient: the value to use for the `with_gradient` parameter for
      `jax2tf.convert`.
    enable_xla: the value to use for the `enable_xla` parameter for
      `jax2tf.convert`.
    compile_model: use TensorFlow jit_compiler on the SavedModel. This
      is needed if the SavedModel will be used for TensorFlow serving.
    polymorphic_shapes: if given then it will be used as the
      `polymorphic_shapes` argument to jax2tf.convert for the second parameter of
      `jax_fn`. In this case, a single `input_signatures` is supported, and
      should have `None` in the polymorphic dimensions.
    saved_model_options: options to pass to savedmodel.save.
  """
    if not input_signatures:
        raise ValueError("At least one input_signature must be given")
    if polymorphic_shapes is not None:
        if len(input_signatures) > 1:
            raise ValueError("For shape-polymorphic conversion a single "
                             "input_signature is supported.")
    tf_fn = jax2tf.convert(jax_fn,
                           with_gradient=with_gradient,
                           polymorphic_shapes=[None, polymorphic_shapes],
                           enable_xla=enable_xla)

    # Create tf.Variables for the parameters. If you want more useful variable
    # names, you can use `tree.map_structure_with_path` from the `dm-tree` package
    param_vars = tf.nest.map_structure(
        lambda param: tf.Variable(param, trainable=with_gradient), params)
    tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
                           autograph=False,
                           jit_compile=compile_model)

    signatures = {}
    # This signature is needed for TensorFlow Serving use.
    signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
      tf_graph.get_concrete_function(input_signatures[0])
    for input_signature in input_signatures[1:]:
        # If there are more signatures, trace and cache a TF function for each one
        tf_graph.get_concrete_function(input_signature)
    wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
    if with_gradient:
        if not saved_model_options:
            saved_model_options = tf.saved_model.SaveOptions(
                experimental_custom_gradients=True)
        else:
            saved_model_options.experimental_custom_gradients = True
    tf.saved_model.save(wrapper,
                        model_dir,
                        signatures=signatures,
                        options=saved_model_options)