Esempio n. 1
0
 def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
                           in_tracers, out_pvs, out_consts, out_keys, name,
                           is_map):
     """Takes a traced function and binds the Jaxpr to output tracers."""
     lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr)
     const_tracers = safe_map(self.new_instantiated_const, consts)
     env_tracers = safe_map(self.instantiate_const, env)
     out_tracers = [
         UnzipTracer(self, pe.PartialVal((pv, const)), None, key)
         for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
     ]
     new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
     if 'donated_invars' in params:
         new_donated_invars = (
             (False, ) * len(const_tracers) + (False, ) * len(env_tracers) +
             tuple(v for v, t in zip(params['donated_invars'], in_tracers)
                   if not t.pval.is_known()))
         new_params['donated_invars'] = new_donated_invars
     if is_map:
         out_axes = params['out_axes_thunk']()
         assert all(out_axis == 0 for out_axis in out_axes)
         new_params['out_axes'] = (0, ) * len(out_tracers)
         del new_params['out_axes_thunk']
     eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers +
                                   in_tracers), out_tracers, primitive,
                             new_params, source_info_util.current())  # pytype: disable=wrong-arg-types
     for t in out_tracers:
         t.recipe = eqn
     return out_tracers
Esempio n. 2
0
def _make_typed_jaxpr(traceable, in_avals):
    pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
    jaxpr, pvals_out, consts = pe.trace_to_jaxpr(traceable,
                                                 pvals,
                                                 instantiate=True)
    out_avals, _ = unzip2(pvals_out)
    return core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
Esempio n. 3
0
def _make_typed_jaxpr(traceable, in_avals):
    pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
    jaxpr, pval_out, consts = pe.trace_to_jaxpr(traceable,
                                                pvals,
                                                instantiate=True)
    out_aval, _ = pval_out
    assert isinstance(out_aval, core.AbstractValue)
    return core.TypedJaxpr(jaxpr, consts, in_avals, out_aval)
Esempio n. 4
0
def _initial_style_jaxpr(fun, in_tree, in_avals):
  in_pvals = [pe.PartialVal((aval, core.unit)) for aval in in_avals]
  fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True)
  out_avals = _map(raise_to_shaped, unzip2(out_pvals)[0])
  const_avals = tuple(raise_to_shaped(core.get_aval(c)) for c in consts)
  typed_jaxpr = core.TypedJaxpr(pe.closure_convert_jaxpr(jaxpr),
                                (), const_avals + in_avals, out_avals)
  return typed_jaxpr, consts, out_tree()
Esempio n. 5
0
    def start_tracing_body(self):
        """Called upon starting the tracing of the loop body."""
        # Make a copy of the current value of the mutable state
        self.carried_state_initial = copy.copy(self.scope._mutable_state)
        # The entire state is carried.
        self.carried_state_names = sorted(self.scope._mutable_state.keys())

        # TODO: This is the first part of partial_eval.trace_to_subjaxpr. Share.
        self.trace = self.scope.start_subtrace()
        # Set the scope._mutable_state to new tracing variables.
        for key, initial in self.carried_state_initial.items():
            mt_aval = _BodyTracer.abstractify(initial)
            mt_pval = pe.PartialVal((mt_aval, core.unit))
            mt_var = self.trace.new_arg(mt_pval)
            self.carried_state_vars[key] = mt_var
            self.scope._mutable_state[key] = mt_var

        index_var_aval = _BodyTracer.abstractify(0)
        index_var_pval = pe.PartialVal((index_var_aval, core.unit))
        self._index_var = self.trace.new_arg(index_var_pval)
