Exemple #1
0
def _primal_tangent_shapes_match(primal, tangent):
  if type(tangent) is not Zero:
    primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
    tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
    assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape)
    expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
    assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
def _flatten_jvp(in_tree, *args):
  primals_in, tangents_in = split_list(args, [len(args) // 2])
  py_primals = tree_unflatten(in_tree, primals_in)
  py_tangents = tree_unflatten(in_tree, tangents_in)
  pair_out = yield (py_primals, py_tangents), {}
  if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
    msg = ("Custom JVP rule must produce a pair (list or tuple of length two) "
           "representing primal and tangent outputs, got {}.")
    raise TypeError(msg.format(pair_out))
  py_primals_out, py_tangents_out = pair_out
  primals_out, out_tree = tree_flatten(py_primals_out)
  tangents_out, out_tree2 = tree_flatten(py_tangents_out)
  if out_tree != out_tree2:
    msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
           "container (pytree) structures, but got {} and {} respectively.")
    raise TypeError(msg.format(out_tree, out_tree2))
  primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out]
  tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) for t in tangents_out]
  if primal_avals_out != tangent_avals_out:
    if len(primal_avals_out) == 1:
      (av1,), (av2,) = primal_avals_out, tangent_avals_out
      msg = ("Custom JVP rule must produce primal and tangent outputs with "
             "equal shapes and dtypes, but got {} and {} respectively.")
      raise TypeError(msg.format(av1.str_short(), av2.str_short()))
    else:
      msg = ("Custom JVP rule must produce primal and tangent outputs with "
             "equal shapes and dtypes, but got:\n{}")
      disagreements = (
          "  primal {} for tangent {}".format(av1.str_short(), av2.str_short())
          for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
      raise TypeError(msg.format('\n'.join(disagreements)))
  yield primals_out + tangents_out, out_tree
Exemple #3
0
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
    f, msgs = checkify_subtrace(f)
    f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    avals_in = [err_aval, code_aval, *in_avals]
    jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
    return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
Exemple #4
0
def checkify_jaxpr(jaxpr, error):
  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
  f, msgs = check_errors_subtrace(f)
  f = check_errors_traceable(f, tuple(error.msgs.items()))
  err_aval = core.raise_to_shaped(core.get_aval(error.err))
  code_aval = core.raise_to_shaped(core.get_aval(error.code))
  avals_in = [err_aval, code_aval, *jaxpr.in_avals]
  jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
  return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
Exemple #5
0
def ignore_errors_jaxpr(jaxpr, error):
    """Constructs a jaxpr which takes two extra args but ignores them."""
    err_aval = core.raise_to_shaped(core.get_aval(error.err))
    code_aval = core.raise_to_shaped(core.get_aval(error.code))
    consts = jaxpr.consts
    jaxpr = jaxpr.jaxpr
    new_vars = core.gensym([jaxpr])
    new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars)
    new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars,
                           jaxpr.eqns)
    return core.ClosedJaxpr(new_jaxpr, consts)
Exemple #6
0
 def aval(self):
     aval = raise_to_shaped(core.get_aval(self.val))
     if self.batch_dim is not_mapped or aval is core.abstract_unit:
         return aval
     else:
         return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim,
                                 aval)
Exemple #7
0
def _sparsify_jaxpr(spenv, jaxpr, *argspecs):
    # TODO(jakevdp): currently this approach discards all information about
    #   shared data & indices when generating the sparsified jaxpr. The
    #   current approach produces valid sparsified while loops, but they
    #   don't work in corner cases (see associated TODO in sparsify_test.py)
    out_tree = None

    @lu.wrap_init
    def wrapped(*args_flat):
        nonlocal out_tree
        args = tree_unflatten(in_tree, args_flat)
        argspecs = arrays_to_argspecs(spenv, args)
        result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, argspecs, spenv)
        out = argspecs_to_arrays(spenv, result)
        out_flat, out_tree = tree_flatten(out)
        return out_flat

    args = argspecs_to_arrays(spenv, argspecs)
    args_flat, in_tree = tree_flatten(args)
    avals_flat = [
        core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat
    ]
    sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
    sp_jaxpr = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(sp_jaxpr), consts)
    return sp_jaxpr, out_tree
Exemple #8
0
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis,
                              axis_index_groups):
    input_aval = raise_to_shaped(x)
    shape = list(input_aval.shape)
    size = shape.pop(split_axis)
    shape.insert(concat_axis, size)
    return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
