Exemple #1
0
    def test_op_metadata_named(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # Calling a jax.named_call
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        def f_callee(x):
            return jnp.cos(x)

        def f_caller(x):
            y = jnp.tanh(x)
            z = jax.named_call(f_callee, name="callee")(y)
            return jnp.sin(z)

        x = np.ones((2, 3), np.float32)

        self.CheckOpMetadata(f_caller, x, [
            tf_test_util.OpMetadataGraph(tf_type="Tanh",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 4,
                                         op_name="jax2tf(f_caller)/tanh",
                                         op_type="tanh"),
            tf_test_util.OpMetadataGraph(
                tf_type="Cos",
                source_file=__file__,
                source_line=user_frame.line_num + 2,
                op_name="jax2tf(f_caller)/named(callee)/cos",
                op_type="cos"),
            tf_test_util.OpMetadataGraph(tf_type="Sin",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 6,
                                         op_name="jax2tf(f_caller)/sin",
                                         op_type="sin"),
        ])
Exemple #2
0
 def process_call(self, call_primitive, f, tracers, params):
     assert call_primitive.multiple_results
     if config.jax_experimental_name_stack:
         params = dict(params, name=params.get('name', f.__name__))
     else:
         params = dict(params,
                       name=wrap_name(params.get('name', f.__name__),
                                      'vmap'))
     vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
     if all(bdim is not_mapped for bdim in dims):
         return call_primitive.bind(f, *vals, **params)
     else:
         f_, dims_out = batch_subtrace(f, self.main, dims)
         ax_size, = {
             x.shape[d]
             for x, d in zip(vals, dims) if d is not not_mapped
         }
         f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name,
                                 dims)
         vals_out = call_primitive.bind(f_, *vals, **params)
         src = source_info_util.current()
         return [
             BatchTracer(self, v, d, src)
             for v, d in zip(vals_out, dims_out())
         ]
Exemple #3
0
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
                   has_output_token: bool) -> core.Jaxpr:
  """Rewrite a Jaxpr to thread the token, if needed."""
  assert has_input_token or not has_output_token

  if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr):
    return jaxpr

  mk_new_var = core.gensym([jaxpr])

  eqns: List[core.JaxprEqn] = []
  last_token_var = mk_new_var(core.abstract_token)  # store the incoming token
  if has_input_token:
    invars = jaxpr.invars + [last_token_var]
  else:
    invars = jaxpr.invars
    # We need tokens but none is given in input; make one depending on all invars
    eqns.append(
        core.new_jaxpr_eqn(jaxpr.invars, [last_token_var],
                           lax.create_token_p, {}, source_info_util.current()))

  for eqn in jaxpr.eqns:
    if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params):
      eqns.append(eqn)
    else:
      output_token_var = mk_new_var(core.abstract_token)
      _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var)
      last_token_var = output_token_var

  outvars = jaxpr.outvars + ([last_token_var] if has_output_token else [])
  new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns)
  return new_jaxpr
Exemple #4
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
  assert not jaxpr.constvars
  policy = params['policy'] or (lambda *_, **__: False)
  # unzip into jaxpr_known and jaxpr_unknown
  in_unknowns = [not t.is_known() for t in tracers]
  jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
      pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
  jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
  _, used_outs_unknown = partition_list(out_inst, out_unknowns)
  jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)

  # compute known outputs and residuals (hoisted out of remat primitive)
  _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
  _, in_consts = partition_list(in_used_known, in_consts_)
  out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
  out_consts_ = iter(out_consts)
  # form known outputs and collect residual tracers
  out_known_tracers = [
      pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
      for uk in out_unknowns if not uk]
  residuals = list(out_consts_)

  # set up unknown outputs with a recipe to call remat
  res_tracers = map(trace.new_instantiated_const, residuals)
  in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
  _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
  out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
                       for x in jaxpr_unknown.outvars]
  new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
  recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                             new_params, source_info_util.current())
  for t in out_jaxpr_tracers: t.recipe = recipe

  # zip together known and unknown outputs
  return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
Exemple #5
0
 def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
                           in_tracers, out_pvs, out_consts, out_keys, name,
                           is_map):
     """Takes a traced function and binds the Jaxpr to output tracers."""
     lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr)
     const_tracers = safe_map(self.new_instantiated_const, consts)
     env_tracers = safe_map(self.instantiate_const, env)
     out_tracers = [
         UnzipTracer(self, pe.PartialVal((pv, const)), None, key)
         for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
     ]
     new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
     if 'donated_invars' in params:
         new_donated_invars = (
             (False, ) * len(const_tracers) + (False, ) * len(env_tracers) +
             tuple(v for v, t in zip(params['donated_invars'], in_tracers)
                   if not t.pval.is_known()))
         new_params['donated_invars'] = new_donated_invars
     if is_map:
         out_axes = params['out_axes_thunk']()
         assert all(out_axis == 0 for out_axis in out_axes)
         new_params['out_axes'] = (0, ) * len(out_tracers)
         del new_params['out_axes_thunk']
     eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers +
                                   in_tracers), out_tracers, primitive,
                             new_params, source_info_util.current())  # pytype: disable=wrong-arg-types
     for t in out_tracers:
         t.recipe = eqn
     return out_tracers
Exemple #6
0
def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
  handler = to_elt_handlers.get(type(x))
  if handler:
    return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
  else:
    spec = spec and canonicalize_axis(spec, len(np.shape(x)))
    return (BatchTracer(trace, x, spec, source_info_util.current())
            if spec is not None else x)
Exemple #7
0
def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals):
  in_dims = in_dims() if callable(in_dims) else in_dims
  trace = main.with_cur_sublevel()
  idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0,
                                    source_info_util.current()))
  in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
  outs = yield in_tracers, {}
  out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
  out_vals = map(partial(from_elt, trace, axis_size), outs, out_dim_dests)
  yield out_vals
Exemple #8
0
def batch_subtrace(main, in_dims, *in_vals):
  # used in e.g. process_call
  trace = main.with_cur_sublevel()
  in_dims = in_dims() if callable(in_dims) else in_dims
  in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
                if dim is not None else x for x, dim in zip(in_vals, in_dims)]
  outs = yield in_tracers, {}
  out_tracers = map(trace.full_raise, outs)
  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
  yield out_vals, out_dims
Exemple #9
0
    def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
        vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
        if all(dim is not_mapped for dim in dims):
            return map_primitive.bind(f, *vals, **params)
        else:
            assert len({
                x.shape[d]
                for x, d in zip(vals, dims) if d is not not_mapped
            }) == 1

            # The logic for the dimension math below is as follows:
            # ╔═════════════╦════════════════════════════════════════╦═══════════╗
            # ║ d / in_axis ║ None                                   ║ int       ║
            # ╠═════════════╬════════════════════════════════════════╩═══════════╣
            # ║ None        ║ No extra axis, so in_axis unaffected               ║
            # ╠═════════════╬════════════════════════════════════════╦═══════════╣
            # ║ int         ║ Not mapped, so batching dim unaffected ║ See below ║
            # ╚═════════════╩════════════════════════════════════════╩═══════════╝
            # When both d and in_axis are defined then:
            # - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
            # - If `d >  in_axis`, we have to decrement `d` (as `in_axis` will get removed).
            def both_mapped(in_out_axis, d):
                return in_out_axis is not None and d is not not_mapped

            new_in_axes = tuple(
                in_axis +
                1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
                for d, in_axis in zip(dims, params['in_axes']))
            new_dims = tuple(
                d - 1 if both_mapped(in_axis, d) and in_axis < d else d
                for d, in_axis in zip(dims, params['in_axes']))
            f, dims_out = batch_subtrace(f, self.main, new_dims)
            out_axes_thunk = params['out_axes_thunk']
            # NOTE: This assumes that the choice of the dimensions over which outputs
            #       are batched is entirely dependent on the function and not e.g. on the
            #       data or its shapes.
            @as_hashable_function(closure=out_axes_thunk)
            def new_out_axes_thunk():
                return tuple(
                    out_axis + 1
                    if both_mapped(out_axis, d) and d < out_axis else out_axis
                    for out_axis, d in zip(out_axes_thunk(), dims_out()))

            new_params = dict(params,
                              in_axes=new_in_axes,
                              out_axes_thunk=new_out_axes_thunk)
            vals_out = map_primitive.bind(f, *vals, **new_params)
            dims_out = (d +
                        1 if both_mapped(out_axis, d) and out_axis <= d else d
                        for d, out_axis in zip(dims_out(), out_axes_thunk()))
            src = source_info_util.current()
            return [
                BatchTracer(self, v, d, src)
                for v, d in zip(vals_out, dims_out)
            ]
Exemple #10
0
 def process_custom_jvp_call(self, prim, fun, jvp, tracers):
   in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
   fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
   jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
   out_vals = prim.bind(fun, jvp, *in_vals)
   fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
   if not fst:
     assert out_dims == out_dims[:len(out_dims) // 2] * 2
     out_dims = out_dims[:len(out_dims) // 2]
   src = source_info_util.current()
   return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
Exemple #11
0
 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
   in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
   axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims)
                 if d is not not_mapped}
   fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
   fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
   bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size,
                              out_dims2, in_dims, self.main.trace_type)
   out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
   fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
   if not fst:
     out_dims = out_dims[-len(out_vals) % len(out_dims):]
   src = source_info_util.current()
   return [BatchTracer(self, v, d, src) for v, d in zip(out_vals, out_dims)]
Exemple #12
0
 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
     assert call_primitive.multiple_results
     params = dict(params,
                   name=wrap_name(params.get('name', f.__name__), 'vmap'))
     vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
     if all(bdim is not_mapped for bdim in dims):
         return call_primitive.bind(f, *vals, **params)
     else:
         f, dims_out = batch_subtrace(f, self.main, dims)
         vals_out = call_primitive.bind(f, *vals, **params)
         src = source_info_util.current()
         return [
             BatchTracer(self, v, d, src)
             for v, d in zip(vals_out, dims_out())
         ]
Exemple #13
0
  def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
    nreps = dynamic_axis_env.nreps
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer
Exemple #14
0
    def test_op_metadata_simple(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # A simple example
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        def f_simple(x):
            return jnp.sin(x)

        x = np.ones((2, 3), np.float32)
        self.CheckOpMetadata(f_simple, x, [
            tf_test_util.OpMetadataGraph(tf_type="Sin",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 2,
                                         op_name="jax2tf(f_simple)/sin",
                                         op_type="sin")
        ])
Exemple #15
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
    assert not jaxpr.constvars
    policy = params['policy'] or nothing_saveable
    in_unknowns = [not t.is_known() for t in tracers]
    jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
        pe.partial_eval_jaxpr_custom(
            jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)

    # DCE jaxpr_staged, keeping only instantiated outputs which are unknown
    _, out_inst_unknown = partition_list(out_inst, out_unknowns)
    jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged,
                                                 out_inst_unknown)
    used_res, in_used_staged = split_list(in_used_staged, [num_res])

    # DCE jaxpr_known, keeping all known outputs but discarding dce'd res
    out_used_known = [True
                      ] * (len(out_unknowns) - sum(out_unknowns)) + used_res
    jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
    num_res = sum(used_res)

    # compute known outputs and residuals (hoisted out of remat primitive)
    _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
    _, in_consts = partition_list(in_used_known, in_consts_)
    out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
    out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res])

    # set up unknown outputs with a recipe to call remat
    res_tracers = map(trace.new_instantiated_const, residuals)
    _, tracers_staged = partition_list(in_used_staged, tracers)
    in_jaxpr_tracers = res_tracers + map(trace.instantiate_const,
                                         tracers_staged)
    out_jaxpr_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
        for x in jaxpr_unknown.outvars
    ]
    new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
    recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                               new_params, jaxpr_unknown.effects,
                               source_info_util.current())
    for t in out_jaxpr_tracers:
        t.recipe = recipe

    # zip together known and unknown outputs
    return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
