import jax import jax.numpy as jnp from jax.experimental import jax2tf import tf2xla_pb2 import tensorflow.compat.v2 as tf tf.enable_v2_behavior() def f(a, b): return jax.lax.add(a, b).sum() f = tf.function(jax2tf.convert(f)) a = b = tf.ones([1, 1]) cf = f.get_concrete_function(a, b) graph_def = cf.graph.as_graph_def() with open('graph.pb', 'wb') as fp: fp.write(graph_def.SerializeToString()) config = tf2xla_pb2.Config() batch_size = 1 feeds = [o.name for o in cf.graph.get_operations() if o.name.startswith('jax2tf_arg')] fetches = [o.name for o in cf.graph.get_operations() if o.name.startswith('jax2tf_out')] for idx, x in enumerate(cf.inputs): x.set_shape([batch_size] + list(x.shape)[1:]) feed = config.feed.add() feed.id.node_name = feeds[idx] feed.shape.MergeFrom(x.shape.as_proto())
def test_simple(self): f_jax = jnp.sin f_jax_rt = jax2tf.call_tf(jax2tf.convert(f_jax)) x = np.float32(0.7) self.assertAllClose(f_jax(x), f_jax_rt(x))
def f_tf_outer(x_tf): y_tf = tf.math.sin(x_tf) z_tf = jax2tf.convert(f_jax)(y_tf) return tf.math.sin(z_tf)
def convert_and_save_model( jax_fn: tp.Callable[[tp.Any, tp.Any], tp.Any], params, model_dir: str, *, input_signatures: tp.Sequence[tf.TensorSpec], shape_polymorphic_input_spec: tp.Optional[str] = None, with_gradient: bool = False, enable_xla: bool = True, compile_model: bool = True, save_model_options: tp.Optional[ tf.saved_model.SaveOptions] = None): """Convert a JAX function and saves a SavedModel. This is an example, for serious uses you will likely want to copy and expand it as needed (see note at the top of the model). Use this function if you have a trained ML model that has both a prediction function and trained parameters, which you want to save separately from the function graph as variables (e.g., to avoid limits on the size of the GraphDef, or to enable fine-tuning.) If you don't have such parameters, you can still use this library function but probably don't need it (see jax2tf/README.md for some simple examples). In order to use this wrapper you must first convert your model to a function with two arguments: the parameters and the input on which you want to do inference. Both arguments may be np.ndarray or (nested) tuples/lists/dictionaries thereof. See the README.md for a discussion of how to prepare Flax and Haiku models. Args: jax_fn: a JAX function taking two arguments, the parameters and the inputs. Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray. params: the parameters, to be used as first argument for `jax_fn`. These must be (nested) tuples/lists/dictionaries of np.ndarray, and will be saved as the variables of the SavedModel. model_dir: the directory where the model should be saved. input_signatures: the input signatures for the second argument of `jax_fn` (the input). A signature must be a `tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary thereof with a structure matching the second argument of `jax_fn`. The first input_signature will be saved as the default serving signature. The additional signatures will be used only to ensure that the `jax_fn` is traced and converted to TF for the corresponding input shapes. shape_polymorphic_input_spec: if given then it will be used as the `in_shapes` argument to jax2tf.convert for the second parameter of `jax_fn`. In this case, a single `input_signatures` is supported, and should have `None` in the polymorphic dimensions. Should be a string, or a (nesteD) tuple/list/dictionary thereof with a structure matching the second argument of `jax_fn`. with_gradient: whether the SavedModel should support gradients. If True, then a custom gradient is saved. If False, then a tf.raw_ops.PreventGradient is saved to error if a gradient is attempted. (At the moment due to a bug in SavedModel, custom gradients are not supported.) enable_xla: whether the jax2tf converter is allowed to use TFXLA ops. If False, the conversion tries harder to use purely TF ops and raises an exception if it is not possible. (default: True) compile_model: use TensorFlow jit_compiler on the SavedModel. This is needed if the SavedModel will be used for TensorFlow serving. save_model_options: options to pass to savedmodel.save. """ if not input_signatures: raise ValueError("At least one input_signature must be given") if shape_polymorphic_input_spec is not None: if len(input_signatures) > 1: raise ValueError("For shape-polymorphic conversion a single " "input_signature is supported.") tf_fn = jax2tf.convert( jax_fn, with_gradient=with_gradient, in_shapes=[None, shape_polymorphic_input_spec], enable_xla=enable_xla, ) # Create tf.Variables for the parameters. If you want more useful variable # names, you can use `tree.map_structure_with_path` from the `dm-tree` package param_vars = tf.nest.map_structure( # Due to a bug in SavedModel it is not possible to use tf.GradientTape on # a function converted with jax2tf and loaded from SavedModel. Thus, we # mark the variables as non-trainable to ensure that users of the # SavedModel will not try to fine tune them. lambda param: tf.Variable(param, trainable=with_gradient), params, ) tf_fun = tf.function( lambda inputs: tf_fn(param_vars, inputs), autograph=False, experimental_compile=compile_model, ) signatures = {} # This signature is needed for TensorFlow Serving use. signatures[ tf.saved_model. DEFAULT_SERVING_SIGNATURE_DEF_KEY] = tf_fun.get_concrete_function( input_signatures[0]) for input_signature in input_signatures[1:]: # If there are more signatures, trace and cache a TF function for each one tf_fun.get_concrete_function(input_signature) wrapper = _ReusableSavedModelWrapper(tf_fun, param_vars) tf.saved_model.save(wrapper, model_dir, signatures=signatures, options=save_model_options)
def _make_trained_model(train_data: tf.data.Dataset, eval_data: tf.data.Dataset, num_epochs: int, steps_per_epoch: int, eval_steps_per_epoch: int, tensorboard_log_dir: str): """Execute model training and evaluation loop. Args: train_data: a dataset with training pairs (_InputBatch, _LabelBatch). eval_data: a dataset with evaluation pairs (_InputBatch, _LabelBatch). num_epochs: number of training epochs. steps_per_epoch: number of steps for a training epoch. Should be the number of samples in your train_data divided by the batch size. eval_steps_per_epoch: number of steps for evaluation at the end of each training epoch. Should be the number of samples in your eval_data divided by the batch size. tensorboard_log_dir: Directory where the tensorboard summaries are written. Returns: An instance of tf.Model. """ learning_rate = 1e-2 rng = jax.random.PRNGKey(0) summary_writer = tensorboard.SummaryWriter(tensorboard_log_dir) summary_writer.hparams( dict( learning_rate=learning_rate, num_epochs=num_epochs, steps_per_epoch=steps_per_epoch, eval_steps_per_epoch=eval_steps_per_epoch)) rng, init_rng = jax.random.split(rng) # Initialize with some fake data of the proper shape. init_val = dict((feature, jnp.array([[1.]], dtype=jnp.float32)) for feature in _FEATURE_KEYS_XF) model = _FlaxPenguinModel() params = model.init(init_rng, init_val)['params'] optimizer_def = flax.optim.Adam(learning_rate=learning_rate) optimizer = optimizer_def.create(params) for epoch in range(1, num_epochs + 1): optimizer, train_metrics = _train_epoch(model, optimizer, train_data, steps_per_epoch) absl.logging.info('Flax train epoch: %d, loss: %.4f, accuracy: %.2f', epoch, train_metrics['loss'], train_metrics['accuracy'] * 100) eval_metrics = _eval_epoch(model, optimizer.target, eval_data, eval_steps_per_epoch) absl.logging.info('Flax eval epoch: %d, loss: %.4f, accuracy: %.2f', epoch, eval_metrics['loss'], eval_metrics['accuracy'] * 100) summary_writer.scalar('epoch_train_loss', train_metrics['loss'], epoch) summary_writer.scalar('epoch_train_accuracy', train_metrics['accuracy'], epoch) summary_writer.scalar('epoch_eval_loss', eval_metrics['loss'], epoch) summary_writer.scalar('epoch_eval_accuracy', eval_metrics['accuracy'], epoch) summary_writer.flush() # The prediction function for the trained model def predict(params: _Params, inputs: _InputBatch): return model.apply({'params': params}, inputs) trained_params = optimizer.target # Convert the prediction function to TF, with a variable batch dimension # for all inputs. tf_fn = jax2tf.convert(predict, with_gradient=False, enable_xla=True, polymorphic_shapes=(None, '(b, 1)')) # Create tf.Variables for the parameters. If you want more useful variable # names, you can use `tree.map_structure_with_path` from the `dm-tree` # package. param_vars = tf.nest.map_structure( # Due to a bug in SavedModel it is not possible to use tf.GradientTape # on a function converted with jax2tf and loaded from SavedModel. # Thus, we mark the variables as non-trainable to ensure that users of # the SavedModel will not try to fine tune them. lambda param: tf.Variable(param, trainable=False), trained_params) tf_graph = tf.function( lambda inputs: tf_fn(param_vars, inputs), autograph=False, experimental_compile=True) return _SavedModelWrapper(tf_graph, param_vars)
def test_bfloat16_passed_by_tf(self): f_jax = lambda a, b: a + b f_tf = tf.function(jax2tf.convert(f_jax), input_signature=[tf.TensorSpec([512, 512], tf.bfloat16), tf.TensorSpec([512, 512], tf.bfloat16)]) self.assertIsNotNone(f_tf.get_concrete_function())
def jax2tf_xla(test_case: ModelTestCase): """Converts the given `module` using the jax2tf emitter with enable_xla=True.""" jax_fn = functools.partial(test_case.apply, test_case.variables) tf_fn = jax2tf.convert(jax_fn, enable_xla=True) _compare(test_case, jax_fn, tf_fn, "JAX vs TF (enable_xla=True)")
def main(_): logging.info('Loading the MNIST TensorFlow dataset') train_ds = mnist_lib.load_mnist(tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist(tfds.Split.TEST, batch_size=FLAGS.serving_batch_size) (flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs) def predict(image): return flax_predict(flax_params, image) # Convert Flax model to TF function. tf_predict = tf.function( jax2tf.convert(predict, enable_xla=False), input_signature=[ tf.TensorSpec(shape=[FLAGS.serving_batch_size, 28, 28, 1], dtype=tf.float32, name='input') ], autograph=False) # Convert TF function to TF Lite format. converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()]) converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. ] tflite_float_model = converter.convert() # Show model size in KBs. float_model_size = len(tflite_float_model) / 1024 print('Float model size = %dKBs.' % float_model_size) # Re-convert the model to TF Lite using quantization. converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_quantized_model = converter.convert() # Show model size in KBs. quantized_model_size = len(tflite_quantized_model) / 1024 print('Quantized model size = %dKBs,' % quantized_model_size) print('which is about %d%% of the float model size.' % (quantized_model_size * 100 / float_model_size)) # Evaluate the TF Lite float model. You'll find that its accurary is identical # to the original Flax model because they are essentially the same model # stored in different format. float_accuracy = evaluate_tflite_model(tflite_float_model, test_ds) print('Float model accuracy = %.4f' % float_accuracy) # Evalualte the TF Lite quantized model. # Don't be surprised if you see quantized model accuracy is higher than # the original float model. It happens sometimes :) quantized_accuracy = evaluate_tflite_model(tflite_quantized_model, test_ds) print('Quantized model accuracy = %.4f' % quantized_accuracy) print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy)) f = open(FLAGS.tflite_file_path, 'wb') f.write(tflite_quantized_model) f.close()
def test_round_trip_reverse(self): f_tf = tf.math.sin f_tf_rt = jax2tf.convert(jax2tf.call_tf(f_tf)) x = np.float32(0.7) self.assertAllClose(f_tf(x).numpy(), f_tf_rt(x).numpy())
def ConvertAndCompare(self, func_jax: Callable, *args, enable_xla: bool = True, limitations: Sequence = ()): """Compares jax_func(*args) with convert(jax_func)(*args). It compares the result of JAX, TF ("eager" mode), TF with tf.function ("graph" mode), and TF with tf.function(jit_compile=True) ("compiled" mode). In each mode, either we expect to encounter a known limitation, or the value should match the value from the JAX execution. Args: func_jax: the function to invoke (``func_jax(*args)``) args: the arguments. enable_xla: if True, allows the use of XLA ops in jax2tf.convert (default: True). limitations: the set of limitations for this harness (not yet filtered by mode). """ # Run JAX. Should not fail, we assume that the harness has been filtered # already by JAX unimplemented primitives. result_jax = func_jax(*args) # JAX result_tf = None func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla) tf_args = _make_tf_args(args) unexpected_successes: List[str] = [] # Run the "compiled" mode first, it is most important for mode in ("compiled", "eager", "graph"): def log_message(extra): return f"[{self._testMethodName}] mode={mode}: {extra}" jax2tf_limits = tuple( filter(lambda l: l.filter(mode=mode), limitations)) skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run] if skip_tf_run: logging.info( log_message( f"Skip TF run due to limitations {skip_tf_run}")) continue try: result_tf = _run_tf_function(func_tf, *tf_args, mode=mode) tf_exception = None except Exception as e: tf_exception = e expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error] if tf_exception: if expect_tf_error: logging.info( log_message( "Found expected TF error with enabled limitations " f"{expect_tf_error}; TF error is {tf_exception}")) continue else: raise tf_exception else: if expect_tf_error: # It is more ergonomic to print all successful modes once logging.warning( log_message( f"Unexpected success with known limitations {expect_tf_error}" )) unexpected_successes.append(f"{mode}: {expect_tf_error}") skip_comparison = [l for l in jax2tf_limits if l.skip_comparison] if skip_comparison: logging.warning( log_message( f"Skip result comparison due to {skip_comparison}")) continue max_tol = None max_tol_lim = None if not jax2tf_limits else jax2tf_limits[ 0].get_max_tolerance_limitation(jax2tf_limits) if max_tol_lim is not None: max_tol = max_tol_lim.tol logging.info( log_message(f"Using tol={max_tol} due to {max_tol_lim}")) # Convert results to np.arrays result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf) # type: ignore custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] assert len( custom_assert_lim ) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" if custom_assert_lim: logging.info( log_message( f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}" )) custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol) else: logging.info( log_message(f"Running default assert with tol={max_tol}")) # In compiled mode we expect the same result as JAX by default self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol) # end "for mode" if unexpected_successes: msg = (f"[{self._testMethodName}] The following are unexpected " "successful modes:\n" + "\n".join(unexpected_successes)) logging.warning(msg) # Uncomment the below if you want to see warnings as failures # self.assertEmpty(msg) return result_jax, result_tf
def ConvertAndCompare(self, func_jax: Callable, *args, custom_assert: Optional[Callable] = None, always_custom_assert: bool = False, expect_tf_exceptions: bool = False, atol=None, rtol=None) -> Tuple[Any, Any]: """Compares jax_func(*args) with convert(jax_func)(*args). It compares the result of JAX, TF ("eager" mode), TF with tf.function ("graph" mode), and TF with tf.function(experimental_compile=True) ("compiled" mode). In each mode, either we expect an exception (see `expect_tf_exceptions`) or the value should match the value from the JAX execution. Args: custom_assert: a function that will be called `custom_assert(result_jax, result_tf)` to assert equality of the results. Use this function when JAX and TF produce different results. This function is only used for "eager" and "graph" modes by default, not for the "compiled" mode, because in that case we expect the results to be equal. always_custom_assert: if True, custom_assert is also called in "compiled" mode. This is useful in cases where JAX and TF produce different but equally valid results. expect_tf_exceptions: if True, there may be exceptions in some evaluation modes; when there is no exception the result should be the same as in JAX. """ original_impl = jax2tf.jax2tf.TensorFlowTrace.get_primitive_impl # Monkey-patch jax2tf.TensorFlowTrace.get_primitive_impl to wrap the # resulting primitive in a categorizer. wrapper = correctness_stats.collect_limitations jax2tf.jax2tf.TensorFlowTrace.get_primitive_impl = ( # type: ignore lambda s, p: wrapper(p, original_impl(s, p))) def restore_get_primitive_impl(): jax2tf.jax2tf.TensorFlowTrace.get_primitive_impl = original_impl # Restore the original jax2tf.TensorFlowTrace.get_primitive_impl # implementation at the end of the test. self.addCleanup(restore_get_primitive_impl) # Run JAX result_jax = func_jax(*args) # Run TF in all execution modes func_tf = jax2tf.convert(func_jax) def convert_if_bfloat16(v): if hasattr(v, "dtype"): return tf.convert_to_tensor( np.array(v, jnp.float32) if v.dtype == jnp.bfloat16 else v, jax2tf.jax2tf.to_tf_dtype(v.dtype)) return v tf_args = tuple(map(convert_if_bfloat16, args)) def run_tf(mode): if mode == "eager": return func_tf(*tf_args) elif mode == "graph": return tf.function(func_tf, autograph=False)(*tf_args) elif mode == "compiled": return tf.function(func_tf, autograph=False, experimental_compile=True)(*tf_args) else: assert False def is_tf_exception(lim: correctness_stats.Limitation): return (lim.error_type == 'Missing TF support' and self.tf_default_device.device_type in lim.devices) result_tf = None for mode in ("eager", "graph", "compiled"): current_limitations_len = len(correctness_stats.all_limitations) try: result_tf = run_tf(mode) except Exception as e: new_limitations = (correctness_stats. all_limitations[current_limitations_len:]) detected_tf_exception = any( map(is_tf_exception, new_limitations)) if not (expect_tf_exceptions or detected_tf_exception): raise e else: for lim in new_limitations: print("Detected limitation: {} for {} devices.".format( lim.error_string, ', '.join(lim.devices))) print( f"Encountered expected exception for mode={mode}: {e}") continue if custom_assert is not None and (mode in ("eager", "graph") or always_custom_assert): custom_assert(result_jax, result_tf) else: # In compiled mode we expect the same result as JAX by default self.assertAllClose(result_jax, result_tf, atol=atol, rtol=rtol) return (result_jax, result_tf)
def test_variable_input(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) f_tf = jax2tf.convert(f_jax) v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_)) self.assertIsInstance(f_tf(v), tf.Tensor) self.assertAllClose(f_jax(0.7), f_tf(v))
def _tf_grad(a, b): with tf.GradientTape() as tape: tape.watch(a) result = jax2tf.convert(f_jax)(a, b) return result, tape.gradient(result, a)
def __init__(self, apply_fn, params, name=None): super().__init__(name=name) self.apply_fn = jax2tf.convert(apply_fn) self.params = tf.nest.map_structure(tf.Variable, params)
def test_variable_input(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) f_tf = jax2tf.convert(f_jax) v = tf.Variable(0.7) self.assertIsInstance(f_tf(v), tf.Tensor) self.assertAllClose(f_jax(0.7), f_tf(v))
def test_gradients_int_argument(self, with_function=True): # https://github.com/google/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument state = dict( float_used=np.array([0.7, 0.9], dtype=np.float32), float_passthrough=np.float16(1.), float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32), int_used=np.int16(5), int_passthrough=np.int8(7), int_unused=np.array([1, 2, 3], dtype=np.uint32), bool_used=np.array([True, False, False, True], dtype=np.bool_), bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_), bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_), ) def jax_f(state): res = dict(state, float_used=2. * state["float_used"], int_used=3 * state["int_used"], bool_used=(state["bool_used"] == state["bool_used"])) del res["float_unused"] del res["int_unused"] del res["bool_unused"] return res args = (state,) res_jax = jax_f(*args) # Native JAX AD vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax) grad_jax, = vjp_jax_fun(*args_vjp) def compare_with_overrides(*, what, expected, **expected_overrides): what_keys = set(what.keys()) expected_keys = set(expected.keys()) self.assertEqual(what_keys, expected_keys) for k, w in what.items(): e = expected[k] if k in expected_overrides: if expected_overrides[k] == "ZERO": e = np.zeros_like(w) elif expected_overrides[k] == "ZERO_INT32": e = np.zeros(np.shape(w), dtype=np.int32) elif expected_overrides[k] == "ONE": e = np.ones_like(w) else: e = expected_overrides[k] if e is None: self.assertIsNone(w, msg=k) else: self.assertIsNotNone(w, msg=k) w = w.numpy() if isinstance(w, tf.Tensor) else e e = e.numpy() if isinstance(e, tf.Tensor) else e try: self.assertAllClose(e, w, err_msg=k) except: print(f"Failed at {k}") raise # compare_with_overrides(g_jax, {}, # bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0), # bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0), # bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0), # float_passthrough=np.ones_like(state["float_passthrough"]), # float_unused=np.zeros_like(state["float_unused"]), # float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype), # int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0), # int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0), # int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0)) # Now native TF gradients, only to test how native TF AD works _, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad( jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO) compare_with_overrides(what=grad_tf_0, expected=grad_jax, float_unused="ZERO", bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO", int_used="ZERO", int_passthrough="ONE", int_unused="ZERO") _, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad( jax_f, args, unconnected_gradients=tf.UnconnectedGradients.NONE) compare_with_overrides(what=grad_tf_None, expected=grad_tf_0, float_unused=None, int_used=None, int_unused=None, bool_used=None, bool_unused=None) f_tf_jax = jax2tf.convert(jax_f) if with_function: f_tf_jax = tf.function(f_tf_jax, autograph=False) _, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args) # Same results as TF native AD with tf.UnconnectedGradients.ZERO compare_with_overrides(what=grad_tf_jax_0, expected=grad_tf_0, int_passthrough="ZERO", bool_passthrough="ZERO") _, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad( f_tf_jax, args, unconnected_gradients=tf.UnconnectedGradients.NONE) compare_with_overrides(what=grad_tf_jax_None, expected=grad_tf_0, int_used=None, int_passthrough=None, int_unused=None, bool_unused=None, bool_used=None, bool_passthrough=None) # Not convert the JAX gradient function tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun) grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp) compare_with_overrides(what=grad_tf_vjp_jax, expected=grad_tf_0, bool_passthrough="ZERO_INT32", bool_unused="ZERO_INT32", bool_used="ZERO_INT32", int_passthrough="ZERO_INT32", int_unused="ZERO_INT32", int_used="ZERO_INT32")
def test_nested_jit(self): f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x))) f_tf = jax2tf.convert(f_jax) np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))
def test_nested_convert_error(self): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): jax2tf.convert(outer)(np.ones((4, )))
def test_bfloat16_returned_by_jax(self): f_jax = lambda a, b: (a + b).astype(jnp.bfloat16) f_tf = jax2tf.convert(f_jax) self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)
def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
def ConvertAndCompare(self, func_jax: Callable, *args, custom_assert: Optional[Callable] = None, expect_tf_exceptions: bool = False, atol=None, rtol=None) -> Tuple[Any, Any]: """Compares jax_func(*args) with convert(jax_func)(*args). It compares the result of JAX, TF ("eager" mode), TF with tf.function ("graph" mode), and TF with tf.function(experimental_compile=True) ("compiled" mode). In each mode, either we expect an exception (see `expect_tf_exceptions`) or the value should match the value from the JAX execution. Args: custom_assert: a function that will be called `custom_assert(result_jax, result_tf)` to assert equality of the results. Use this function when JAX and TF produce different results. This function is only used for "eager" and "graph" modes, not for the "compiled" mode, because in that case we expect always the results to be equal. expect_tf_exceptions: if True, there may be exceptions in some evaluation modes; when there is no exception the result should be the same as in JAX. """ # Run JAX result_jax = func_jax(*args) # Run TF in all execution modes func_tf = jax2tf.convert(func_jax) def run_tf(mode): if mode == "eager": return func_tf(*args) elif mode == "graph": return tf.function(func_tf, autograph=False)(*args) elif mode == "compiled": return tf.function(func_tf, autograph=False, experimental_compile=True)(*args) else: assert False result_tf = None for mode in ("eager", "graph", "compiled"): try: result_tf = run_tf(mode) except Exception as e: if not expect_tf_exceptions: raise e else: print(f"Encountered exception for mode={mode}: {e}") continue if custom_assert is not None and mode in ("eager", "graph"): custom_assert(result_jax, result_tf) else: # In compiled mode we always expect the same result as JAX self.assertAllClose(result_jax, result_tf, atol=atol, rtol=rtol) return (result_jax, result_tf)
def outer(y): sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg return y + sin_1
def _check_sharding_annotations(self, f_jax, args: Sequence[Any], *, expected: Sequence[str], expected_opt: Sequence[str], num_partitions=2): """Check expected patterns in the HLO generated from f_jax and its conversion. We run this check on CPU also, which is useful for debugging locally. We currently check the unoptimized HLO against `expected` on CPU and TPU, and we check the optimized HLO against `expected_opt` on TPU only and only for JAX. See `self.AssertShardingAnnotations` for documentation of `expected` and `expected_opt`. """ if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Sharding HLO tests not useful for GPU") jax_comp = jax.xla_computation(f_jax)(*args) jax_hlo = jax_comp.as_hlo_text() if LOG_HLO: logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo) self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected) if jtu.device_under_test() == "tpu": backend = jax._src.lib.xla_bridge.get_backend() num_replicas = 1 device_assignment = np.arange(num_partitions * num_replicas) device_assignment = np.reshape(device_assignment, (-1, num_partitions)) use_spmd_partitioning = num_partitions > 1 compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=num_replicas, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) jax_optimized_hlo = backend.compile( jax_comp, compile_options).hlo_modules()[0].to_string() if LOG_HLO: logging.info("[%s] got JAX optimized HLO for platform %s %s", self._testMethodName, backend.platform, jax_optimized_hlo) self.AssertShardingAnnotations("JAX after optimizations", jax_optimized_hlo, expected_opt) f_tf = jax2tf.convert(f_jax) device_name = f"/device:{jtu.device_under_test().upper()}:0" tf_hlo = (tf.function(f_tf, jit_compile=True, autograph=False).experimental_get_compiler_ir( *args)(stage="hlo", device_name=device_name)) if LOG_HLO: logging.info("[%s] got TF HLO %s", self._testMethodName, tf_hlo) self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected) tf_optimized_hlo = (tf.function( f_tf, jit_compile=True).experimental_get_compiler_ir(*args)( stage="optimized_hlo", device_name=device_name)) if LOG_HLO: logging.info("[%s] got TF optimized HLO for %s: %s", self._testMethodName, device_name, tf_optimized_hlo)
def test_convert_argument_non_callable_error(self): with self.assertRaisesRegex(TypeError, "Expected a callable value"): jax2tf.convert(5.)
def ConvertAndCompare(self, func_jax: Callable, *args, enable_xla: bool = True, limitations: Sequence = ()): """Compares jax_func(*args) with convert(jax_func)(*args). It compares the result of JAX, TF ("eager" mode), TF with tf.function ("graph" mode), and TF with tf.function(jit_compile=True) ("compiled" mode). In each mode, either we expect to encounter a known limitation, or the value should match the value from the JAX execution. Args: func_jax: the function to invoke (``func_jax(*args)``) args: the arguments. enable_xla: if True, allows the use of XLA ops in jax2tf.convert (default: True). limitations: the set of limitations for this harness (not yet filtered by mode). """ # Run JAX. Should not fail, we assume that the harness has been filtered # already by JAX unimplemented primitives. result_jax = func_jax(*args) # JAX result_tf = None func_tf = jax2tf.convert(func_jax, enable_xla=enable_xla) unexpected_successes: List[str] = [] # Run the "compiled" mode first, it is most important for mode in ("compiled", "eager", "graph"): def log_message(extra): return f"[{self._testMethodName}] mode={mode}: {extra}" jax2tf_limits = tuple( filter(lambda l: l.filter(mode=mode), limitations)) skip_tf_run = [l for l in jax2tf_limits if l.skip_tf_run] if skip_tf_run: logging.info( log_message( f"Skip TF run due to limitations {skip_tf_run}")) continue try: result_tf = _run_tf_function(func_tf, *args, mode=mode) tf_exception = None except Exception as e: tf_exception = e expect_tf_error = [l for l in jax2tf_limits if l.expect_tf_error] if tf_exception: if expect_tf_error: logging.info( log_message( "Found expected TF error with enabled limitations " f"{expect_tf_error}; TF error is {tf_exception}")) continue else: raise tf_exception else: if expect_tf_error: # It is more ergonomic to print all successful modes once logging.warning( log_message( f"Unexpected success with known limitations {expect_tf_error}" )) unexpected_successes.append(f"{mode}: {expect_tf_error}") if (jtu.device_under_test() == "gpu" and "dot_general_preferred" in self._testMethodName): logging.info( log_message( f"Arguments are {args}, JAX result is {result_jax}\nand TF result is {result_tf}" )) skip_comparison = [l for l in jax2tf_limits if l.skip_comparison] if skip_comparison: logging.warning( log_message( f"Skip result comparison due to {skip_comparison}")) continue max_tol = None max_tol_lim = None if not jax2tf_limits else jax2tf_limits[ 0].get_max_tolerance_limitation(jax2tf_limits) if max_tol_lim is not None: max_tol = max_tol_lim.tol logging.info( log_message(f"Using tol={max_tol} due to {max_tol_lim}")) # Convert results to np.arrays result_tf = tf.nest.map_structure(lambda t: t.numpy(), result_tf) # type: ignore custom_assert_lim = [l for l in jax2tf_limits if l.custom_assert] assert len( custom_assert_lim ) <= 1, f"Expecting at most one applicable limitation with custom_assert, found {custom_assert_lim}" try: err_msg = f"TF mode {mode}." log_hlo_on_error = mode == "compiled" or jtu.device_under_test( ) == "tpu" if log_hlo_on_error: err_msg += " See the logs for JAX and TF HLO comparisons." if custom_assert_lim: logging.info( log_message( f"Running custom_assert with tol={max_tol} due to {custom_assert_lim[0]}" )) custom_assert_lim[0].custom_assert(self, result_jax, result_tf, args=args, tol=max_tol, err_msg=err_msg) else: logging.info( log_message( f"Running default assert with tol={max_tol}")) self.assertAllClose(result_jax, result_tf, atol=max_tol, rtol=max_tol, err_msg=err_msg) except AssertionError as e: # Print the HLO for comparison if not log_hlo_on_error: print( f"[{self._testMethodName}] Not logging HLO because the " f"mode was {mode}") raise logging.info( f"[{self._testMethodName}] Logging HLO for exception in mode {mode}: {e}" ) jax_comp = jax.xla_computation(func_jax)(*args) jax_hlo = jax_comp.as_hlo_text() logging.info(f"[{self._testMethodName}] " f"JAX NON_OPT HLO\n{jax_hlo}") tf_args_signature = _make_tf_input_signature(*args) # If we give the signature, we cannot pass scalars tf_args_no_scalars = tuple( map( lambda a, sig: tf.convert_to_tensor( a, dtype=sig.dtype), args, tf_args_signature)) tf_func_compiled = tf.function( func_tf, autograph=False, jit_compile=True, input_signature=tf_args_signature) tf_hlo = tf_func_compiled.experimental_get_compiler_ir( *tf_args_no_scalars)(stage="hlo") logging.info( f"[{self._testMethodName}] TF NON OPT HLO\n{tf_hlo}") backend = jax.lib.xla_bridge.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() logging.info(f"[{self._testMethodName}] " f"JAX OPT HLO\n{jax_opt_hlo}") # TODO(b/189265364): Remove this workaround if (jtu.device_under_test() == "gpu" and "dot_general" in self._testMethodName): print( f"[{self._testMethodName}] Not logging TF OPT HLO because of " f"crash in tf.experimental_get_compiler_ir (b/189265364)" ) else: tf_opt_hlo = tf_func_compiled.experimental_get_compiler_ir( *tf_args_no_scalars)(stage="optimized_hlo") logging.info( f"[{self._testMethodName}] TF OPT HLO\n{tf_opt_hlo}") raise # end "for mode" if unexpected_successes: msg = (f"[{self._testMethodName}] The following are unexpected " "successful modes:\n" + "\n".join(unexpected_successes)) logging.warning(msg) # Uncomment the below if you want to see warnings as failures # self.assertEmpty(msg) return result_jax, result_tf
def test_convert_argument_non_tensor_error(self): with self.assertRaisesRegex(TypeError, "Argument.*should be NumPy array"): jax2tf.convert(lambda x: x)(lambda y: y)
def test_shape_poly(self): f_jax = jnp.sin f_jax_rt = jax2tf.call_tf( jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])) x = np.array([0.7, 0.8], dtype=np.float32) self.assertAllClose(f_jax(x), f_jax_rt(x))
def test_argument_eager_tensor(self): x = jax2tf.convert(jnp.sin)(1.) jax2tf.convert(jnp.cos)(x) # No error
def test_stop_gradient(self): f = jax2tf.convert(lax.stop_gradient) self.assertEqual(f(tf.ones([])), 1.)
def convert_and_save_model(jax_fn, params, model_dir, *, input_signatures, polymorphic_shapes=None, with_gradient=False, enable_xla=True, compile_model=True, saved_model_options=None): """Convert a JAX function and saves a SavedModel. This is an example, we do not promise backwards compatibility for this code. For serious uses, please copy and and expand it as needed (see note at the top of the module). Use this function if you have a trained ML model that has both a prediction function and trained parameters, which you want to save separately from the function graph as variables (e.g., to avoid limits on the size of the GraphDef, or to enable fine-tuning.) If you don't have such parameters, you can still use this library function but probably don't need it (see jax2tf/README.md for some simple examples). In order to use this wrapper you must first convert your model to a function with two arguments: the parameters and the input on which you want to do inference. Both arguments may be np.ndarray or (nested) tuples/lists/dictionaries thereof. See the README.md for a discussion of how to prepare Flax and Haiku models. Args: jax_fn: a JAX function taking two arguments, the parameters and the inputs. Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray. params: the parameters, to be used as first argument for `jax_fn`. These must be (nested) tuples/lists/dictionaries of np.ndarray, and will be saved as the variables of the SavedModel. model_dir: the directory where the model should be saved. input_signatures: the input signatures for the second argument of `jax_fn` (the input). A signature must be a `tensorflow.TensorSpec` instance, or a (nested) tuple/list/dictionary thereof with a structure matching the second argument of `jax_fn`. The first input_signature will be saved as the default serving signature. The additional signatures will be used only to ensure that the `jax_fn` is traced and converted to TF for the corresponding input shapes. with_gradient: the value to use for the `with_gradient` parameter for `jax2tf.convert`. enable_xla: the value to use for the `enable_xla` parameter for `jax2tf.convert`. compile_model: use TensorFlow jit_compiler on the SavedModel. This is needed if the SavedModel will be used for TensorFlow serving. polymorphic_shapes: if given then it will be used as the `polymorphic_shapes` argument to jax2tf.convert for the second parameter of `jax_fn`. In this case, a single `input_signatures` is supported, and should have `None` in the polymorphic dimensions. saved_model_options: options to pass to savedmodel.save. """ if not input_signatures: raise ValueError("At least one input_signature must be given") if polymorphic_shapes is not None: if len(input_signatures) > 1: raise ValueError("For shape-polymorphic conversion a single " "input_signature is supported.") tf_fn = jax2tf.convert(jax_fn, with_gradient=with_gradient, polymorphic_shapes=[None, polymorphic_shapes], enable_xla=enable_xla) # Create tf.Variables for the parameters. If you want more useful variable # names, you can use `tree.map_structure_with_path` from the `dm-tree` package param_vars = tf.nest.map_structure( lambda param: tf.Variable(param, trainable=with_gradient), params) tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs), autograph=False, jit_compile=compile_model) signatures = {} # This signature is needed for TensorFlow Serving use. signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \ tf_graph.get_concrete_function(input_signatures[0]) for input_signature in input_signatures[1:]: # If there are more signatures, trace and cache a TF function for each one tf_graph.get_concrete_function(input_signature) wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars) if with_gradient: if not saved_model_options: saved_model_options = tf.saved_model.SaveOptions( experimental_custom_gradients=True) else: saved_model_options.experimental_custom_gradients = True tf.saved_model.save(wrapper, model_dir, signatures=signatures, options=saved_model_options)