示例#1
0
    def trace_to_jaxpr_finalize(in_tracers,
                                out_tracers,
                                trace,
                                instantiate=True):
        # TODO: This is the final part of the partial_eval.trace_to_subjaxpr. Share.
        instantiate = [instantiate] * len(out_tracers)
        out_tracers = safe_map(trace.full_raise,
                               safe_map(core.full_lower, out_tracers))
        out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
                               instantiate, out_tracers)
        jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
        out_pvals = [t.pval for t in out_tracers]
        # TODO: this is from partial_eval.trace_to_jaxpr. Share.
        assert not env

        # TODO: this is from the final part of lax_control_flow._initial_style_jaxpr
        out_avals = safe_map(abstract_arrays.raise_to_shaped,
                             unzip2(out_pvals)[0])
        const_avals = tuple(
            abstract_arrays.raise_to_shaped(core.get_aval(c)) for c in consts)

        in_pvals = [t.pval for t in in_tracers]
        in_avals = tuple(
            safe_map(abstract_arrays.raise_to_shaped,
                     unzip2(in_pvals)[0]))

        typed_jaxpr = core.TypedJaxpr(pe.convert_constvars_jaxpr(jaxpr), (),
                                      const_avals + in_avals, out_avals)
        return typed_jaxpr, consts
示例#2
0
 def new(cls, val):
     if val is jax_core.unit:
         return InverseAndILDJ.unknown(jax_core.abstract_unit)
     val = np.array(val)
     aval = jax_core.get_aval(val)
     aval = abstract_arrays.raise_to_shaped(aval)
     ndslice = NDSlice.new(val, np.zeros_like(val))
     return InverseAndILDJ(aval, frozenset([ndslice]))
示例#3
0
def _initial_style_jaxpr(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
  const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
  typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
                                (), const_avals + in_avals, out_avals)
  return typed_jaxpr, consts, out_tree()
示例#4
0
 def instantiate_const_abstracted(self, tracer):
     pv, const = tracer.pval
     if isinstance(pv, jax_core.AbstractValue):
         return tracer
     elif pv is None:
         aval = abstract_arrays.raise_to_shaped(
             trace_util.get_shaped_aval(const), onp.isscalar(const))
         return UnzipTracer(self, pe.PartialVal.unknown(aval),
                            pe.ConstVar(const), tracer.is_key())
     else:
         raise TypeError(pv)
示例#5
0
 def handle_sow(self, *values, name, tag, tree, mode):
   """Stores a sow in the reaps dictionary."""
   del tag
   if name in self.reaps:
     raise ValueError(f'Variable has already been reaped: {name}')
   avals = tree_util.tree_unflatten(
       tree,
       [abstract_arrays.raise_to_shaped(jax_core.get_aval(v)) for v in values])
   self.reaps[name] = Reap(
       tree_util.tree_unflatten(tree, values), dict(mode=mode, aval=avals))
   return values
示例#6
0
def custom_jvp_call_jaxpr(fun, jvp, *args):
    """A convenience wrapper to apply the custom_jvp_call_jaxpr primitive."""
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    jvp_jaxpr_thunk = pe._memoize(  # pylint: disable=protected-access
        lambda: cd._initial_style_jaxpr(jvp, in_avals * 2))  # pylint: disable=protected-access
    return cd.custom_jvp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           jvp_jaxpr_thunk=jvp_jaxpr_thunk,
                                           num_consts=len(consts))
示例#7
0
def _get_harvest_metadata(closed_jaxpr, settings, *args):
    """Probes a jaxpr for metadata like its sown values."""
    fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr))
    with jax_core.new_main(HarvestTrace) as main:
        settings = HarvestSettings(settings.tag, settings.blocklist,
                                   settings.allowlist, True)
        fun = reap_function(fun, main, settings, True)
        fun, aux = _reap_metadata_wrapper(fun)
        flat_args, in_tree = tree_util.tree_flatten(args)
        flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
        in_avals = jax_util.safe_map(
            lambda a: abstract_arrays.raise_to_shaped(jax_core.get_aval(a)),
            flat_args)
        pe.trace_to_jaxpr_final(flat_fun, in_avals)
        metadata = aux()
        out_tree()
    return metadata
示例#8
0
def custom_vjp_call_jaxpr(fun, fwd, bwd, *args, out_trees):
    in_avals = [
        abstract_arrays.raise_to_shaped(jax_core.get_aval(x)) for x in args
    ]
    fun_jaxpr, consts = cd._initial_style_jaxpr(  # pylint: disable=protected-access
        fun, in_avals)  # consts can be tracers!
    closed_fun_jaxpr = jax_core.ClosedJaxpr(
        pe.convert_constvars_jaxpr(fun_jaxpr), ())
    fwd_jaxpr_thunk = pe._memoize(
        lambda: cd._initial_style_jaxpr(fwd, in_avals))  # pylint: disable=protected-access
    return cd.custom_vjp_call_jaxpr_p.bind(*consts,
                                           *args,
                                           fun_jaxpr=closed_fun_jaxpr,
                                           fwd_jaxpr_thunk=fwd_jaxpr_thunk,
                                           bwd=bwd,
                                           out_trees=out_trees,
                                           num_consts=len(consts))
示例#9
0
 def aval(self):
     return abstract_arrays.raise_to_shaped(jax_core.get_aval(self.val))
示例#10
0
def get_shaped_aval(x):
    """Converts a JAX value type into a shaped abstract value."""
    if hasattr(x, 'dtype') and hasattr(x, 'shape'):
        return abstract_arrays.ShapedArray(x.shape,
                                           dtypes.canonicalize_dtype(x.dtype))
    return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
示例#11
0
 def abstractify(x):
     return abstract_arrays.raise_to_shaped(core.get_aval(x))
示例#12
0
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis):
    input_aval = raise_to_shaped(x)
    shape = list(input_aval.shape)
    size = shape.pop(split_axis)
    shape.insert(concat_axis, size)
    return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
示例#13
0
        if axis_index_groups is not None:
            size = len(axis_index_groups[0])
        elif type(axis_name) is tuple:
            size = prod([core.axis_frame(name).size
                         for name in axis_name])  # type: ignore
        else:
            size = core.axis_frame(axis_name).size  # type: ignore
        return tuple(size * x for x in args)
    return core.Primitive.bind(psum_p,
                               *args,
                               axis_name=axis_name,
                               axis_index_groups=axis_index_groups)


pmax_p = core.Primitive('pmax')
pmax_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
xla.parallel_translations[pmax_p] = \
    partial(_allreduce_translation_rule, lax.max_p)
batching.split_axis_rules[pmax_p] = partial(_split_axis_comm_assoc, pmax_p)
batching.primitive_batchers[pmax_p] = partial(_collective_batcher, pmax_p)
batching.collective_rules[pmax_p] = \
  partial(_batched_reduction_collective,
          pmax_p,
          lambda v, d: v.max(d),
          lambda v, axis_size: v)

pmin_p = core.Primitive('pmin')
pmin_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
xla.parallel_translations[pmin_p] = \
    partial(_allreduce_translation_rule, lax.min_p)
batching.split_axis_rules[pmin_p] = partial(_split_axis_comm_assoc, pmin_p)
示例#14
0
def _gamma_batching_rule(batched_args, batch_dims):
    k, a = batched_args
    bk, ba = batch_dims
    size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
                if i is not None)
    k = batching.bdim_at_front(k, bk, size)
    a = batching.bdim_at_front(a, ba, size)
    return random_gamma_p.bind(k, a), (0, )


random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.multiple_results = True
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a:
                                 (abstract_arrays.raise_to_shaped(a), ))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a:
           (tangent * _gamma_grad(ans[0], a), ))
xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl)
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule


def gamma(key, a, shape=None, dtype=onp.float64):
    """Sample Gamma random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    a: a float or array of floats broadcast-compatible with ``shape``
      representing the parameter of the distribution.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``a``. The default (None)
示例#15
0
  else:
    samples = vmap(_gamma_one)(keys, alphas)
  return jnp.reshape(samples, a_shape),

def _gamma_batching_rule(batched_args, batch_dims):
    k, a = batched_args
    bk, ba = batch_dims
    size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
    k = batching.bdim_at_front(k, bk, size)
    a = batching.bdim_at_front(a, ba, size)
    return random_gamma_p.bind(k, a), (0,)

random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.multiple_results = True
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a: (abstract_arrays.raise_to_shaped(a),))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a: (tangent * _gamma_grad(ans[0], a),))
xla.translations[random_gamma_p] = xla.lower_fun(_gamma_impl)
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule

def gamma(key, a, shape=None, dtype=np.float64):
  """Sample Gamma random values with given shape and float dtype.

  Args:
    key: a PRNGKey used as the random key.
    a: a float or array of floats broadcast-compatible with ``shape``
      representing the parameter of the distribution.
    shape: optional, a tuple of nonnegative integers specifying the result
      shape. Must be broadcast-compatible with ``a``. The default (None)
      produces a result shape equal to ``a.shape``.
    dtype: optional, a float dtype for the returned values (default float64 if
示例#16
0
 def new(cls, val):
   aval = jax_core.get_aval(val)
   if aval is jax_core.abstract_unit:
     return cls.unknown(aval)
   aval = abstract_arrays.raise_to_shaped(aval)
   return InverseAndILDJ(aval, val, np.array(0.))
示例#17
0
def typematch(aval1, aval2):
  return raise_to_shaped(aval1) == raise_to_shaped(aval2)
示例#18
0
def typecheck(aval, x):
  aval = raise_to_shaped(aval)
  try:
    return aval == core.lattice_join(aval, core.get_aval(x))
  except TypeError:
    return False
示例#19
0
def _abstractify(x):
  return raise_to_shaped(core.get_aval(x))
示例#20
0
def get_shaped_aval(x):
    if hasattr(x, 'dtype') and hasattr(x, 'shape'):
        return abstract_arrays.ShapedArray(x.shape, x.dtype)
    return abstract_arrays.raise_to_shaped(jax_core.get_aval(x))
示例#21
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear, unroll):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals,
                                       *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        jax.lax.lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear,
                           unroll=unroll)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys