Exemple #1
0
 def save_computation_graphs(self, save_backward_graph):
     """Dump computation graphs to files."""
     if self._n_devices != 1:
         return  # TODO(lukaszkaiser): make this work with more devices.
     next_train_batch = next(self._train_stream)
     output_dir = self._output_dir
     if self._n_devices > 1:
         next_train_batch = reshape_by_device(next_train_batch,
                                              self._n_devices)
     params = self._opt_state[0]
     forward_computation = jax.xla_computation(self._model_predict_eval)(
         next_train_batch,
         params=params,
         state=self._model_state,
         rng=self._rngs[0])
     with gfile.GFile(os.path.join(output_dir, "forward.txt"), "w") as f:
         f.write(forward_computation.GetHloText())
     with gfile.GFile(os.path.join(output_dir, "forward.dot"), "w") as f:
         f.write(forward_computation.GetHloDotGraph())
     backward_computation = jax.xla_computation(
         self._jit_update_fn)(self._step, self._opt_state, next_train_batch,
                              self._model_state, self._rngs)
     with gfile.GFile(os.path.join(output_dir, "backward.txt"), "w") as f:
         f.write(backward_computation.GetHloText())
     if save_backward_graph:  # Backward graphs can be large so we guard it.
         with gfile.GFile(os.path.join(output_dir, "backward.dot"),
                          "w") as f:
             f.write(backward_computation.GetHloDotGraph())
Exemple #2
0
  def trace_bench(state):
    """Benchmark Jax trace of hk.init_fn of model."""
    x = jnp.ones(input_shape).block_until_ready()
    k = jax.random.PRNGKey(42)

    while state:
      jax.xla_computation(init)(k, x)
 def test_different_computations(self):
     computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
     computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
     compile_options = jax.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     self.assertNotEqual(cc.get_cache_key(computation1, compile_options),
                         cc.get_cache_key(computation2, compile_options))
Exemple #4
0
 def save_computation_graphs(self, save_backward_graph):
     """Dump computation graphs to files."""
     if self.n_devices != 1:
         return  # TODO(lukaszkaiser): make this work with more devices.
     batch = next(self._train_stream)
     output_dir = self._output_dir
     if self.n_devices > 1:
         batch = _reshape_by_device(batch, self.n_devices)
     weights = self._opt_state[0][0]
     forward_computation = jax.xla_computation(self._model_predict_eval)(
         batch,
         weights=weights,
         state=self._model_state[0],
         rng=self._rngs[0])
     with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'),
                            'w') as f:
         f.write(forward_computation.GetHloText())
     with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'),
                            'w') as f:
         f.write(forward_computation.GetHloDotGraph())
     backward_computation = jax.xla_computation(self._jit_update_fn)(
         self._step, self._opt_state, batch, self._model_state, self._rngs)
     with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.txt'),
                            'w') as f:
         f.write(backward_computation.GetHloText())
     if save_backward_graph:  # Backward graphs can be large so we guard it.
         with tf.io.gfile.GFile(os.path.join(output_dir, 'backward.dot'),
                                'w') as f:
             f.write(backward_computation.GetHloDotGraph())
