def save_computation_graphs(self, save_backward_graph): """Dump computation graphs to files.""" if self._n_devices != 1: return # TODO(lukaszkaiser): make this work with more devices. next_train_batch = next(self._train_stream) output_dir = self._output_dir if self._n_devices > 1: next_train_batch = reshape_by_device(next_train_batch, self._n_devices) params = self._opt_state[0] forward_computation = jax.xla_computation(self._model_predict_eval)( next_train_batch, params=params, state=self._model_state, rng=self._rngs[0]) with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f: f.write(forward_computation.GetHloText()) with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation( self._jit_update_fn)(self._step, self._opt_state, next_train_batch, self._model_state, self._rngs) with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with gfile.GFile(os.path.join(output_dir, "backward.dot"), "w") as f: f.write(backward_computation.GetHloDotGraph())
def trace_bench(state): """Benchmark Jax trace of hk.init_fn of model.""" x = jnp.ones(input_shape).block_until_ready() k = jax.random.PRNGKey(42) while state: jax.xla_computation(init)(k, x)
def test_different_computations(self): computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) self.assertNotEqual(cc.get_cache_key(computation1, compile_options), cc.get_cache_key(computation2, compile_options))
def save_computation_graphs(self, save_backward_graph): """Dump computation graphs to files.""" if self.n_devices != 1: return # TODO(lukaszkaiser): make this work with more devices. batch = next(self._train_stream) output_dir = self._output_dir if self.n_devices > 1: batch = _reshape_by_device(batch, self.n_devices) weights = self._opt_state[0][0] forward_computation = jax.xla_computation(self._model_predict_eval)( batch, weights=weights, state=self._model_state[0], rng=self._rngs[0]) with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: f.write(forward_computation.GetHloText()) with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: f.write(forward_computation.GetHloDotGraph()) backward_computation = jax.xla_computation(self._jit_update_fn)( self._step, self._opt_state, batch, self._model_state, self._rngs) with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.txt'), 'w') as f: f.write(backward_computation.GetHloText()) if save_backward_graph: # Backward graphs can be large so we guard it. with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.dot'), 'w') as f: f.write(backward_computation.GetHloDotGraph())
def test_jit_metadata(self): hlo = jax.xla_computation(jnp.sin)(1.).get_hlo_module().to_string() self.assertRegex(hlo, 'op_type="sin"') self.assertRegex(hlo, 'op_name="xla_computation\\(sin\\)/sin"') def foo(x): return jnp.sin(x) hlo = jax.xla_computation(foo)(1.).get_hlo_module().to_string() self.assertRegex(hlo, 'op_type="sin"') self.assertRegex(hlo, 'op_name="xla_computation\\(foo\\)/sin"')
def test_diff_executables(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1) computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax.lib.xla_bridge.get_backend() executable1 = backend.compile(computation1, compile_options) executable2 = backend.compile(computation2, compile_options) cc.put_executable(computation1, compile_options, executable1) cc.put_executable(computation2, compile_options, executable2) self.assertNotEqual( cc.get_executable(computation1, compile_options), cc.get_executable(computation2, compile_options))
def test_same_hash_key(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax._src.lib.xla_bridge.get_backend() self.assertEqual(cc.get_cache_key(computation, compile_options, backend), cc.get_cache_key(computation, compile_options, backend))
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reordered in the actual generated XLA code. # We use here the same flattening function as that one, which is used by # the JAX serializer to determine the ordering and allow it to be captured # in the parameter binding. We do not need to do anything special for the # results, since the results, if multiple, are always returned as a tuple. flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs)) tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj])) context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation( traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation(compiled_xla, tensor_indexes, computation_type)
def _check_sharding_annotations(self, f_jax, args: Sequence[Any], *, expected: Sequence[str], expected_opt: Sequence[str], num_partitions=2): """Check expected patterns in the HLO generated from f_jax and its conversion. We run this check on CPU also, which is useful for debugging locally. We currently check the unoptimized HLO against `expected` on CPU and TPU, and we check the optimized HLO against `expected_opt` on TPU only and only for JAX. See `self.AssertShardingAnnotations` for documentation of `expected` and `expected_opt`. """ if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Sharding HLO tests not useful for GPU") jax_comp = jax.xla_computation(f_jax)(*args) jax_hlo = jax_comp.as_hlo_text() if LOG_HLO: logging.info(f"[{self._testMethodName}] got JAX HLO {jax_hlo}") self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected) if jtu.device_under_test() == "tpu": backend = jax.lib.xla_bridge.get_backend() num_replicas = 1 device_assignment = np.arange(num_partitions * num_replicas) device_assignment = np.reshape(device_assignment, (-1, num_partitions)) use_spmd_partitioning = num_partitions > 1 compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=num_replicas, num_partitions=num_partitions, device_assignment=device_assignment, use_spmd_partitioning=use_spmd_partitioning, ) jax_optimized_hlo = backend.compile( jax_comp, compile_options).hlo_modules()[0].to_string() if LOG_HLO: logging.info(f"[{self._testMethodName}] got JAX optimized HLO for " f"platform {backend.platform} {jax_optimized_hlo}") self.AssertShardingAnnotations("JAX after optimizations", jax_optimized_hlo, expected_opt) f_tf = jax2tf.convert(f_jax) device_name = f"/device:{jtu.device_under_test().upper()}:0" tf_hlo = tf.function(f_tf, jit_compile=True, autograph=False).\ experimental_get_compiler_ir(*args)(stage="hlo", device_name=device_name) if LOG_HLO: logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}") self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected) tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\ experimental_get_compiler_ir(*args)(stage="optimized_hlo", device_name=device_name) if LOG_HLO: logging.info(f"[{self._testMethodName}] got TF optimized HLO " f"for {device_name}: {tf_optimized_hlo}")
def test_pjit_basic1D(self): @functools.partial(pjit.pjit, in_axis_resources=(P("x"), P("x")), out_axis_resources=None) def jax_func(x, y): return x + y shape = (8, 10) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text() print(f"HLO is {hlo}") print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}") self._check_sharding_annotations( jax_func, [x, x], expected=[ r"f32\[8,10\].*sharding={devices=\[2,1\]", # x and y r"f32\[8,10\].*sharding={replicated", # output ], expected_opt=[ r"f32\[4,10\].*sharding={devices=\[2,1\]", # x and y # TODO: why don't we see "sharding={replicated" r"f32\[8,10\]", # output ], num_partitions=2)
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) def _adjust_arg(x): return type_conversions.type_to_py_container(x, x.type_signature) args = [_adjust_arg(x) for x in args] kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()} context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation(traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor( returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation( compiled_xla, computation_type)
def test_grad_jit_metadata(self): @jax.jit def foo(x): return jnp.sin(x) hlo = jax.xla_computation(jax.grad(foo))(1.).get_hlo_module().to_string() self.assertRegex(hlo, 'op_type="sin"') self.assertRegex(hlo, 'op_type="cos"') self.assertRegex(hlo, 'op_type="mul"')
def test_get_no_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) self.assertEqual(cc.get_executable(computation, compile_options), None)
def test_different_hash_key(self): computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options_not_filled = jax.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) compile_options_filled = self.filled_compile_options() self.assertNotEqual( cc.get_cache_key(computation, compile_options_not_filled), cc.get_cache_key(computation, compile_options_filled))
def get_flops(f: Callable, optimize: bool, *a, **kw) -> float: m = jax.xla_computation(f)(*a, **kw) client = jax.lib.xla_bridge.get_backend() if optimize: m = client.compile(m).hlo_modules()[0] else: m = m.as_hlo_module() analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m) return analysis['flops']
def compute_num_flops(f, optimize, *a, **kw): m = jax.xla_computation(f)(*a, **kw) client = jax.lib.xla_bridge.get_backend() if optimize: m = client.compile(m).hlo_modules()[0] else: m = m.as_hlo_module() analysis = jax.lib.xla_extension.hlo_module_cost_analysis(client, m) return int(analysis['flops'])
def jax_to_hlo(fn, input_shapes, constants=None): """Converts a JAX function to an HLO module. Args: fn: Function to convert. input_shapes: List of tuples (arg name, xla_client.Shape), indicating the shapes of the arguments to fn. The order of parameters in the resulting XLA program will match the order in this list. constants: Dict mapping function argument name to a Python value. Specified arguments these values as compile-time constants. Returns: A tuple (serialized_hlo_proto, hlo_text). """ if not constants: constants = {} overlapping_args = {arg_name for arg_name, _ in input_shapes} & set( constants.keys()) if overlapping_args: raise ValueError( 'Arguments appear in both `input_shapes` and `constants`: %s' % ', '.join(sorted(overlapping_args))) args = [] for arg_name, shape in input_shapes: if not shape.is_array(): raise ValueError( 'Shape %s is not an array, but currently only arrays ' 'are supported (i.e., no tuples, nor tokens).' % str(shape)) # Check that `shape` either doesn't have a layout or has the default layout. # # TODO(jlebar): This could be simpler if the Shape class exposed its layout, # or if Shape exposed a function to unconditionally use the default layout. shape_with_default_layout = xla_client.Shape.array_shape( shape.xla_element_type(), shape.dimensions()).with_major_to_minor_layout_if_absent() if (shape.with_major_to_minor_layout_if_absent() != shape_with_default_layout): raise ValueError('Shape %s has a non-default layout, but only ' 'the default layout is allowed.' % str(shape)) args.append(jnp.zeros(shape.dimensions(), dtype=shape.numpy_dtype())) # Curry `constants` into the function. fn_curried = functools.partial(fn, **constants) # Wrapper that takes in args in the order of `input_shapes` and converts them # to kwargs for calling `fn`. def ordered_wrapper(*args): arg_names = [arg_name for arg_name, _ in input_shapes] return fn_curried(**dict(zip(arg_names, args))) comp = jax.xla_computation(ordered_wrapper)(*args) return (comp.as_serialized_hlo_module_proto(), comp.as_hlo_text())
def compile_bench(state): """Benchmark Jax compile of hk.init_fn of model.""" x = jnp.ones(input_shape).block_until_ready() k = jax.random.PRNGKey(42) c = jax.xla_computation(init)(k, x) b = jax.lib.xla_client.get_local_backend() while state: b.compile(c)
def testTranslationRule(self): @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None) def f(x, y): return x + y # Test that the translation rule runs without error and produces the # OpShardings we expect somewhere. shape = (8, 8) hlo = jax.xla_computation(f)(np.ones(shape), np.ones(shape)) self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text()) self.assertIn("sharding={replicated}", hlo.as_hlo_text())
def test_default_name(self): @named_call.stateful_named_call def naming_things_is_hard(x): return x ** 2 @jax.jit def f(x): return naming_things_is_hard(x) + naming_things_is_hard(x) c = jax.xla_computation(f)(2) self.assertIn('naming_things_is_hard', c.as_hlo_text())
def testShardingConstraintAnnotation(self): @partial(sharded_jit, in_parts=None, out_parts=None) def f(x): y = x + 1 y = with_sharding_constraint(y, P(2, 1)) return y * 2 shape = (8, 8) hlo = jax.xla_computation(f)(np.ones(shape)) # Annotation from with_sharding_constraint self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text()) # Annotation from sharded_jit self.assertIn("sharding={replicated}", hlo.as_hlo_text())
def test_cond_metadata(self): def true_fun(x): return jnp.sin(x) def false_fun(x): return jnp.cos(x) def f(x): return jax.lax.cond(True, x, true_fun, x, false_fun) hlo = jax.xla_computation(f)(1.).get_hlo_module().to_string() self.assertRegex(hlo, 'op_type="cond"') self.assertRegex(hlo, 'op_name=".*cond\\[ linear=\\(False, False\\) \\]"') self.assertRegex(hlo, 'op_type="cos"') self.assertRegex(hlo, 'op_name=".*cond/branch_0_fun/cos"') self.assertRegex(hlo, 'op_type="sin"') self.assertRegex(hlo, 'op_name=".*cond/branch_1_fun/sin"')
def test_named_call_default_name(self): @stateful.named_call def naming_things_is_hard(x): return x ** 2 @jax.jit def f(x): return naming_things_is_hard(x) + naming_things_is_hard(x) c = jax.xla_computation(f)(1.) print_opts = jax.xla.xe.HloPrintOptions.short_parsable() print_opts.print_metadata = True hlo_text = c.as_hlo_module().to_string(print_opts) self.assertIn("naming_things_is_hard", hlo_text)
def test_put_executable(self): with tempfile.TemporaryDirectory() as tmpdir: cc.initialize_cache(tmpdir) computation = jax.xla_computation(lambda x, y: x + y)(1, 1) compile_options = jax._src.lib.xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = jax._src.lib.xla_bridge.get_backend() executable = backend.compile(computation, compile_options) cc.put_executable("alambda", computation, compile_options, executable, backend) deserialized_executable = cc.get_executable(computation, compile_options, backend) inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32)) expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend) actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend) self.assertEqual(expected, actual)
def test_grad_jit_metadata(self): @jax.jit def foo(x): return jnp.sin(x) hlo = jax.xla_computation( jax.grad(foo))(1.).get_hlo_module().to_string() self.assertRegex(hlo, 'op_type="sin"') self.assertRegex(hlo, 'op_type="cos"') self.assertRegex(hlo, 'op_type="mul"') self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"') self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"') self.assertRegex( hlo, 'op_name=".*jit\\(transpose\\(' 'jvp\\(foo\\)\\)\\)/mul"')
def aot(function, *args, **options): """Traces and compiles a function, flattening the input args. This is intended to be a lower-level interface for compiling a JAX function to IREE without setting up the runtime bindings to use it within Python. A common usecase for this is compiling to Android (and similar targets). Args: function: The function to compile. args: The inputs to trace and compile the function for. **kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions """ xla_comp = jax.xla_computation(function)(*args) hlo_proto = xla_comp.as_serialized_hlo_module_proto() return iree.compiler.xla.compile_str(hlo_proto, **options)
def testNestedMeshSPMD(self): h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))), in_axes={0: 'c'}, out_axes=({1: 'c'}, {}), axis_resources={'c': 'z'}) f = xmap(lambda x: h(x * 2), in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}), axis_resources={'a': 'x', 'b': 'y'}) xshape = (8, 2, 4, 5) x = jnp.arange(np.prod(xshape)).reshape(xshape) y = f(x) hlo = jax.xla_computation(f)(x).as_hlo_text() match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo) self.assertIsNot(match, None) tile_factors = [int(s) for s in match.group(1).split(',')] self.assertEqual(set(tile_factors), {1, 2})
def test_precision(self, precision): def f(x): return basic.Linear(1)(x, precision=precision) f = transform.transform(f) rng = jax.random.PRNGKey(42) x = np.ones([1, 1]) params = f.init(rng, x) c = jax.xla_computation(lambda x: f.apply(params, None, x))(x) hlo = c.as_hlo_text() op_line = next(l for l in hlo.split("\n") if "dot(" in l) if precision is not None and precision != jax.lax.Precision.DEFAULT: name = str(precision).lower() self.assertRegex(op_line, f"operand_precision={{{name},{name}}}") else: self.assertNotIn("operand_precision", op_line)
def load_hlo_proto_from_jax_fn(fn, *fn_args, **fn_kwargs): """Loads HLO proto object from jax function. Args: fn: a jax function. *fn_args: Arguments to fn. **fn_kwargs: Keyword arguments to fn. Returns: An HloModuleProto object. """ computation = jax.xla_computation(fn)(*fn_args, **fn_kwargs) serialized_hlo = computation.as_serialized_hlo_module_proto() hlo_module_proto = hlo_pb2.HloModuleProto.FromString(serialized_hlo) return hlo_module_proto
def save_computation_graphs(self): """Dump computation graphs to files.""" if self.n_devices != 1: return # TODO(lukaszkaiser): make this work with more devices. batch = next(self._train_stream) output_dir = self._output_dir if self.n_devices > 1: batch = _reshape_by_device(batch, self.n_devices) weights = self._opt_state[0][0] forward_computation = jax.xla_computation(self._model_predict_eval)( batch, weights=weights, state=self._model_state[0], rng=self._rngs[0]) with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f: f.write(forward_computation.as_hlo_text()) with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f: f.write(forward_computation.as_hlo_dot_graph())