예제 #1
0
파일: transform.py 프로젝트: yashk2810/jax
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
    # TODO(jakevdp): currently this approach discards all information about
    #   shared data & indices when generating the sparsified jaxpr. The
    #   current approach produces valid sparsified while loops, but they
    #   don't work in corner cases (see associated TODO in sparsify_test.py)
    out_tree = None

    @lu.wrap_init
    def wrapped(*args_flat):
        nonlocal out_tree
        args = tree_unflatten(in_tree, args_flat)
        argspecs = arrays_to_argspecs(spenv, args)
        result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
        out = argspecs_to_arrays(spenv, result)
        out_flat, out_tree = tree_flatten(out)
        return out_flat

    args = argspecs_to_arrays(spenv, argspecs)
    args_flat, in_tree = tree_flatten(args)
    avals_flat = [
        core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat
    ]
    sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
    sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
    return sp_jaxpr, out_tree
예제 #2
0
 def fwd(*args, **kwargs):
     ans, rule = fun(*args, **kwargs)
     ans_flat, out_tree = tree_flatten((ans, ))
     rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree)
     ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat]
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(rule, ans_avals)
     return ans, Residuals(jaxpr, in_tree(), out_tree, consts)
예제 #3
0
def _trace_to_jaxpr_with_refs(
    f, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
    f, out_tree_thunk = flatten_fun_nokwargs(
        lu.wrap_init(f), treedef_tuple((tree_structure(0), state_tree)))
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(f, state_avals)
    return jaxpr, consts, out_tree_thunk()
예제 #4
0
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
  # TODO(jakevdp): currently this approach discards all information about
  #   shared data & indices when generating the sparsified jaxpr. The
  #   current approach produces valid sparsified while loops, but they
  #   don't work in corner cases (see associated TODO in sparsify_test.py)
  out_tree = None

  @lu.wrap_init
  def wrapped(*args_flat):
    # TODO(frostig,jakevdp): This closes over `spenv`, which can bring
    # in buffers from the "outer scope" as constants. Is this a
    # problem for primitives like cond and while_loop, which always
    # convert constvars to invars when staging out their subjaxprs?
    nonlocal out_tree
    args = tree_unflatten(in_tree, args_flat)
    spvalues = arrays_to_spvalues(spenv, args)
    result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv)
    out = spvalues_to_arrays(spenv, result)
    out_flat, out_tree = tree_flatten(out)
    return out_flat

  args = spvalues_to_arrays(spenv, spvalues)
  args_flat, in_tree = tree_flatten(args)
  avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
  sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
  sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
  return sp_jaxpr, out_tree
예제 #5
0
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
  assert not jaxpr.constvars
  cell = lambda: None

  @lu.wrap_init
  def transposed(*args):
    in_primals, out_cts = tree_unflatten(treedef, args)
    in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
                pe.PartialVal.known(x) for x in in_primals]
    primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
    t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
    dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
    in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
    in_cts_ = iter(in_cts)
    in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
              else ad_util.Zero(x.aval) for x in in_primals]
    assert next(in_cts_, None) is None
    in_cts, cell.treedef = tree_flatten(in_cts)
    return in_cts

  args, treedef = tree_flatten((in_primals, out_cts))
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
  transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
  in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
  return tree_unflatten(cell.treedef, in_cts)  # type: ignore
예제 #6
0
  def test_staging_nested(self):
    n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      @jax.jit
      def g(x, y, z, w):
        return (x, w)
      return g(x, y, x, y)

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.invars, 1 + 2)  # one axis size var, two other inputs
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)

    self.assertLen(jaxpr.outvars, 2)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)

    self.assertLen(jaxpr.eqns, 1)
    eqn = jaxpr.eqns[0]
    self.assertIsInstance(eqn.primitive, core.CallPrimitive)
    inner_jaxpr = eqn.params['call_jaxpr']
    self.assertIsInstance(inner_jaxpr, core.Jaxpr)

    self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
예제 #7
0
  def test_staging_nested_including_shape_arg(self):
    # This test covers the _get_tracers_only_in_shapes logic in partial_eval.py.
    n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      @jax.jit
      def g(_, x, y, z, w):
        return (x, w)
      return g(x.shape[0], x, y, x, y)

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.eqns, 1)
    eqn = jaxpr.eqns[0]
    self.assertIsInstance(eqn.primitive, core.CallPrimitive)
    inner_jaxpr = eqn.params['call_jaxpr']
    self.assertIsInstance(inner_jaxpr, core.Jaxpr)

    self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[1].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[2].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[3].aval.shape)
    self.assertEqual((inner_jaxpr.invars[0],), inner_jaxpr.invars[4].aval.shape)