Exemple #5
0
 def test_jit_metadata(self):
   hlo = jax.xla_computation(jnp.sin)(1.).get_hlo_module().to_string()
   self.assertRegex(hlo, 'op_type="sin"')
   self.assertRegex(hlo, 'op_name="xla_computation\\(sin\\)/sin"')
   def foo(x):
     return jnp.sin(x)
   hlo = jax.xla_computation(foo)(1.).get_hlo_module().to_string()
   self.assertRegex(hlo, 'op_type="sin"')
   self.assertRegex(hlo, 'op_name="xla_computation\\(foo\\)/sin"')
 def test_diff_executables(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         computation1 = jax.xla_computation(lambda x, y: x + y)(1, 1)
         computation2 = jax.xla_computation(lambda x, y: x * y)(2, 2)
         compile_options = jax.lib.xla_bridge.get_compile_options(
             num_replicas=1, num_partitions=1)
         backend = jax.lib.xla_bridge.get_backend()
         executable1 = backend.compile(computation1, compile_options)
         executable2 = backend.compile(computation2, compile_options)
         cc.put_executable(computation1, compile_options, executable1)
         cc.put_executable(computation2, compile_options, executable2)
         self.assertNotEqual(
             cc.get_executable(computation1, compile_options),
             cc.get_executable(computation2, compile_options))
Exemple #7
0
 def test_same_hash_key(self):
   computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
   compile_options = jax._src.lib.xla_bridge.get_compile_options(
                      num_replicas=1, num_partitions=1)
   backend = jax._src.lib.xla_bridge.get_backend()
   self.assertEqual(cc.get_cache_key(computation, compile_options, backend),
                    cc.get_cache_key(computation, compile_options, backend))
Exemple #8
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack):
  """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
  py_typecheck.check_callable(traced_fn)
  py_typecheck.check_callable(arg_fn)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

  if parameter_type is not None:
    parameter_type = computation_types.to_type(parameter_type)
    packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
  else:
    packed_arg = None

  args, kwargs = arg_fn(packed_arg)

  # While the fake parameters are fed via args/kwargs during serialization,
  # it is possible for them to get reordered in the actual generated XLA code.
  # We use here the same flattening function as that one, which is used by
  # the JAX serializer to determine the ordering and allow it to be captured
  # in the parameter binding. We do not need to do anything special for the
  # results, since the results, if multiple, are always returned as a tuple.
  flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
  tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

  context = jax_computation_context.JaxComputationContext()
  with context_stack.install(context):
    tracer_callable = jax.xla_computation(
        traced_fn, tuple_args=True, return_shape=True)
    compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

  if isinstance(returned_shape, jax.ShapeDtypeStruct):
    returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape)
  else:
    returned_type_spec = computation_types.to_type(
        structure.map_structure(
            _jax_shape_dtype_struct_to_tff_tensor,
            structure.from_container(returned_shape, recursive=True)))

  computation_type = computation_types.FunctionType(parameter_type,
                                                    returned_type_spec)
  return xla_serialization.create_xla_tff_computation(compiled_xla,
                                                      tensor_indexes,
                                                      computation_type)
Exemple #9
0
  def _check_sharding_annotations(self,
                                  f_jax,
                                  args: Sequence[Any],
                                  *,
                                  expected: Sequence[str],
                                  expected_opt: Sequence[str],
                                  num_partitions=2):
    """Check expected patterns in the HLO generated from f_jax and its conversion.

    We run this check on CPU also, which is useful for debugging locally.
    We currently check the unoptimized HLO against `expected` on CPU and TPU,
    and we check the optimized HLO against `expected_opt` on TPU only and
    only for JAX.

    See `self.AssertShardingAnnotations` for documentation of `expected`
    and `expected_opt`.
    """
    if jtu.device_under_test() == "gpu":
      raise unittest.SkipTest("Sharding HLO tests not useful for GPU")

    jax_comp = jax.xla_computation(f_jax)(*args)
    jax_hlo = jax_comp.as_hlo_text()
    if LOG_HLO:
      logging.info(f"[{self._testMethodName}] got JAX HLO {jax_hlo}")
    self.AssertShardingAnnotations("JAX before optimizations", jax_hlo, expected)

    if jtu.device_under_test() == "tpu":
      backend = jax.lib.xla_bridge.get_backend()
      num_replicas = 1
      device_assignment = np.arange(num_partitions * num_replicas)
      device_assignment = np.reshape(device_assignment, (-1, num_partitions))
      use_spmd_partitioning = num_partitions > 1
      compile_options = jax.lib.xla_bridge.get_compile_options(
          num_replicas=num_replicas,
          num_partitions=num_partitions,
          device_assignment=device_assignment,
          use_spmd_partitioning=use_spmd_partitioning,
      )
      jax_optimized_hlo = backend.compile(
          jax_comp, compile_options).hlo_modules()[0].to_string()
      if LOG_HLO:
        logging.info(f"[{self._testMethodName}] got JAX optimized HLO for "
                     f"platform {backend.platform} {jax_optimized_hlo}")
      self.AssertShardingAnnotations("JAX after optimizations",
                                     jax_optimized_hlo, expected_opt)

    f_tf = jax2tf.convert(f_jax)
    device_name = f"/device:{jtu.device_under_test().upper()}:0"
    tf_hlo = tf.function(f_tf, jit_compile=True, autograph=False).\
      experimental_get_compiler_ir(*args)(stage="hlo",
                                          device_name=device_name)
    if LOG_HLO:
      logging.info(f"[{self._testMethodName}] got TF HLO {tf_hlo}")
    self.AssertShardingAnnotations("TF before optimizations", tf_hlo, expected)
    tf_optimized_hlo = tf.function(f_tf, jit_compile=True).\
      experimental_get_compiler_ir(*args)(stage="optimized_hlo",
                                          device_name=device_name)
    if LOG_HLO:
      logging.info(f"[{self._testMethodName}] got TF optimized HLO "
                   f"for {device_name}: {tf_optimized_hlo}")
Exemple #10
0
    def test_pjit_basic1D(self):
        @functools.partial(pjit.pjit,
                           in_axis_resources=(P("x"), P("x")),
                           out_axis_resources=None)
        def jax_func(x, y):
            return x + y

        shape = (8, 10)
        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
        hlo = jax.xla_computation(jax_func)(x, x).as_hlo_text()
        print(f"HLO is {hlo}")
        print(f"JAXPR is {jax.make_jaxpr(jax_func)(x, x)}")
        self._check_sharding_annotations(
            jax_func,
            [x, x],
            expected=[
                r"f32\[8,10\].*sharding={devices=\[2,1\]",  # x and y
                r"f32\[8,10\].*sharding={replicated",  # output
            ],
            expected_opt=[
                r"f32\[4,10\].*sharding={devices=\[2,1\]",  # x and y
                # TODO: why don't we see "sharding={replicated"
                r"f32\[8,10\]",  # output
            ],
            num_partitions=2)
Exemple #11
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    def _adjust_arg(x):
        return type_conversions.type_to_py_container(x, x.type_signature)

    args = [_adjust_arg(x) for x in args]
    kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()}

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, computation_type)
Exemple #12
0
 def test_grad_jit_metadata(self):
   @jax.jit
   def foo(x):
     return jnp.sin(x)
   hlo = jax.xla_computation(jax.grad(foo))(1.).get_hlo_module().to_string()
   self.assertRegex(hlo, 'op_type="sin"')
   self.assertRegex(hlo, 'op_type="cos"')
   self.assertRegex(hlo, 'op_type="mul"')
Exemple #13
0
 def test_get_no_executable(self):
     with tempfile.TemporaryDirectory() as tmpdir:
         cc.initialize_cache(tmpdir)
         computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
         compile_options = jax.lib.xla_bridge.get_compile_options(
             num_replicas=1, num_partitions=1)
         self.assertEqual(cc.get_executable(computation, compile_options),
                          None)
Exemple #14
0
 def test_different_hash_key(self):
     computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
     compile_options_not_filled = jax.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     compile_options_filled = self.filled_compile_options()
     self.assertNotEqual(
         cc.get_cache_key(computation, compile_options_not_filled),
         cc.get_cache_key(computation, compile_options_filled))
Exemple #15
0
def get_flops(f: Callable, optimize: bool, *a, **kw) -> float:
    m = jax.xla_computation(f)(*a, **kw)
    client = jax.lib.xla_bridge.get_backend()
    if optimize:
        m = client.compile(m).hlo_modules()[0]
    else:
        m = m.as_hlo_module()
    analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m)
    return analysis['flops']
Exemple #16
0
def compute_num_flops(f, optimize, *a, **kw):
    m = jax.xla_computation(f)(*a, **kw)
    client = jax.lib.xla_bridge.get_backend()
    if optimize:
        m = client.compile(m).hlo_modules()[0]
    else:
        m = m.as_hlo_module()
    analysis = jax.lib.xla_extension.hlo_module_cost_analysis(client, m)
    return int(analysis['flops'])
Exemple #17
0
def jax_to_hlo(fn, input_shapes, constants=None):
    """Converts a JAX function to an HLO module.

  Args:
    fn: Function to convert.
    input_shapes: List of tuples (arg name, xla_client.Shape),
      indicating the shapes of the arguments to fn.  The order of parameters in
      the resulting XLA program will match the order in this list.
    constants: Dict mapping function argument name to a Python value.  Specified
      arguments these values as compile-time constants.

  Returns:
    A tuple (serialized_hlo_proto, hlo_text).
  """
    if not constants:
        constants = {}

    overlapping_args = {arg_name
                        for arg_name, _ in input_shapes} & set(
                            constants.keys())
    if overlapping_args:
        raise ValueError(
            'Arguments appear in both `input_shapes` and `constants`: %s' %
            ', '.join(sorted(overlapping_args)))

    args = []
    for arg_name, shape in input_shapes:
        if not shape.is_array():
            raise ValueError(
                'Shape %s is not an array, but currently only arrays '
                'are supported (i.e., no tuples, nor tokens).' % str(shape))

        # Check that `shape` either doesn't have a layout or has the default layout.
        #
        # TODO(jlebar): This could be simpler if the Shape class exposed its layout,
        # or if Shape exposed a function to unconditionally use the default layout.
        shape_with_default_layout = xla_client.Shape.array_shape(
            shape.xla_element_type(),
            shape.dimensions()).with_major_to_minor_layout_if_absent()
        if (shape.with_major_to_minor_layout_if_absent() !=
                shape_with_default_layout):
            raise ValueError('Shape %s has a non-default layout, but only '
                             'the default layout is allowed.' % str(shape))

        args.append(jnp.zeros(shape.dimensions(), dtype=shape.numpy_dtype()))

    # Curry `constants` into the function.
    fn_curried = functools.partial(fn, **constants)

    # Wrapper that takes in args in the order of `input_shapes` and converts them
    # to kwargs for calling `fn`.
    def ordered_wrapper(*args):
        arg_names = [arg_name for arg_name, _ in input_shapes]
        return fn_curried(**dict(zip(arg_names, args)))

    comp = jax.xla_computation(ordered_wrapper)(*args)
    return (comp.as_serialized_hlo_module_proto(), comp.as_hlo_text())
Exemple #18
0
  def compile_bench(state):
    """Benchmark Jax compile of hk.init_fn of model."""
    x = jnp.ones(input_shape).block_until_ready()
    k = jax.random.PRNGKey(42)

    c = jax.xla_computation(init)(k, x)
    b = jax.lib.xla_client.get_local_backend()

    while state:
      b.compile(c)
Exemple #19
0
    def testTranslationRule(self):
        @partial(sharded_jit, in_parts=(P(2, 1), P(2, 1)), out_parts=None)
        def f(x, y):
            return x + y

        # Test that the translation rule runs without error and produces the
        # OpShardings we expect somewhere.
        shape = (8, 8)
        hlo = jax.xla_computation(f)(np.ones(shape), np.ones(shape))
        self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text())
        self.assertIn("sharding={replicated}", hlo.as_hlo_text())
Exemple #20
0
  def test_default_name(self):
    @named_call.stateful_named_call
    def naming_things_is_hard(x):
      return x ** 2

    @jax.jit
    def f(x):
      return naming_things_is_hard(x) + naming_things_is_hard(x)

    c = jax.xla_computation(f)(2)
    self.assertIn('naming_things_is_hard', c.as_hlo_text())
Exemple #21
0
    def testShardingConstraintAnnotation(self):
        @partial(sharded_jit, in_parts=None, out_parts=None)
        def f(x):
            y = x + 1
            y = with_sharding_constraint(y, P(2, 1))
            return y * 2

        shape = (8, 8)
        hlo = jax.xla_computation(f)(np.ones(shape))
        # Annotation from with_sharding_constraint
        self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text())
        # Annotation from sharded_jit
        self.assertIn("sharding={replicated}", hlo.as_hlo_text())
Exemple #22
0
 def test_cond_metadata(self):
   def true_fun(x):
     return jnp.sin(x)
   def false_fun(x):
     return jnp.cos(x)
   def f(x):
     return jax.lax.cond(True, x, true_fun, x, false_fun)
   hlo = jax.xla_computation(f)(1.).get_hlo_module().to_string()
   self.assertRegex(hlo, 'op_type="cond"')
   self.assertRegex(hlo, 'op_name=".*cond\\[ linear=\\(False, False\\) \\]"')
   self.assertRegex(hlo, 'op_type="cos"')
   self.assertRegex(hlo, 'op_name=".*cond/branch_0_fun/cos"')
   self.assertRegex(hlo, 'op_type="sin"')
   self.assertRegex(hlo, 'op_name=".*cond/branch_1_fun/sin"')
Exemple #23
0
  def test_named_call_default_name(self):
    @stateful.named_call
    def naming_things_is_hard(x):
      return x ** 2

    @jax.jit
    def f(x):
      return naming_things_is_hard(x) + naming_things_is_hard(x)

    c = jax.xla_computation(f)(1.)
    print_opts = jax.xla.xe.HloPrintOptions.short_parsable()
    print_opts.print_metadata = True
    hlo_text = c.as_hlo_module().to_string(print_opts)
    self.assertIn("naming_things_is_hard", hlo_text)
Exemple #24
0
 def test_put_executable(self):
   with tempfile.TemporaryDirectory() as tmpdir:
     cc.initialize_cache(tmpdir)
     computation = jax.xla_computation(lambda x, y: x + y)(1, 1)
     compile_options = jax._src.lib.xla_bridge.get_compile_options(
         num_replicas=1, num_partitions=1)
     backend = jax._src.lib.xla_bridge.get_backend()
     executable = backend.compile(computation, compile_options)
     cc.put_executable("alambda", computation, compile_options, executable,
                       backend)
     deserialized_executable = cc.get_executable(computation, compile_options, backend)
     inputs_to_executable = (np.array(1, dtype=np.int32), np.array(2, dtype=np.int32))
     expected = jax._src.lib.xla_client.execute_with_python_values(executable, inputs_to_executable, backend)
     actual = jax._src.lib.xla_client.execute_with_python_values(deserialized_executable, inputs_to_executable, backend)
     self.assertEqual(expected, actual)
Exemple #25
0
    def test_grad_jit_metadata(self):
        @jax.jit
        def foo(x):
            return jnp.sin(x)

        hlo = jax.xla_computation(
            jax.grad(foo))(1.).get_hlo_module().to_string()
        self.assertRegex(hlo, 'op_type="sin"')
        self.assertRegex(hlo, 'op_type="cos"')
        self.assertRegex(hlo, 'op_type="mul"')
        self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/sin"')
        self.assertRegex(hlo, 'op_name=".*jit\\(jvp\\(foo\\)\\)/cos"')
        self.assertRegex(
            hlo, 'op_name=".*jit\\(transpose\\('
            'jvp\\(foo\\)\\)\\)/mul"')
Exemple #26
0
def aot(function, *args, **options):
    """Traces and compiles a function, flattening the input args.

  This is intended to be a lower-level interface for compiling a JAX function to
  IREE without setting up the runtime bindings to use it within Python. A common
  usecase for this is compiling to Android (and similar targets).

  Args:
    function: The function to compile.
    args: The inputs to trace and compile the function for.
    **kwargs: Keyword args corresponding to xla.ImportOptions or CompilerOptions
  """
    xla_comp = jax.xla_computation(function)(*args)
    hlo_proto = xla_comp.as_serialized_hlo_module_proto()
    return iree.compiler.xla.compile_str(hlo_proto, **options)
Exemple #27
0
 def testNestedMeshSPMD(self):
   h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
            in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
            axis_resources={'c': 'z'})
   f = xmap(lambda x: h(x * 2),
            in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
            axis_resources={'a': 'x', 'b': 'y'})
   xshape = (8, 2, 4, 5)
   x = jnp.arange(np.prod(xshape)).reshape(xshape)
   y = f(x)
   hlo = jax.xla_computation(f)(x).as_hlo_text()
   match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
   self.assertIsNot(match, None)
   tile_factors = [int(s) for s in match.group(1).split(',')]
   self.assertEqual(set(tile_factors), {1, 2})
Exemple #28
0
    def test_precision(self, precision):
        def f(x):
            return basic.Linear(1)(x, precision=precision)

        f = transform.transform(f)
        rng = jax.random.PRNGKey(42)
        x = np.ones([1, 1])
        params = f.init(rng, x)
        c = jax.xla_computation(lambda x: f.apply(params, None, x))(x)
        hlo = c.as_hlo_text()
        op_line = next(l for l in hlo.split("\n") if "dot(" in l)
        if precision is not None and precision != jax.lax.Precision.DEFAULT:
            name = str(precision).lower()
            self.assertRegex(op_line, f"operand_precision={{{name},{name}}}")
        else:
            self.assertNotIn("operand_precision", op_line)
Exemple #29
0
def load_hlo_proto_from_jax_fn(fn, *fn_args, **fn_kwargs):
    """Loads HLO proto object from jax function.

  Args:
    fn: a jax function.
    *fn_args: Arguments to fn.
    **fn_kwargs: Keyword arguments to fn.

  Returns:
    An HloModuleProto object.

  """
    computation = jax.xla_computation(fn)(*fn_args, **fn_kwargs)
    serialized_hlo = computation.as_serialized_hlo_module_proto()
    hlo_module_proto = hlo_pb2.HloModuleProto.FromString(serialized_hlo)
    return hlo_module_proto
Exemple #30
0
 def save_computation_graphs(self):
   """Dump computation graphs to files."""
   if self.n_devices != 1:
     return  # TODO(lukaszkaiser): make this work with more devices.
   batch = next(self._train_stream)
   output_dir = self._output_dir
   if self.n_devices > 1:
     batch = _reshape_by_device(batch, self.n_devices)
   weights = self._opt_state[0][0]
   forward_computation = jax.xla_computation(self._model_predict_eval)(
       batch, weights=weights, state=self._model_state[0],
       rng=self._rngs[0])
   with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.txt'), 'w') as f:
     f.write(forward_computation.as_hlo_text())
   with tf.io.gfile.GFile(os.path.join(output_dir, 'forward.dot'), 'w') as f:
     f.write(forward_computation.as_hlo_dot_graph())