Esempio n. 6
0
def _tscan(f, a, bs, fields=(0, )):
    """
    Works as jax.lax.scan but has additional `fields` argument to select only
    necessary fields from `a`'s structure. Defaults to selecting only the first
    field. Other fields will be filled by None.
    """
    # Note: code is copied and modified from lax.scan implementation in
    # [JAX](https://github.com/google/jax) to support the additional `fields`
    # arg. Original code has the following copyright:
    #
    # Copyright 2018 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License")

    # convert pytree to flat jaxtuple
    a, a_tree = pytree_to_flatjaxtuple(a)
    bs, b_tree = pytree_to_flatjaxtuple(bs)
    fields, _ = pytree_to_flatjaxtuple(fields)
    f, out_tree = pytree_fun_to_flatjaxtuple_fun(wrap_init(f),
                                                 (a_tree, b_tree))

    # convert arrays to abstract values
    a_aval, _ = lax._abstractify(a)
    bs_aval, _ = lax._abstractify(bs)
    # convert bs to b
    b_aval = core.AbstractTuple(
        [ShapedArray(b.shape[1:], b.dtype) for b in bs_aval])

    # convert abstract values to partial values (?) then evaluate to get jaxpr
    a_pval = partial_eval.PartialVal((a_aval, core.unit))
    b_pval = partial_eval.PartialVal((b_aval, core.unit))
    jaxpr, pval_out, consts = partial_eval.trace_to_jaxpr(f, (a_pval, b_pval))
    aval_out, _ = pval_out
    consts = core.pack(consts)

    out = tscan_p.bind(a, bs, fields, consts, aval_out=aval_out, jaxpr=jaxpr)
    return tree_unflatten(out_tree(), out)
Esempio n. 7
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     pvals = [pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals]
     jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
         flat_fun,
         pvals,
         instantiate=True,
         stage_out=True,
         trace_type=pe.StagingJaxprTrace)
     out_avals = [pval.get_aval() for pval in out_pvals]
     typed_jaxpr = jax_core.TypedJaxpr(jaxpr, consts, flat_avals, out_avals)
     return typed_jaxpr, (in_tree, out_tree())
Esempio n. 8
0
def _scan_partial_eval(trace, *tracers, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
  num_xs = len(jaxpr.in_avals) - num_carry - num_consts
  num_ys = len(jaxpr.out_avals) - num_carry

  unknowns = original_unknowns = [t.pval[0] is not None for t in tracers]
  const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])

  carry_uk = init_uk
  for _ in range(1000):
    unknowns = const_uk + carry_uk + xs_uk
    jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
        jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
    carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
    if carry_uk_out == carry_uk:
      break
    else:
      carry_uk = carry_uk_out
  else:
    raise FixedPointError

  in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)]
  new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
                 for uk, t in zip(unknowns, tracers)]

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  out_avals = carry_avals + ys_avals
  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]

  linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)]
  out_flat = scan_p.bind(
      *in_consts, forward=forward, length=length, jaxpr=jaxpr_1,
      num_consts=num_consts, num_carry=num_carry, linear=linear_1)
  out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys])
  out_consts = out_carry + ys
  residual_tracers = _map(trace.new_instantiated_const, residuals)
  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
                 for pv, const in zip(out_pvs, out_consts)]
  linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)]
              + [False] * len(residual_tracers))
  eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p,
                         (), dict(forward=forward, length=length, jaxpr=jaxpr_2,
                                  num_consts=num_consts, num_carry=num_carry,
                                  linear=linear_2))
  for t in out_tracers: t.recipe = eqn
  return out_tracers
Esempio n. 9
0
 def wrapped(*args, **kwargs):
     fun = lu.wrap_init(f, kwargs)
     flat_args, in_tree = tree_util.tree_flatten(args)
     flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
     flat_avals = safe_map(get_shaped_aval, flat_args)
     if not jax.config.omnistaging_enabled:
         raise ValueError('Oryx must be used with JAX omnistaging enabled.')
     if dynamic:
         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
     else:
         pvals = [
             pe.PartialVal((aval, jax_core.unit)) for aval in flat_avals
         ]
         jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun,
                                              pvals,
                                              instantiate=True)
     typed_jaxpr = jax_core.ClosedJaxpr(jaxpr, consts)
     return typed_jaxpr, (in_tree, out_tree())
Esempio n. 10
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    jaxpr = kwargs.pop('jaxpr')
    length = kwargs.pop('length')
    forward = kwargs.pop('forward')
    assert not kwargs
    in_pvs, _ = unzip2([t.pval for t in tracers])
    sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)

    sc_carry = sc_init
    for i in range(1000):
        second_components = (sc_consts, sc_carry, sc_xs)
        jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr,
                                                         second_components,
                                                         instantiate=(sc_carry,
                                                                      False))
        sc_carry_out, sc_ys = sc_out
        if sc_carry_out == sc_carry:
            break
        else:
            sc_carry = _binary_lattice_join(sc_carry, sc_carry_out)
    else:
        raise FixedPointError

    consts_tracer, init_tracer, xs_tracer = tracers
    lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry)
    lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer
    in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers])

    carry_aval, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)
    out_aval = core.AbstractTuple((carry_aval, ys_aval))
    out_pv = _put_known_pvs(sc_out, out_aval)

    out_carry, (ys, residuals) = scan_p.bind(*in_consts,
                                             forward=forward,
                                             length=length,
                                             jaxpr=jaxpr_1)
    out_const = core.pack((out_carry, ys))
    residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
    d, c, a = lifted_tracers
    new_tracers = (d, c, (a, residuals_tracer))
    eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False,
                        dict(forward=forward, length=length, jaxpr=jaxpr_2))
    return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
Esempio n. 11
0
    def default_process_primitive(self, primitive, tracers, params):
        """Partially evaluate primitives and saves variable recipes."""
        pvs, consts = jax_util.unzip2(t.pval for t in tracers)
        if all(pv is None for pv in pvs):
            return primitive.bind(*consts, **params)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const, tracers)
        if any(not isinstance(t, UnzipTracer) for t in tracers):
            assert False
        key = all(t.is_key() for t in tracers)
        avals = [t.aval for t in tracers]
        ans = primitive.abstract_eval(*avals, **params)
        if not primitive.multiple_results:
            ans = [ans]
        out_tracers = [
            UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key)
            for aval in ans
        ]
        # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer
        eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params,
                                source_info_util.current())  # pytype: disable=wrong-arg-types
        for t in out_tracers:
            t.recipe = eqn

        is_variable = (key and primitive is harvest.sow_p
                       and params['tag'] == settings.tag)
        # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where
        # JaxprTrace will just return out_tracers, UnzipTrace will record an
        # additional VariableRecipe into the tracers, which will be used after
        # the trace is complete to construct init/apply Jaxprs.
        if is_variable:
            name, var_in_tracers, var_out_tracers = unzip_registry[primitive](
                tracers, out_tracers, **params)
            variable_recipe = VariableRecipe(name, var_in_tracers,
                                             var_out_tracers)
            for t in out_tracers:
                t.variable_recipe = variable_recipe

        if primitive.multiple_results:
            return out_tracers
        return out_tracers[0]
Esempio n. 12
0
import numpy as onp
from absl.testing import absltest
from absl.testing import parameterized

from jax import api
from jax import core
from jax import numpy as np
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce
from jax.util import partial
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla

_ = pe.PartialVal((UnshapedArray(onp.float32), core.unit))
__ = pe.PartialVal((ShapedArray((), onp.float32), core.unit))


def call(f, *args):
    return jit(f)(*args)


def simple_fun(x, y):
    return np.sin(x * y)


def simple_fun_fanout(x, y):
    return np.sin(x * y) * x

Esempio n. 13
0
def _get_partial_value(object):
    # ShapedArrays are abstract values that carry around
    # shape and dtype information
    aval = j_abstract_arrays.ShapedArray(numpy.shape(object), numpy.result_type(object))
    result = ji_partial_eval.PartialVal((aval, j_core.unit))
    return result
Esempio n. 14
0
 def _partialize(flat_inputs):
     return map(
         lambda x: pe.PartialVal(
             (jax.raise_to_shaped(jc.get_aval(x)), jc.unit)), flat_inputs)
Esempio n. 15
0
 def pv_like(x):
     return pe.PartialVal((get_aval(x), jc.unit))
Esempio n. 16
0
 def DIABLED_test_print_jaxpr_compound(self):
     # TODO(dougalm): figure out what jaxpr-tracing api to expose and re-enable
     pv = pe.PartialVal((ShapedArray((2, 3), onp.float32), core.unit))
     print(pe.trace_to_jaxpr(fun_with_call_closure, (pv, ))[0])
Esempio n. 17
0
def _instantiated_trace_to_jaxpr(fun, avals):
    pvals = map(lambda aval: pe.PartialVal((aval, unit)), avals)
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, pvals, instantiate=True)
    out_avals, _ = unzip2(out_pvals)
    return jaxpr, out_avals, consts
Esempio n. 18
0
def pv_like(x, abstract=True):
    """Converts a JAX value type into a JAX `PartialVal`."""
    if abstract:
        return pe.PartialVal((get_shaped_aval(x), jax_core.unit))
    else:
        return pe.PartialVal((None, x))  # pytype: disable=wrong-arg-types
Esempio n. 19
0
    def test_simple_trace(self):
        def foo(x):
            return np.sin(x) + np.cos(x)

        pval = pe.PartialVal((ShapedArray((3, 2), onp.float32), core.unit))
        check_trace_eval(foo, (pval, ), (onp.random.randn(3, 2), ), pval)
Esempio n. 20
0
def scan(f, init, xs):
  """Scan a function over leading array axes while carrying along state.

  The type signature in brief is

  .. code-block:: haskell

    scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])

  where we use [t] here to denote the type t with an additional leading axis.
  That is, if t is an array type then [t] represents the type with an additional
  leading axis, and if t is a pytree (container) type with array leaves then [t]
  represents the type with the same pytree structure and corresponding leaves
  each with an additional leading axis.

  When both ``a`` and ``b`` are array types, the semantics of ``scan`` are given
  by this Python implementation::

    def scan(f, init, xs):
      carry = init
      ys = []
      for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
      return carry, np.stack(ys)

  Unlike that Python version, both ``a`` and ``b`` may be arbitrary pytree
  types, and so multiple arrays can be scanned over at once and produce multiple
  output arrays.

  Also unlike that Python version, ``scan`` is a JAX primitive and is lowered to
  a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Args:
    f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning
      that ``f`` accepts two arguments where the first is a value of the loop
      carry and the second is a slice of ``xs`` along its leading axis, and that
      ``f`` returns a pair where the first element represents a new value for
      the loop carry and the second represents a slice of the output.
    init: an initial loop carry value of type ``c``, which can be a scalar,
      array, or any pytree (nested Python tuple/list/dict) thereof, representing
      the initial loop carry value.
    xs: the value of type ``[a]`` over which to scan along the leading axis,
      where ``[a]`` can be an array or any pytree (nested Python
      tuple/list/dict) thereof with consistent leading axis sizes.

  Returns:
    A pair of type ``(c, [b])`` where the first element represents the final
    loop carry value and the second element represents the stacked outputs of
    the second output of ``f`` when scanned over the leading axis of the inputs.
  """
  (init, xs), in_trees = unzip2(map(pytree_to_jaxtupletree, (init, xs)))
  f, out_tree = pytree_fun_to_jaxtupletree_fun(lu.wrap_init(f), in_trees)
  carry_pval = carry_aval, _ = _abstractify(init)
  xs_aval, _ = _abstractify(xs)
  x_aval = _demote_aval_rank(xs_aval)
  x_pval = pe.PartialVal((x_aval, core.unit))
  jaxpr, pval_out, consts = pe.trace_to_jaxpr(
      f, (carry_pval, x_pval), instantiate=True)
  pv_out, const_out = pval_out
  assert isinstance(pv_out, core.AbstractValue) and const_out == core.unit
  if not isinstance(pv_out, core.AbstractTuple) or len(pv_out) != 2:
    msg = ("scanned function must have signature `c -> a -> (c, b)`, but the "
           "output was not a pair: got type {}.")
    raise TypeError(msg.format(pv_out))
  carry_aval_out, y_aval = pv_out
  if carry_aval != carry_aval_out:
    msg = ("scanned function carry output does not match carry input: "
           "input carry is {} and output carry is {}.")
    raise TypeError(msg.format(carry_aval, carry_aval_out))
  lifted_jaxpr = pe._closure_convert_jaxpr(jaxpr)
  consts_aval, _ = _abstractify(core.pack(consts))
  in_avals = (consts_aval, carry_aval, x_aval)
  out_aval = core.AbstractTuple((carry_aval, y_aval))
  jaxpr = core.TypedJaxpr(lifted_jaxpr, (), in_avals, out_aval)
  length = _leading_dim_size(xs)
  out = scan_p.bind(core.pack(consts), init, xs,
                    forward=True, length=length, jaxpr=jaxpr)
  return build_tree(out_tree(), out)
Esempio n. 21
0
 def pv_like(x):
   aval = ShapedArray(onp.shape(x), onp.result_type(x))
   return pe.PartialVal((aval, unit))
Esempio n. 22
0
def pv_like(x, abstract=True):
    if abstract:
        return pe.PartialVal((get_shaped_aval(x), jax_core.unit))
    else:
        return pe.PartialVal((None, x))  # pytype: disable=wrong-arg-types