Example #1
0
def _while_callback_rule(trace, *tracers, cond_jaxpr, body_jaxpr,
                         cond_nconsts, body_nconsts):
  cond_const_tracers, body_const_tracers, init_tracers = split_list(
            tracers, [cond_nconsts, body_nconsts])
  init_avals = safe_map(lambda x: x.aval, init_tracers)
  cond_const_vals, body_const_vals, init_vals = tree_map(
      lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers))

  body_fun = jaxpr_as_fun(body_jaxpr)
  cond_fun = jaxpr_as_fun(cond_jaxpr)

  def cond(*carry):
    return cond_fun(*it.chain(cond_const_vals, carry))

  def body(*carry):
    return body_fun(*it.chain(body_const_vals, carry))

  new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls)  # type: ignore
  new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls)  # type: ignore
  in_tree = tree_structure(init_avals)

  new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals))
  new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(new_body, in_tree, tuple(init_avals))
  out = lcf.while_p.bind(
      *it.chain(new_cond_consts, new_body_consts, init_vals),
      cond_nconsts=len(new_cond_consts),
      body_nconsts=len(new_body_consts),
      cond_jaxpr=new_cond_jaxpr,
      body_jaxpr=new_body_jaxpr)
  return safe_map(trace.pure, out)
Example #2
0
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                     cond_nconsts, body_nconsts):
    """Reaps the body of a while loop to get the reaps of the final iteration."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    _, init_avals = tree_util.tree_map(lambda x: x.aval,
                                       (body_const_tracers, init_tracers))
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')
    reap_avals = {k: v['aval'] for k, v in body_metadata.items()}

    cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)

    def new_cond(carry, _):
        return cond_fun(*(cond_const_vals + carry))

    def new_body(carry, _):
        carry, reaps = call_and_reap(
            body_fun, **reap_settings)(*(body_const_vals + carry))
        return (carry, reaps)

    new_in_avals, new_in_tree = tree_util.tree_flatten(
        (init_avals, reap_avals))
    new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_cond, new_in_tree, tuple(new_in_avals))
    new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, new_in_tree, tuple(new_in_avals))
    dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype),
                                         reap_avals)
    new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals))
    out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals),
                           cond_nconsts=len(cond_consts),
                           body_nconsts=len(body_consts),
                           cond_jaxpr=new_cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_tree, out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode'])
    return out
Example #3
0
def _scan_callback_rule(trace, *tracers, reverse, length, num_consts, num_carry,
                        jaxpr, linear, unroll):
  const_tracers, carry_tracers, xs_tracers = split_list(tracers, [num_consts, num_carry])
  carry_avals, xs_avals = tree_map(lambda x: x.aval, (carry_tracers, xs_tracers))
  const_vals, carry_vals, xs_vals = tree_map(lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))

  x_tracers = [t[0] for t in xs_tracers]
  x_avals = [t.aval for t in x_tracers]

  body_fun = jaxpr_as_fun(jaxpr)

  def new_body(*vals):
    out = body_fun(*vals)
    out_carry, y = split_list(out, [num_carry])
    return out_carry, y
  new_body = callback_transform(new_body, trace.callback,
                                strip_calls=trace.strip_calls)  # type: ignore
  in_tree = tree_structure(carry_avals + xs_avals)
  new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
      new_body, in_tree, tuple(carry_avals + x_avals))
  vals = tuple(it.chain(new_consts, carry_vals, xs_vals))
  out_vals = lax.scan_p.bind(*vals, reverse=reverse, length=length,
                             num_consts=len(new_consts), num_carry=num_carry,
                             jaxpr=new_jaxpr, linear=linear, unroll=unroll)
  return safe_map(trace.pure, out_vals)
Example #4
0
File: loops.py Project: x1489/jax
 def build_output_vals(self, scope, carried_state_names, carried_tree,
                       init_vals, body_closed_jaxpr, body_const_vals):
     # Simulate a pass-through false branch
     in_vals, in_tree = tree_util.tree_flatten(
         (body_const_vals, tree_util.tree_unflatten(carried_tree,
                                                    init_vals)))
     in_avals = safe_map(_BodyTracer.abstractify, in_vals)
     pass_through_closed_jaxpr, pass_through_const_vals, _ = (
         lax_control_flow._initial_style_jaxpr(lambda *args: args[1],
                                               in_tree, tuple(in_avals)))
     assert len(pass_through_const_vals) == 0
     args = list(itertools.chain(body_const_vals, init_vals))
     return lax_control_flow.cond_p.bind(
         self.index,
         *args,
         branches=(pass_through_closed_jaxpr, body_closed_jaxpr),
         linear=(False, ) * len(args))
Example #5
0
File: loops.py Project: wayfeng/jax
    def build_output_vals(self, scope, carried_state_names, carried_tree,
                          init_vals, body_closed_jaxpr, body_const_vals):
        # Trace the conditional function. cond_func takes 0 arguments, but
        # for lax.while we need a conditional function that takes the
        # carried_state_names. _initial_style_jaxpr will start its own trace and
        # will create tracers for all the carried state. We must put these values
        # in the scope._mutable_state before we trace the conditional
        # function.
        def cond_func_wrapped(*args):
            assert len(args) == len(carried_state_names)
            for ms, init_ms in zip(carried_state_names, args):
                scope._mutable_state[ms] = init_ms
            res = self.cond_func()
            # Conditional function is not allowed to modify the scope state
            for ms, init_ms in zip(carried_state_names, args):
                if not (scope._mutable_state[ms] is init_ms):
                    raise ValueError(
                        f"Conditional function modifies scope.{ms} field.")
            return res

        init_avals = safe_map(_BodyTracer.abstractify, init_vals)
        cond_jaxpr, cond_consts, cond_tree = (
            lax_control_flow._initial_style_jaxpr(cond_func_wrapped,
                                                  carried_tree,
                                                  tuple(init_avals)))
        # TODO: share these checks with lax_control_flow.while
        if not tree_util.treedef_is_leaf(cond_tree):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got pytree {cond_tree}."
            )
        if not safe_map(core.typecompat, cond_jaxpr.out_avals,
                        [core.ShapedArray((), np.bool_)]):
            raise TypeError(
                f"cond_fun must return a boolean scalar, but got output type(s) "
                f"{cond_jaxpr.out_avals}.")

        return lax_control_flow.while_p.bind(*cond_consts,
                                             *body_const_vals,
                                             *init_vals,
                                             cond_nconsts=len(cond_consts),
                                             cond_jaxpr=cond_jaxpr,
                                             body_nconsts=len(body_const_vals),
                                             body_jaxpr=body_closed_jaxpr)
Example #6
0
def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                      cond_nconsts, body_nconsts):
    """Injects values into a while loop, overriding values for all iterations."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    init_avals = tree_util.tree_map(lambda x: x.aval, init_tracers)
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')

    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    plant_settings = dict(tag=settings.tag,
                          allowlist=settings.allowlist,
                          blocklist=settings.blocklist,
                          exclusive=settings.exclusive)
    plants = context.plants

    def new_body(*carry):
        carry = plant(body_fun,
                      **plant_settings)(plants,
                                        *(tuple(body_const_vals) + carry))
        return carry

    in_tree = tree_util.tree_structure(init_avals)
    new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, in_tree, tuple(init_avals))
    out = lcf.while_p.bind(*(cond_const_vals + new_body_consts + init_vals),
                           cond_nconsts=len(cond_const_vals),
                           body_nconsts=len(new_body_consts),
                           cond_jaxpr=cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    return jax_util.safe_map(trace.pure, out)
Example #7
0
def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                     num_consts, num_carry, linear, unroll):
    """Injects values into a scan according to their sow mode."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval,
                                               (carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    metadata = _get_harvest_metadata(
        jaxpr, settings, *(const_tracers + carry_tracers + x_tracers))

    plants = context.plants
    plant_modes = collections.defaultdict(set)
    plant_xs_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        plant_modes[mode].add(name)
        if mode == 'append' and name in plants:
            plant_xs_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)
    clobber_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['clobber']
    }
    append_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['append']
    }

    plant_xs_flat_avals, _ = tree_util.tree_flatten(plant_xs_avals)

    plant_xs_in_tree = tree_util.tree_structure(
        (carry_avals, (xs_avals, plant_xs_avals)))

    def new_body(carry, x):
        x, plants = x
        all_plants = {**plants, **clobber_plants}
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out = plant(body_fun,
                    tag=settings.tag,
                    allowlist=settings.allowlist,
                    blocklist=settings.blocklist,
                    exclusive=settings.exclusive)(all_plants, *all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        return carry_out, y

    new_body_jaxpr, consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, plant_xs_in_tree,
        tuple(carry_avals + x_avals + plant_xs_flat_avals))
    plant_vals = tree_util.tree_leaves(append_plants)
    out = lcf.scan_p.bind(*(consts + carry_vals + xs_vals + plant_vals),
                          reverse=reverse,
                          length=length,
                          jaxpr=new_body_jaxpr,
                          num_consts=len(consts),
                          num_carry=num_carry,
                          linear=linear + (False, ) * len(plant_vals),
                          unroll=unroll)
    return out
Example #8
0
def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                    num_consts, num_carry, linear, unroll):
    """Reaps the body of a scan to pull out `clobber` and `append` sows."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    _, carry_avals, xs_avals = tree_util.tree_map(
        lambda x: x.aval, (const_tracers, carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    x_vals = [t.val for t in x_tracers]
    metadata = _get_harvest_metadata(jaxpr, settings,
                                     *(const_vals + carry_vals + x_vals))

    reap_modes = collections.defaultdict(set)
    reap_carry_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        reap_modes[mode].add(name)
        if mode == 'clobber':
            reap_carry_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)

    reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals)

    reap_carry_in_tree = tree_util.tree_structure(
        ((carry_avals, reap_carry_avals), xs_avals))

    def new_body(carry, x):
        carry, _ = carry
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out, reaps = call_and_reap(body_fun,
                                   tag=settings.tag,
                                   allowlist=settings.allowlist,
                                   blocklist=settings.blocklist,
                                   exclusive=settings.exclusive)(*all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        carry_reaps = {
            name: val
            for name, val in reaps.items() if name in reap_modes['clobber']
        }
        xs_reaps = {
            name: val
            for name, val in reaps.items() if name in reap_modes['append']
        }
        return (carry_out, carry_reaps), (y, xs_reaps)

    new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, reap_carry_in_tree,
        tuple(carry_avals + reap_carry_flat_avals + x_avals))
    dummy_reap_carry_vals = tree_util.tree_map(
        lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals)
    out = lax.scan_p.bind(
        *(consts + carry_vals + dummy_reap_carry_vals + xs_vals),
        reverse=reverse,
        length=length,
        jaxpr=new_body_jaxpr,
        num_consts=len(consts),
        num_carry=len(carry_vals + dummy_reap_carry_vals),
        linear=(linear[:len(consts)] + (False, ) * len(dummy_reap_carry_vals) +
                linear[len(consts):]),
        unroll=unroll)
    (carry_out,
     carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out)
    (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map(
        trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps)))
    for k, v in {**carry_reaps, **ys_reaps}.items():
        sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k)
    return carry_out + ys
Example #9
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear, unroll):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals,
                                       *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist,
                           mode=settings.mode)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear,
                           unroll=unroll)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys