Example #1
0
File: xla.py Project: John1Tang/jax
 def f_new(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
           avals_out: Optional[Sequence[core.AbstractValue]],
           *xla_args: xc.XlaOp,
           **params) -> Sequence[xc.XlaOp]:
   wrapped_fun = lu.wrap_init(fun, params)
   if not multiple_results:
     wrapped_fun = _tuple_output(wrapped_fun)
   with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)):
     jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
   return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
                        *xla_args)
Example #2
0
File: mlir.py Project: rsepassi/jax
 def f_lowered(ctx, avals_in, avals_out, *args, **params):
     if multiple_results:
         f = fun
     else:
         f = lambda *args, **kw: (fun(*args, **kw), )
     wrapped_fun = lu.wrap_init(f, params)
     with core.extend_axis_env_nd(
             zip(ctx.axis_env.names, ctx.axis_env.sizes)):
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
     return jaxpr_subcomp(ctx, jaxpr, _ir_consts(consts),
                          *map(wrap_singleton_ir_values, args))
Example #3
0
 def f_with_avals(c, avals, xla_args, params):
     # parallelism is only supported via the new-style API.
     axis_env = AxisEnv(1, (), ())
     wrapped_fun = lu.wrap_init(fun, params)
     if not multiple_results:
         wrapped_fun = _tuple_output(wrapped_fun)
     with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
     ctx = TranslationContext(c, backend, axis_env, new_name_stack())
     outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
     if (multiple_results or any(
             len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
         return xops.Tuple(c, outs)
     else:
         assert len(outs) == 1, outs
         return outs[0]