예제 #8
0
    def test_staging_primitive_applications(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            z = lax.mul(x, y)
            w = lax.sin(z)
            u = lax_internal._reduce_sum(w, [0])
            return (u, )

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])

        self.assertLen(jaxpr.invars,
                       1 + 2)  # one axis size var, two other inputs
        self.assertLen(jaxpr.eqns, 3)
        self.assertLen(jaxpr.eqns[0].outvars, 1)
        self.assertEqual(jaxpr.eqns[0].outvars[0].aval.shape,
                         jaxpr.invars[1].aval.shape)

        self.assertLen(jaxpr.outvars, 1)
        self.assertEqual(jaxpr.outvars[0].aval.shape, ())
예제 #9
0
파일: ode.py 프로젝트: tudorcebere/jax
def closure_convert(fun, in_tree, in_avals):
    if config.omnistaging_enabled:
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(
            wrapped_fun, in_avals)
    else:
        in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        with core.initial_style_staging():  # type: ignore
            jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
                wrapped_fun, in_pvals, instantiate=True,
                stage_out=False)  # type: ignore
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(mattjj): revise this approach
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(y, t, *hconsts_args):
        hoisted_consts, args = split_list(hconsts_args, [num_consts])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
예제 #10
0
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
    f, msgs = checkify_subtrace(f)
    f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    avals_in = [err_aval, code_aval, *in_avals]
    jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
    return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
예제 #11
0
파일: common.py 프로젝트: xueeinstein/jax
def _initial_style_open_jaxpr(fun: Callable,
                              in_tree,
                              in_avals,
                              primitive_name: Optional[str] = None):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>")
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
    return jaxpr, consts, out_tree()
예제 #12
0
파일: ad.py 프로젝트: jbampton/jax
def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
  assert len(jaxpr.in_avals) == len(nonzeros)
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False),
                                        nonzeros)
  tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
  avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
  jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
예제 #13
0
파일: batching.py 프로젝트: xueeinstein/jax
def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest,
                      axis_name, main_type):
  f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
  f, out_batched = _batch_jaxpr_inner(f, axis_size, out_axes_dest)
  f = _batch_jaxpr_outer(f, axis_name, axis_size, in_axes, main_type)
  avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped
              else aval for aval, b in zip(closed_jaxpr.in_avals, in_axes)]
  jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
예제 #14
0
파일: checkify.py 프로젝트: rsepassi/jax
def checkify_jaxpr(jaxpr, error):
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f, msgs = check_errors_subtrace(f)
  f = check_errors_traceable(f, tuple(error.msgs.items()))
  err_aval = core.raise_to_shaped(core.get_aval(error.err))
  code_aval = core.raise_to_shaped(core.get_aval(error.code))
  avals_in = [err_aval, code_aval, *jaxpr.in_avals]
  jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
예제 #15
0
    def test_typecheck_staging_nested(self):
        n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        m = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((DBIdx(1), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(a, b):
            @jax.jit
            def g(x):
                return x

            return g(a),

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, m, a, b], keep_inputs=[False, False, True, True])
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (e,) }
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce a type error by applying the called jaxpr to arguments
        # with types which aren't consistent with its input binders:
        _, _, c, d = jaxpr.invars
        jaxpr.eqns[0].invars[1] = d
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a d   !!! type error here !!!
        #   in (e,) }
        with self.assertRaisesRegex(TypeError, "passes operand"):
            core.check_jaxpr(jaxpr)

        # Restore the original jaxpr:
        jaxpr.eqns[0].invars[1] = c
        core.check_jaxpr(jaxpr)  # no problems here...

        # Let's introduce another type error by setting the call result let binders
        # to have the wrong type:
        jaxpr.eqns[0].outvars[0] = core.Var(0, '', d.aval)
        # { lambda ; a:i32[] b:i32[] c:f32[a] d:f32[b]. let
        #     e:f32[b] = xla_call[   !!! type error here !!!
        #       call_jaxpr={ lambda ; f:i32[] g:f32[f]. let  in (g,) }
        #       name=g
        #     ] a c
        #   in (h,) }
        with self.assertRaisesRegex(TypeError, "inconsistently typed as"):
            core.check_jaxpr(jaxpr)