Exemple #16
0
def gather_error_check(error, operand, start_indices, *,
                       dimension_numbers, slice_sizes, unique_indices,
                       indices_are_sorted, mode, fill_value):
  out = slicing.gather_p.bind(
      operand, start_indices, dimension_numbers=dimension_numbers,
      slice_sizes=slice_sizes, unique_indices=unique_indices,
      indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)

  # compare to OOB masking logic in lax._gather_translation_rule
  dnums = dimension_numbers
  operand_dims = np.array(operand.shape)

  upper_bound = operand_dims[np.array(dnums.start_index_map)]
  upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
  all_inbounds = jnp.all((start_indices >= 0) & (start_indices <= upper_bound))

  summary = source_info_util.summarize(source_info_util.current())
  msg = f"out-of-bounds indexing at {summary}"
  return out, assert_func(error, all_inbounds, msg)
Exemple #17
0
    def test_op_metadata_batched_while(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # An example with while and cond
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        @jax.vmap
        def f_while(x):
            def body_fun(carry):
                new_carry = jnp.sin(carry)  # We look for "sin" in the graph
                return new_carry

            _, carry = lax.while_loop(
                lambda carry: jnp.all(carry <= x
                                      ),  # We look for "le" in the graph
                body_fun,
                x)
            return carry

        shape = (3, 2)
        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)

        jax_comp = jax.xla_computation(f_while)(x)
        backend = jax.lib.xla_bridge.get_backend()
        modules = backend.compile(jax_comp).hlo_modules()
        jax_opt_hlo = modules[0].to_string()
        print(f"JAX OPT HLO = {jax_opt_hlo}")

        self.CheckOpMetadata(f_while, x, [
            tf_test_util.OpMetadataGraph(
                tf_type="Sin",
                source_file=__file__,
                source_line=user_frame.line_num + 4,
                op_name="jax2tf(f_while)/while/body/sin",
                op_type="sin"),
            tf_test_util.OpMetadataGraph(
                tf_type="LessEqual",
                source_file=__file__,
                source_line=user_frame.line_num + 8,
                op_name="jax2tf(f_while)/while/body_pred/le",
                op_type="le"),
        ])
Exemple #18
0
 def process_primitive(self, primitive, tracers, params):
   vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
   is_axis_primitive = primitive in axis_primitive_batchers
   used_names = core.used_axis_names(primitive, params)
   if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names):
     frame = self.get_frame(vals_in, dims_in)
     batcher_primitive = self.get_axis_primitive_batcher(primitive, frame)
     val_out, dim_out = batcher_primitive(vals_in, dims_in, **params)
   elif all(bdim is not_mapped for bdim in dims_in):
     return primitive.bind(*vals_in, **params)
   else:
     frame = self.get_frame(vals_in, dims_in)
     batched_primitive = self.get_primitive_batcher(primitive, frame)
     val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
   src = source_info_util.current()
   if primitive.multiple_results:
     return [BatchTracer(self, x, d, src) for x, d in zip(val_out, dim_out)]
     return map(partial(BatchTracer, self), val_out, dim_out)
   else:
     return BatchTracer(self, val_out, dim_out, src)
Exemple #19
0
    def default_process_primitive(self, primitive, tracers, params):
        """Partially evaluate primitives and saves variable recipes."""
        pvs, consts = jax_util.unzip2(t.pval for t in tracers)
        if all(pv is None for pv in pvs):
            return primitive.bind(*consts, **params)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const, tracers)
        if any(not isinstance(t, UnzipTracer) for t in tracers):
            assert False
        key = all(t.is_key() for t in tracers)
        avals = [t.aval for t in tracers]
        ans = primitive.abstract_eval(*avals, **params)
        if not primitive.multiple_results:
            ans = [ans]
        out_tracers = [
            UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key)
            for aval in ans
        ]
        # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer
        eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params,
                                source_info_util.current())  # pytype: disable=wrong-arg-types
        for t in out_tracers:
            t.recipe = eqn

        is_variable = (key and primitive is harvest.sow_p
                       and params['tag'] == settings.tag)
        # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where
        # JaxprTrace will just return out_tracers, UnzipTrace will record an
        # additional VariableRecipe into the tracers, which will be used after
        # the trace is complete to construct init/apply Jaxprs.
        if is_variable:
            name, var_in_tracers, var_out_tracers = unzip_registry[primitive](
                tracers, out_tracers, **params)
            variable_recipe = VariableRecipe(name, var_in_tracers,
                                             var_out_tracers)
            for t in out_tracers:
                t.variable_recipe = variable_recipe

        if primitive.multiple_results:
            return out_tracers
        return out_tracers[0]