Exemple #9
0
def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
  # TODO(jakevdp): currently this approach discards all information about
  #   shared data & indices when generating the sparsified jaxpr. The
  #   current approach produces valid sparsified while loops, but they
  #   don't work in corner cases (see associated TODO in sparsify_test.py)
  out_tree = None

  @lu.wrap_init
  def wrapped(*args_flat):
    # TODO(frostig,jakevdp): This closes over `spenv`, which can bring
    # in buffers from the "outer scope" as constants. Is this a
    # problem for primitives like cond and while_loop, which always
    # convert constvars to invars when staging out their subjaxprs?
    nonlocal out_tree
    args = tree_unflatten(in_tree, args_flat)
    spvalues = arrays_to_spvalues(spenv, args)
    result = eval_sparse(jaxpr.jaxpr, jaxpr.consts, spvalues, spenv)
    out = spvalues_to_arrays(spenv, result)
    out_flat, out_tree = tree_flatten(out)
    return out_flat

  args = spvalues_to_arrays(spenv, spvalues)
  args_flat, in_tree = tree_flatten(args)
  avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
  sp_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
  sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
  return sp_jaxpr, out_tree
Exemple #10
0
 def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue:
     if not self.fwd or not self.bwd:
         msg = "No VJP defined for custom_vjp function {} using defvjp."
         raise AttributeError(msg.format(self.__name__))
     args = _resolve_kwargs(self.fun, args, kwargs)
     if self.nondiff_argnums:
         for i in self.nondiff_argnums:
             _check_for_tracers(args[i])
         nondiff_argnums = set(self.nondiff_argnums)
         dyn_argnums = [
             i for i in range(len(args)) if i not in nondiff_argnums
         ]
         f_, dyn_args = argnums_partial(lu.wrap_init(self.fun), dyn_argnums,
                                        args)
         static_args = [args[i] for i in self.nondiff_argnums]
         fwd, _ = argnums_partial(lu.wrap_init(self.fwd), dyn_argnums, args)
         bwd = _add_args(lu.wrap_init(self.bwd), static_args)
     else:
         f_, dyn_args = lu.wrap_init(self.fun), args
         fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
     args_flat, in_tree = tree_flatten(dyn_args)
     in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
     flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
     flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
     flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
     out_flat = custom_vjp_call_p.bind(flat_fun,
                                       flat_fwd,
                                       flat_bwd,
                                       *args_flat,
                                       out_trees=out_trees)
     fst, aux = lu.merge_linear_aux(out_tree, out_trees)
     out_tree = aux if fst else aux[0]
     return tree_unflatten(out_tree, out_flat)
Exemple #11
0
def _cond_transpose(reduce_axes, cts, *args, branches, linear):
    index, *ops = args
    in_avals = _map(raise_to_shaped, branches[0].in_avals)
    num_res = len(ops) - sum(linear)

    branches_trans = tuple(
        _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes)
        for jaxpr in branches)
    lin_in_avals = [
        raise_to_shaped(a, weak_type=False) for a, l in zip(in_avals, linear)
        if l
    ]
    assert all(
        core.typematch(out_aval, lin_in_aval) for jaxpr in branches_trans
        for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))

    res = ops[:num_res]
    cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
    linear_trans = (False, ) * num_res + (True, ) * len(cts)

    out = cond_p.bind(index,
                      *res,
                      *cts,
                      branches=branches_trans,
                      linear=linear_trans)
    assert all(_map(core.typecheck, lin_in_avals, out))

    out_iter = iter(out)
    out = [next(out_iter) if l else None for l in linear]
    assert next(out_iter, None) is None
    return [None] + out
Exemple #12
0
def remat_transpose(reduce_axes, out_cts, *in_primals, jaxpr, **params):
  assert not jaxpr.constvars
  cell = lambda: None

  @lu.wrap_init
  def transposed(*args):
    in_primals, out_cts = tree_unflatten(treedef, args)
    in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
                pe.PartialVal.known(x) for x in in_primals]
    primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
    t_jaxpr, _, consts = pe.trace_to_jaxpr_nounits(primal_fun, in_pvals, False)
    dummy_args = [ad.UndefinedPrimal(v.aval) for v in t_jaxpr.invars]
    in_cts = ad.backward_pass(t_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
    in_cts_ = iter(in_cts)
    in_cts = [next(in_cts_) if ad.is_undefined_primal(x)
              else ad_util.Zero(x.aval) for x in in_primals]
    assert next(in_cts_, None) is None
    in_cts, cell.treedef = tree_flatten(in_cts)
    return in_cts

  args, treedef = tree_flatten((in_primals, out_cts))
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  transposed_jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(transposed, in_avals)
  transposed_jaxpr = pe.convert_constvars_jaxpr(transposed_jaxpr_)
  in_cts = remat_p.bind(*consts, *args, jaxpr=transposed_jaxpr, **params)
  return tree_unflatten(cell.treedef, in_cts)  # type: ignore
Exemple #13
0
def typecheck_atom(env, x):
  if isinstance(x, Literal):
    return core.raise_to_shaped(core.get_aval(x.val))
  elif isinstance(x, Var):
    return typecheck_type(env, x.aval)
  else:
    raise TypeError(f'atom of unexpected type {x}')
def _custom_vjp_call_jaxpr_jvp(primals, tangents, *,
                               fun_jaxpr: core.ClosedJaxpr,
                               fwd_jaxpr_thunk: Callable[[],
                                                         Tuple[core.Jaxpr,
                                                               Sequence[Any]]],
                               bwd: lu.WrappedFun, out_trees: Callable,
                               num_consts: int):
    _, args = split_list(primals, [num_consts])
    consts_dot, args_dot = split_list(tangents, [num_consts])
    if any(type(t) is not Zero for t in consts_dot):
        raise ad.CustomVJPException()
    fwd_jaxpr, fwd_consts = fwd_jaxpr_thunk()  # consts can be tracers!
    out_tree, res_tree = out_trees()
    args_dot = map(ad.instantiate_zeros, args_dot)
    # Cast float0 to zeros with the primal dtype because custom vjp rules don't
    # currently handle float0s
    args_dot = map(ad.replace_float0s, args, args_dot)
    res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args)
    res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
    avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
    tangents_out = ad.custom_lin_p.bind(*res,
                                        *args_dot,
                                        num_res=res_tree.num_leaves,
                                        bwd=bwd,
                                        avals_out=avals_out)
    tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
    return primals_out, tangents_out
Exemple #15
0
def array_result_handler(sticky_device: Optional[Device],
                         aval: core.ShapedArray):
    if aval.dtype == dtypes.float0:
        return lambda _, __: np.zeros(aval.shape, dtypes.float0)
    aval = core.raise_to_shaped(aval)
    handler = lambda _, b: _maybe_create_array_from_da(b, aval, sticky_device)
    handler.args = aval, sticky_device  # for C++ dispatch path in api.py
    return handler
Exemple #16
0
def _custom_ivjp(fun, ivjp, args):
  in_avals = [raise_to_shaped(get_aval(x)) for x in args]
  fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
  try:
    ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2)
  except RecursionError:
    raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__))
  return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr,
                                   ivjp_jaxpr=ivjp_jaxpr)
Exemple #17
0
def shaped_abstractify(x):
  try:
    return core.raise_to_shaped(core.get_aval(x))
  except TypeError:
    pass

  weak_type = getattr(x, 'weak_type', False)
  named_shape = getattr(x, 'named_shape', {})
  return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
                          named_shape=named_shape)
Exemple #18
0
 def fun_remat(*args, **kwargs):
   args_flat, in_tree = tree_flatten((args, kwargs))
   flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(fun, in_tree, False, "checkpoint")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   out_flat = remat_p.bind(
       *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr),
       prevent_cse=prevent_cse, differentiated=False, policy=policy)
   return tree_unflatten(out_tree(), out_flat)
Exemple #19
0
 def __init__(self, trace, val, batch_dim: Optional[int]):
     if config.jax_enable_checks:
         assert type(batch_dim) in (int, NotMapped)
         if type(batch_dim) is int:
             aval = raise_to_shaped(core.get_aval(val))
             assert aval is core.abstract_unit or 0 <= batch_dim < len(
                 aval.shape)  # type: ignore
     self._trace = trace
     self.val = val
     self.batch_dim = batch_dim
Exemple #20
0
 def __init__(self, trace, val, batch_dim: Optional[int],
              source_info: Optional[source_info_util.SourceInfo] = None):
   if config.jax_enable_checks:
     assert type(batch_dim) in (int, NotMapped)
     if type(batch_dim) is int:
       aval = raise_to_shaped(core.get_aval(val))
       assert batch_dim is not_mapped or 0 <= batch_dim < len(aval.shape)  # type: ignore
   self._trace = trace
   self.val = val
   self.batch_dim = batch_dim
   self.source_info = source_info
Exemple #21
0
 def f_jitted(*args, **kwargs):
   args, in_tree = tree_flatten((args, kwargs))
   f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
   jaxpr, consts, unconverted_binders = trace_to_jaxpr_dynamic(f, in_avals)
   num_consts = len(consts)
   args = [*consts, *args]
   dim_vals, args = _extract_dim_vals(jaxpr.in_dim_binders, jaxpr.in_binders,
                                      unconverted_binders, args)
   out_flat = dynamic_xla_call_p.bind(*dim_vals, *args, jaxpr=jaxpr,
                                      num_consts=num_consts)
   return tree_unflatten(out_tree(), out_flat)
Exemple #22
0
 def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
   primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
   tangents_in = map(instantiate_zeros, tangents_in)
   res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
   out_tree, res_tree = out_trees()
   res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
   avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
   tangents_out = custom_lin_p.bind(
       *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
       out_avals=avals_out)
   tangents_out = map(recast_to_float0, primals_out, tangents_out)
   return map(partial(JVPTracer, self), primals_out, tangents_out)
Exemple #23
0
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
                        *idx: int):
    if not isinstance(ref_aval, ShapedArrayRef):
        raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
    val_aval = core.raise_to_shaped(val_aval)
    assert isinstance(val_aval, core.ShapedArray)
    expected_output_shape = ref_aval.shape[len(idx):]
    if expected_output_shape != val_aval.shape:
        raise ValueError("Invalid shape for `swap`. "
                         f"Ref shape: {ref_aval.shape}. "
                         f"Value shape: {val_aval.shape}. "
                         f"Indices: {idx}. ")
    return core.ShapedArray(ref_aval.shape[len(idx):], ref_aval.dtype), {State}
Exemple #24
0
 def fwd(*args):
   flat_args, in_tree = tree_flatten(args)
   in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args)
   fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals)
   # TODO: Don't warn if consts contain JVP tracers?
   if consts:
     warnings.warn("Values that an @invertible function closes over will not have their " +
                   "gradients computed correctly (their uses inside this function will be ignored)!")
   # TODO: This requires the body to be jittable, but this shouldn't be necessary.
   #       Is there a way to trace a jaxpr while running it?
   flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
   return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
Exemple #25
0
 def __call__(self, *args, **kwargs):
   assert not kwargs
   args_flat, in_tree = tree_flatten(args)
   flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
   debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap")
   jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
   assert not len(consts)
   closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
   out_flat = custom_vmap_p.bind(*consts, *args_flat,
                                 call=closed_call,
                                 rule=self.vmap_rule,
                                 in_tree=in_tree)
   return tree_unflatten(out_tree(), out_flat)
Exemple #26
0
def _lu_abstract_eval(operand):
  operand = raise_to_shaped(operand)
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to LU decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    pivot = operand.update(shape=batch_dims + (min(m, n),), dtype=jnp.int32)
    perm = operand.update(shape=batch_dims + (m,), dtype=jnp.int32)
  else:
    pivot = operand
    perm = operand
  return operand, pivot, perm
Exemple #27
0
def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr,
                            reduce_axes, False)
  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
  if not config.jax_experimental_name_stack:
    params = dict(params, name=wrap_name(params['name'], 'transpose'))
  update_params = call_transpose_param_updaters.get(primitive)
  if update_params:
    params = update_params(params, map(is_undefined_primal, args),
                           [type(x) is not Zero for x in ct])
  if config.jax_dynamic_shapes:
    in_type = [(core.raise_to_shaped(core.get_aval(x)), True) for x in all_args]
    fun = lu.annotate(fun, tuple(in_type))
  out_flat = primitive.bind(fun, *all_args, **params)
  return tree_unflatten(out_tree(), out_flat)
Exemple #28
0
def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size):
  pivots = raise_to_shaped(pivots)
  if isinstance(pivots, ShapedArray):
    if pivots.ndim < 1 or pivots.dtype != np.dtype(np.int32):
      raise ValueError(
          'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype '
          'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype))

    if permutation_size < pivots.shape[-1]:
      raise ValueError(
          'Output permutation size {} has to exceed the trailing dimension of '
          'the pivots. Got shape {}'.format(permutation_size, pivots.shape))

    batch_dims = pivots.shape[:-1]
    permutations = pivots.update(shape=batch_dims + (permutation_size,))
  else:
    permutations = pivots

  return permutations
Exemple #29
0
    def __call__(self, residual_arg, linear_arg):
        res_arg, lin_arg = residual_arg, linear_arg
        _, res_tree = tree_flatten(res_arg)
        _, lin_tree = tree_flatten(lin_arg)
        args_flat, in_tree = tree_flatten((res_arg, lin_arg))

        flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun),
                                                  in_tree)
        in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
        debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
        jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
        assert not len(consts)
        closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
        out_flat = custom_transpose_p.bind(*consts,
                                           *args_flat,
                                           call=closed_call,
                                           rule=self.transpose,
                                           lin_tree=lin_tree,
                                           res_tree=res_tree,
                                           out_tree=out_tree())
        return tree_unflatten(out_tree(), out_flat)
def abstractify(x):
    return core.raise_to_shaped(core.get_aval(x))