예제 #16
0
 def wrapped(spenv: SparsifyEnv, *spvalues: SparsifyValue, **params: Any) -> Tuple[Sequence[SparsifyValue], bool]:
   spvalues_flat, in_tree = tree_flatten(spvalues, is_leaf=_is_spvalue)
   in_avals_flat = spvalues_to_avals(spenv, spvalues_flat)
   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
   jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
   result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
   if len(out_avals_flat) != len(result):
     raise Exception("Internal: eval_sparse does not return expected number of arguments. "
                     "Got {result} for avals {out_avals_flat}")
   return result, out_tree()
예제 #17
0
 def wrapped(spenv: SparseEnv, *argspecs: ArgSpec, **params: Any) -> Tuple[Sequence[ArgSpec], bool]:
   in_avals = argspecs_to_avals(spenv, argspecs)
   in_avals_flat, in_tree = tree_flatten(in_avals)
   wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
   jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
   result = eval_sparse(jaxpr, consts, argspecs, spenv)
   if len(out_avals_flat) != len(result):
     raise Exception("Internal: eval_sparse does not return expected number of arguments. "
                     "Got {result} for avals {out_avals_flat}")
   return result, out_tree()
예제 #18
0
 def fun_remat(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(fun, in_tree, False, "checkpoint")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   out_flat = remat_p.bind(
       *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
       prevent_cse=prevent_cse, differentiated=False, policy=policy)
   return tree_unflatten(out_tree(), out_flat)
예제 #19
0
def make_jaxpr(fun: Callable, in_tree: PyTreeDef,
               in_avals: typing.Tuple[core.AbstractValue],  # with DBIdx in them
               keep_inputs: typing.Tuple[bool]
               ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]:
  flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
  debug = pe.debug_info(fun, in_tree, False, "dex_jit")
  jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug,
                                                keep_inputs=keep_inputs)
  jaxpr = pe.convert_constvars_jaxpr(jaxpr_)
  consts = [_canonicalize_arg(c) for c in consts]
  return jaxpr, consts, out_tree()
예제 #20
0
파일: xla.py 프로젝트: John1Tang/jax
 def f_new(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
           avals_out: Optional[Sequence[core.AbstractValue]],
           *xla_args: xc.XlaOp,
           **params) -> Sequence[xc.XlaOp]:
   wrapped_fun = lu.wrap_init(fun, params)
   if not multiple_results:
     wrapped_fun = _tuple_output(wrapped_fun)
   with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)):
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
   return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
                        *xla_args)
예제 #21
0
파일: mlir.py 프로젝트: rsepassi/jax
 def f_lowered(ctx, avals_in, avals_out, *args, **params):
     if multiple_results:
         f = fun
     else:
         f = lambda *args, **kw: (fun(*args, **kw), )
     wrapped_fun = lu.wrap_init(f, params)
     with core.extend_axis_env_nd(
             zip(ctx.axis_env.names, ctx.axis_env.sizes)):
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
     return jaxpr_subcomp(ctx, jaxpr, _ir_consts(consts),
                          *map(wrap_singleton_ir_values, args))
예제 #22
0
def discharge_state(jaxpr: core.Jaxpr,
                    consts: Sequence[Any]) -> Tuple[core.Jaxpr, List[Any]]:
    """Converts a jaxpr that takes in `Ref`s into one that doesn't."""
    in_avals = [
        core.ShapedArray(v.aval.shape, v.aval.dtype)
        if type(v.aval) is ShapedArrayRef else v.aval for v in jaxpr.invars
    ]
    eval_jaxpr = lu.wrap_init(
        partial(_eval_jaxpr_discharge_state, jaxpr, consts))
    new_jaxpr, _, new_consts = pe.trace_to_jaxpr_dynamic(eval_jaxpr, in_avals)
    return new_jaxpr, new_consts
예제 #23
0
 def __call__(self, *args, **kwargs):
   assert not kwargs
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   assert not len(consts)
   closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   out_flat = custom_vmap_p.bind(*consts, *args_flat,
                                 call=closed_call,
                                 rule=self.vmap_rule,
                                 in_tree=in_tree)
   return tree_unflatten(out_tree(), out_flat)