Exemple #20
0
  def test_op_metadata_while_and_cond(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # An example with while and cond
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_while_cond(x):
      def body_fun(i_acc):
        i, acc = i_acc
        return (i + 1,
                (jnp.cos(acc) +
                 lax.cond(jnp.mod(i, 2) == 0,
                          lambda acc: jnp.sin(acc),
                          lambda acc: acc,
                          acc)))

      _, acc = lax.while_loop(
          lambda i_acc: i_acc[0] <= 5,
          body_fun, (0, x))
      return acc

    x = np.ones((2, 3), np.float32)
    self.CheckOpMetadata(
        f_while_cond, x,
        [tf_test_util.OpMetadataGraph(tf_type="Cos",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 5,
                                      op_name="jax2tf(f_while_cond)/while/body/cos",
                                      op_type="cos"),
         tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 7,
                                      op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
                                      op_type="sin"),
         tf_test_util.OpMetadataGraph(tf_type="FloorMod",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 6,
                                      op_name="jax2tf(f_while_cond)/while/body/rem",
                                      op_type="rem"),
         ]
    )
Exemple #21
0
def nan_error_check(prim, error, *in_vals, **params):
  out = prim.bind(*in_vals, **params)
  no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
  summary = source_info_util.summarize(source_info_util.current())
  msg = f"nan generated by primitive {prim.name} at {summary}"
  return out, assert_func(error, no_nans, msg)
Exemple #22
0
def summary() -> str:
    return str(source_info_util.summarize(source_info_util.current()))
Exemple #23
0
 def sublift(self, val):
   return BatchTracer(self, val.val, val.batch_dim, source_info_util.current())
Exemple #24
0
 def lift(self, val):
   return BatchTracer(self, val, not_mapped, source_info_util.current())
Exemple #25
0
def _cond_partial_eval(trace, *tracers, branches, linear):
    in_unknowns = [t.pval[0] is not None for t in tracers]
    index_uk, *ops_uk = in_unknowns

    if index_uk:
        # When the branch index is unknown, we stage out the whole cond.
        # TODO(mattjj): remove this path when old remat is removed
        params = dict(branches=branches, linear=linear)
        return trace.default_process_primitive(cond_p, tracers, params)

    branches_out_uks = []
    for branch_jaxpr in branches:
        _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr,
                                                         ops_uk,
                                                         instantiate=False)
        branches_out_uks.append(out_uks)
    out_uks = [any(uks) for uks in zip(*branches_out_uks)]

    branches_known, branches_unknown, branch_res_avals = [], [], []
    for branch_jaxpr in branches:
        branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
            pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
        branches_known.append(branch_jaxpr_known)
        branches_unknown.append(branch_jaxpr_unknown)
        branch_res_avals.append(res_avals)

    all_res_avals, res_avals_per_branch = _merge_branch_residuals(
        branch_res_avals)
    num_res = len(all_res_avals)

    num_known_outs = len(out_uks) - sum(out_uks)
    branches_known = _join_cond_outputs(branches_known, all_res_avals,
                                        res_avals_per_branch, num_known_outs)
    branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
        branches_unknown, all_res_avals, res_avals_per_branch)
    assert all(
        all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
        for j in branches_known[1:])

    in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
    linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
    out_consts_res = cond_p.bind(*in_consts,
                                 branches=branches_known,
                                 linear=tuple(linear_known))
    out_consts, res = split_list(out_consts_res,
                                 [len(out_consts_res) - num_res])

    index_tracer = trace.instantiate_const(tracers[0])
    ops_tracers = [
        trace.instantiate_const(t)
        for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk
    ]
    res_tracers = _map(trace.new_instantiated_const, res)
    out_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
        for aval in branches_unknown[0].out_avals
    ]
    linear_unknown = ([False] * num_res +
                      [l for l, uk in zip(linear, in_unknowns[1:]) if uk])
    params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
    name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
    source = source_info_util.current().replace(name_stack=name_stack)
    eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers,
                            out_tracers, cond_p, params, core.no_effects,
                            source)
    for t in out_tracers:
        t.recipe = eqn
    return util.merge_lists(out_uks, out_consts, out_tracers)