예제 #24
0
 def f_with_avals(c, avals, xla_args, params):
     # parallelism is only supported via the new-style API.
     axis_env = AxisEnv(1, (), ())
     wrapped_fun = lu.wrap_init(fun, params)
     if not multiple_results:
         wrapped_fun = _tuple_output(wrapped_fun)
     with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
     ctx = TranslationContext(c, backend, axis_env, new_name_stack())
     outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
     if (multiple_results or any(
             len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
         return xops.Tuple(c, outs)
     else:
         assert len(outs) == 1, outs
         return outs[0]
예제 #25
0
 def wrapped(*args, **kwargs):
   fun = lu.wrap_init(f, kwargs)
   flat_args, in_tree = tree_util.tree_flatten(args)
   flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
   flat_avals = safe_map(get_shaped_aval, flat_args)
   if dynamic:
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
         flat_fun,
         flat_avals)
   else:
     pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals]
     jaxpr, _, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True)
   typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
   return typed_jaxpr, (in_tree, out_tree())
예제 #26
0
    def test_staging_nested_including_shape_arg(self):
        n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
        a = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)
        b = core.DShapedArray((pe.DBIdx(0), ),
                              jnp.dtype('float32'),
                              weak_type=False)

        @lu.wrap_init
        def f(x, y):
            @jax.jit
            def g(_, x, y, z, w):
                return (x, w)

            return g(x.shape[0], x, y, x, y)

        jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
            f, [n, a, b], keep_inputs=[False, True, True])
        print(jaxpr)

        # { lambda ; a:i32[] b:f32[a] c:f32[a]. let
        #     d:f32[a] e:f32[a] = xla_call[
        #       call_jaxpr={ lambda ; f:i32[] g:i32[] h:f32[f] i:f32[f] j:f32[f] k:f32[f]. let
        #
        #         in (h, k) }
        #       name=g
        #     ] a a b c b c
        #   in (d, e) }

        self.assertLen(jaxpr.eqns, 1)
        eqn = jaxpr.eqns[0]
        self.assertIsInstance(eqn.primitive, core.CallPrimitive)
        inner_jaxpr = eqn.params['call_jaxpr']
        self.assertIsInstance(inner_jaxpr, core.Jaxpr)

        self.assertLen(inner_jaxpr.invars, 1 + 4)  # one axis size var
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[1].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[2].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[3].aval.shape)
        self.assertEqual((inner_jaxpr.invars[0], ),
                         inner_jaxpr.invars[4].aval.shape)
예제 #27
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     if not jax.config.omnistaging_enabled:
         raise ValueError('Oryx must be used with JAX omnistaging enabled.')
     if dynamic:
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
     else:
         pvals = [
             pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals
         ]
         jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun,
                                              pvals,
                                              instantiate=True)
     typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
     return typed_jaxpr, (in_tree, out_tree())
예제 #28
0
def _closure_convert_for_avals(fun, in_tree, in_avals):
  wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  out_tree = out_tree()

  (closure_consts, hoisted_consts), merge = partition_list(_maybe_perturbed, consts)
  num_consts = len(hoisted_consts)

  def converted_fun(*args_hconsts):
    num_args = len(args_hconsts) - num_consts
    args, hoisted_consts = split_list(args_hconsts, [num_args])
    consts = merge(closure_consts, hoisted_consts)
    all_args, in_tree2 = tree_flatten(tuple(args))
    assert in_tree == in_tree2
    out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
    return tree_unflatten(out_tree, out_flat)

  return converted_fun, hoisted_consts
예제 #29
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     if dynamic and jax.config.omnistaging_enabled:
         jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
             flat_fun, flat_avals)
     else:
         pvals = [
             pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals
         ]
         jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun,
                                                      pvals,
                                                      instantiate=True,
                                                      stage_out=True)
         out_avals = [pval.get_aval() for pval in out_pvals]
     typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals)
     return typed_jaxpr, (in_tree, out_tree())
예제 #30
0
  def test_staging_basic(self):
    n = core.ShapedArray((), jnp.dtype('int32'), weak_type=False)
    a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
    b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)

    @lu.wrap_init
    def f(x, y):
      return x, y

    jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
                                            keep_inputs=[False, True, True])

    self.assertLen(jaxpr.invars, 3)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[1].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.invars[2].aval.shape)

    self.assertLen(jaxpr.outvars, 2)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[0].aval.shape)
    self.assertEqual((jaxpr.invars[0],), jaxpr.outvars[1].aval.shape)