Ejemplo n.º 1
0
def _reap_cond_rule(trace, *tracers, branches, linear):
    """Reaps each path of the `cond`."""
    index_tracer, ops_tracers = tracers[0], tracers[1:]
    index_val, ops_vals = tree_util.tree_map(lambda x: x.val,
                                             (index_tracer, ops_tracers))
    _, ops_avals = tree_util.tree_map(lambda x: x.aval,
                                      (index_tracer, ops_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)
    branch_metadatas = tuple(
        _get_harvest_metadata(branch, settings, *ops_tracers)
        for branch in branches)
    _check_branch_metadata(branch_metadatas)
    branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches))
    reaped_branches = tuple(
        call_and_reap(f, **reap_settings) for f in branch_funs)
    in_tree = tree_util.tree_structure(ops_avals)
    new_branch_jaxprs, consts, out_trees = (
        lcf._initial_style_jaxprs_with_common_consts(  # pylint: disable=protected-access
            reaped_branches, in_tree, ops_avals, lax.cond_p.name))
    out = lax.cond_p.bind(index_val,
                          *(tuple(consts) + ops_vals),
                          branches=tuple(new_branch_jaxprs),
                          linear=(False, ) * len(tuple(consts) + linear))
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_trees[0], out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=branch_metadatas[0][k]['mode'])
    return out
Ejemplo n.º 2
0
 def process_call(self, call_primitive: jax_core.Primitive, f: Any,
                  tracers: List['HarvestTracer'], params: Dict[str, Any]):
     context = trace_util.get_dynamic_context(self)
     if call_primitive is nest_p:
         return context.process_nest(self, f, *tracers, **params)
     return context.process_higher_order_primitive(self, call_primitive, f,
                                                   tracers, params, False)
Ejemplo n.º 3
0
def _plant_cond_rule(trace, *tracers, branches, linear):
    """Injects the same values into both branches of a conditional."""
    index_tracer, ops_tracers = tracers[0], tracers[1:]
    index_val, ops_vals = tree_util.tree_map(lambda x: x.val,
                                             (index_tracer, ops_tracers))
    ops_avals = tree_util.tree_map(lambda x: x.aval, ops_tracers)
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    plant_settings = dict(tag=settings.tag,
                          allowlist=settings.allowlist,
                          blocklist=settings.blocklist,
                          exclusive=settings.exclusive)
    branch_metadatas = tuple(
        _get_harvest_metadata(branch, settings, *ops_tracers)
        for branch in branches)
    _check_branch_metadata(branch_metadatas)
    plants = context.plants
    branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches))
    planted_branches = tuple(
        functools.partial(plant(f, **plant_settings), plants)
        for f in branch_funs)
    in_tree = tree_util.tree_structure(ops_avals)
    new_branch_jaxprs, consts, _ = (
        lcf._initial_style_jaxprs_with_common_consts(  # pylint: disable=protected-access
            planted_branches, in_tree, ops_avals, lax.cond_p.name))
    out = lax.cond_p.bind(index_val,
                          *(tuple(consts) + ops_vals),
                          branches=tuple(new_branch_jaxprs),
                          linear=(False, ) * len(tuple(consts) + linear))
    return jax_util.safe_map(trace.pure, out)
Ejemplo n.º 4
0
 def process_primitive(
     self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'],
     params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]:
   context = trace_util.get_dynamic_context(self)
   custom_rule = context.get_custom_rule(primitive)
   if custom_rule:
     return custom_rule(self, *tracers, **params)
   return self.default_process_primitive(primitive, tracers, params)
Ejemplo n.º 5
0
    def process_higher_order_primitive(self, primitive, f, tracers, params,
                                       is_map):
        name = params.pop('name', f.__name__)
        tracers = safe_map(self.instantiate_const, tracers)
        vals = [t.val for t in tracers]
        context = trace_util.get_dynamic_context(self)
        active_tag = context.settings.tag
        plants = context.plants
        if primitive is nest_p:
            plants = plants.get(params['scope'], {})
        if is_map:
            # TODO(sharadmv): figure out if invars are mapped or unmapped
            params = params.copy()
            out_axes_thunk = params['out_axes_thunk']

            @jax_util.as_hashable_function(key=('harvest', out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                assert all(out_axis == 0 for out_axis in out_axes)
                return (0, ) * out_tree().num_leaves

            new_params = dict(
                params,
                in_axes=(0, ) * len(tree_util.tree_leaves(plants)) +
                params['in_axes'],
                out_axes_thunk=new_out_axes_thunk)
        else:
            new_params = dict(params)
        all_args, all_tree = tree_util.tree_flatten((plants, vals))
        num_plants = len(all_args) - len(vals)
        if 'donated_invars' in params:
            new_params['donated_invars'] = ((False, ) * num_plants +
                                            params['donated_invars'])
        f, out_tree = harvest_eval(f, self, context.settings, all_tree)
        out_flat = primitive.bind(f,
                                  *all_args,
                                  **new_params,
                                  name=jax_util.wrap_name(name, 'harvest'))
        out, reaps = tree_util.tree_unflatten(out_tree(), out_flat)
        out_tracers = safe_map(self.pure, out)
        reap_tracers = tree_util.tree_map(self.pure, reaps)
        if primitive is nest_p and reap_tracers:
            flat_tracers, tree = tree_util.tree_flatten(reap_tracers)
            self.handle_sow(*flat_tracers,
                            name=params['scope'],
                            tag=active_tag,
                            mode='strict',
                            tree=tree)
        else:
            for name, reap_tracer in reap_tracers.items():
                flat_tracers, tree = tree_util.tree_flatten(reap_tracer)
                self.handle_sow(*flat_tracers,
                                name=name,
                                tag=active_tag,
                                mode='strict',
                                tree=tree)
        return out_tracers
Ejemplo n.º 6
0
def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                     cond_nconsts, body_nconsts):
    """Reaps the body of a while loop to get the reaps of the final iteration."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    _, init_avals = tree_util.tree_map(lambda x: x.aval,
                                       (body_const_tracers, init_tracers))
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')
    reap_avals = {k: v['aval'] for k, v in body_metadata.items()}

    cond_fun = jax_core.jaxpr_as_fun(cond_jaxpr)
    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    reap_settings = dict(tag=settings.tag,
                         allowlist=settings.allowlist,
                         blocklist=settings.blocklist,
                         exclusive=settings.exclusive)

    def new_cond(carry, _):
        return cond_fun(*(cond_const_vals + carry))

    def new_body(carry, _):
        carry, reaps = call_and_reap(
            body_fun, **reap_settings)(*(body_const_vals + carry))
        return (carry, reaps)

    new_in_avals, new_in_tree = tree_util.tree_flatten(
        (init_avals, reap_avals))
    new_cond_jaxpr, cond_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_cond, new_in_tree, tuple(new_in_avals))
    new_body_jaxpr, body_consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, new_in_tree, tuple(new_in_avals))
    dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype),
                                         reap_avals)
    new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals))
    out = lax.while_p.bind(*(cond_consts + body_consts + new_in_vals),
                           cond_nconsts=len(cond_consts),
                           body_nconsts=len(body_consts),
                           cond_jaxpr=new_cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    out = jax_util.safe_map(trace.pure, out)
    out, reaps = tree_util.tree_unflatten(out_tree, out)
    for k, v in reaps.items():
        sow(v, name=k, tag=settings.tag, mode=body_metadata[k]['mode'])
    return out
Ejemplo n.º 7
0
 def default_process_primitive(
     self, primitive: jax_core.Primitive, tracers: List['HarvestTracer'],
     params: Dict[str, Any]) -> Union['HarvestTracer', List['HarvestTracer']]:
   context = trace_util.get_dynamic_context(self)
   vals = [t.val for t in tracers]
   if primitive is sow_p:
     outvals = context.process_sow(*vals, **params)
     return jax_util.safe_map(self.pure, outvals)
   outvals = primitive.bind(*vals, **params)
   if not primitive.multiple_results:
     outvals = [outvals]
   out_tracers = jax_util.safe_map(self.pure, outvals)
   if primitive.multiple_results:
     return out_tracers
   return out_tracers[0]
Ejemplo n.º 8
0
def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr,
                      cond_nconsts, body_nconsts):
    """Injects values into a while loop, overriding values for all iterations."""
    cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list(
        tracers, [cond_nconsts, body_nconsts])
    init_avals = tree_util.tree_map(lambda x: x.aval, init_tracers)
    cond_const_vals, body_const_vals, init_vals = tree_util.tree_map(
        lambda x: x.val,
        (cond_const_tracers, body_const_tracers, init_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    body_metadata = _get_harvest_metadata(body_jaxpr, settings,
                                          *(body_const_tracers + init_tracers))
    for k, meta in body_metadata.items():
        mode = meta['mode']
        if mode != 'clobber':
            raise ValueError(
                f'Must use clobber mode for \'{k}\' inside of a `while_loop`.')

    body_fun = jax_core.jaxpr_as_fun(body_jaxpr)
    plant_settings = dict(tag=settings.tag,
                          allowlist=settings.allowlist,
                          blocklist=settings.blocklist,
                          exclusive=settings.exclusive)
    plants = context.plants

    def new_body(*carry):
        carry = plant(body_fun,
                      **plant_settings)(plants,
                                        *(tuple(body_const_vals) + carry))
        return carry

    in_tree = tree_util.tree_structure(init_avals)
    new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, in_tree, tuple(init_avals))
    out = lcf.while_p.bind(*(cond_const_vals + new_body_consts + init_vals),
                           cond_nconsts=len(cond_const_vals),
                           body_nconsts=len(new_body_consts),
                           cond_jaxpr=cond_jaxpr,
                           body_jaxpr=new_body_jaxpr)
    return jax_util.safe_map(trace.pure, out)
Ejemplo n.º 9
0
    def default_process_primitive(self, primitive, tracers, params):
        """Partially evaluate primitives and saves variable recipes."""
        pvs, consts = jax_util.unzip2(t.pval for t in tracers)
        if all(pv is None for pv in pvs):
            return primitive.bind(*consts, **params)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const, tracers)
        if any(not isinstance(t, UnzipTracer) for t in tracers):
            assert False
        key = all(t.is_key() for t in tracers)
        avals = [t.aval for t in tracers]
        ans = primitive.abstract_eval(*avals, **params)
        if not primitive.multiple_results:
            ans = [ans]
        out_tracers = [
            UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key)
            for aval in ans
        ]
        # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer
        eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params,
                                source_info_util.current())  # pytype: disable=wrong-arg-types
        for t in out_tracers:
            t.recipe = eqn

        is_variable = (key and primitive is harvest.sow_p
                       and params['tag'] == settings.tag)
        # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where
        # JaxprTrace will just return out_tracers, UnzipTrace will record an
        # additional VariableRecipe into the tracers, which will be used after
        # the trace is complete to construct init/apply Jaxprs.
        if is_variable:
            name, var_in_tracers, var_out_tracers = unzip_registry[primitive](
                tracers, out_tracers, **params)
            variable_recipe = VariableRecipe(name, var_in_tracers,
                                             var_out_tracers)
            for t in out_tracers:
                t.variable_recipe = variable_recipe

        if primitive.multiple_results:
            return out_tracers
        return out_tracers[0]
Ejemplo n.º 10
0
 def process_higher_order_primitive(self, trace, call_primitive, f, tracers,
                                    params, is_map):
     del is_map
     name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap')
     context = trace_util.get_dynamic_context(trace)
     vals = [t.val for t in tracers]
     plants = context.plants
     if 'in_axes' in params:
         # TODO(b/199459308): figure out if invars are mapped or unmapped
         params = dict(params,
                       in_axes=(0, ) * len(tree_util.tree_leaves(plants)) +
                       params['in_axes'])
     if 'donated_invars' in params:
         params = dict(params)
         params['donated_invars'] = (
             (False, ) * len(tree_util.tree_leaves(plants)) +
             params['donated_invars'])
     elif call_primitive is nest_p:
         plants = plants.get(params['scope'], {})
     all_vals, all_tree = tree_util.tree_flatten((plants, vals))
     f = plant_eval(f, trace, self.settings, all_tree)
     out_vals = call_primitive.bind(f, *all_vals, name=name, **params)
     return jax_util.safe_map(trace.pure, out_vals)
Ejemplo n.º 11
0
def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                     num_consts, num_carry, linear, unroll):
    """Injects values into a scan according to their sow mode."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval,
                                               (carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    metadata = _get_harvest_metadata(
        jaxpr, settings, *(const_tracers + carry_tracers + x_tracers))

    plants = context.plants
    plant_modes = collections.defaultdict(set)
    plant_xs_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        plant_modes[mode].add(name)
        if mode == 'append' and name in plants:
            plant_xs_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)
    clobber_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['clobber']
    }
    append_plants = {
        name: value
        for name, value in plants.items() if name in plant_modes['append']
    }

    plant_xs_flat_avals, _ = tree_util.tree_flatten(plant_xs_avals)

    plant_xs_in_tree = tree_util.tree_structure(
        (carry_avals, (xs_avals, plant_xs_avals)))

    def new_body(carry, x):
        x, plants = x
        all_plants = {**plants, **clobber_plants}
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out = plant(body_fun,
                    tag=settings.tag,
                    allowlist=settings.allowlist,
                    blocklist=settings.blocklist,
                    exclusive=settings.exclusive)(all_plants, *all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        return carry_out, y

    new_body_jaxpr, consts, _ = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, plant_xs_in_tree,
        tuple(carry_avals + x_avals + plant_xs_flat_avals))
    plant_vals = tree_util.tree_leaves(append_plants)
    out = lcf.scan_p.bind(*(consts + carry_vals + xs_vals + plant_vals),
                          reverse=reverse,
                          length=length,
                          jaxpr=new_body_jaxpr,
                          num_consts=len(consts),
                          num_carry=num_carry,
                          linear=linear + (False, ) * len(plant_vals),
                          unroll=unroll)
    return out
Ejemplo n.º 12
0
def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                    num_consts, num_carry, linear, unroll):
    """Reaps the body of a scan to pull out `clobber` and `append` sows."""

    const_tracers, carry_tracers, xs_tracers = jax_util.split_list(
        tracers, [num_consts, num_carry])
    _, carry_avals, xs_avals = tree_util.tree_map(
        lambda x: x.aval, (const_tracers, carry_tracers, xs_tracers))
    const_vals, carry_vals, xs_vals = tree_util.tree_map(
        lambda x: x.val, (const_tracers, carry_tracers, xs_tracers))
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers]
    x_avals = [t.aval for t in x_tracers]
    x_vals = [t.val for t in x_tracers]
    metadata = _get_harvest_metadata(jaxpr, settings,
                                     *(const_vals + carry_vals + x_vals))

    reap_modes = collections.defaultdict(set)
    reap_carry_avals = {}
    for name, meta in metadata.items():
        mode = meta['mode']
        aval = meta['aval']
        if mode == 'strict':
            raise ValueError(
                f'Cannot use strict mode for \'{name}\' inside `scan`.')
        reap_modes[mode].add(name)
        if mode == 'clobber':
            reap_carry_avals[name] = aval
    body_fun = jax_core.jaxpr_as_fun(jaxpr)

    reap_carry_flat_avals, _ = tree_util.tree_flatten(reap_carry_avals)

    reap_carry_in_tree = tree_util.tree_structure(
        ((carry_avals, reap_carry_avals), xs_avals))

    def new_body(carry, x):
        carry, _ = carry
        all_values = const_vals + tree_util.tree_leaves((carry, x))
        out, reaps = call_and_reap(body_fun,
                                   tag=settings.tag,
                                   allowlist=settings.allowlist,
                                   blocklist=settings.blocklist,
                                   exclusive=settings.exclusive)(*all_values)
        carry_out, y = jax_util.split_list(out, [num_carry])
        carry_reaps = {
            name: val
            for name, val in reaps.items() if name in reap_modes['clobber']
        }
        xs_reaps = {
            name: val
            for name, val in reaps.items() if name in reap_modes['append']
        }
        return (carry_out, carry_reaps), (y, xs_reaps)

    new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr(  # pylint: disable=protected-access
        new_body, reap_carry_in_tree,
        tuple(carry_avals + reap_carry_flat_avals + x_avals))
    dummy_reap_carry_vals = tree_util.tree_map(
        lambda x: jnp.zeros(x.shape, x.dtype), reap_carry_flat_avals)
    out = lax.scan_p.bind(
        *(consts + carry_vals + dummy_reap_carry_vals + xs_vals),
        reverse=reverse,
        length=length,
        jaxpr=new_body_jaxpr,
        num_consts=len(consts),
        num_carry=len(carry_vals + dummy_reap_carry_vals),
        linear=(linear[:len(consts)] + (False, ) * len(dummy_reap_carry_vals) +
                linear[len(consts):]),
        unroll=unroll)
    (carry_out,
     carry_reaps), (ys, ys_reaps) = tree_util.tree_unflatten(out_tree, out)
    (carry_out, carry_reaps), (ys, ys_reaps) = tree_util.tree_map(
        trace.pure, ((carry_out, carry_reaps), (ys, ys_reaps)))
    for k, v in {**carry_reaps, **ys_reaps}.items():
        sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k)
    return carry_out + ys
Ejemplo n.º 13
0
 def process_map(self, call_primitive: jax_core.Primitive, f: Any,
                 tracers: List['HarvestTracer'], params: Dict[str, Any]):
     context = trace_util.get_dynamic_context(self)
     return context.process_higher_order_primitive(self, call_primitive, f,
                                                   tracers, params, True)
Ejemplo n.º 14
0
    def handle_call_primitive(self, call_primitive, f, tracers, params,
                              is_map):
        """Handler for call_primitives, like jit or layer_call.

    When an UnzipTracer hits a call primitive, there is either a variable
    inside of the call primitive, in which case the input
    function needs to be unzipped into two, or there are no variables
    in the function, so the call_primitive is recorded in the trace as-is.

    We use `unzip_eval_wrapper`, which returns whether or not an unzip
    was successful or not. If it was successful, we record two new
    Jaxprs into the trace (one for init, one for apply). Otherwise, we
    just record the Jaxpr corresponding to the function call.

    Args:
      call_primitive: a call primitive like xla_call
      f: a jax.linear_util wrapped function to be called
      tracers: inputs to the function
      params: parameters of the primitives
      is_map: whether or not the primitive is a map primitive (e.g. xla_pmap)

    Returns:
      A list of output tracers
    """
        name = params.get('name', f.__name__)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const_abstracted, tracers)
        if call_primitive in current_custom_rules():
            return current_custom_rules()[call_primitive](self, f, *tracers,
                                                          **params)
        if call_primitive in pe.call_partial_eval_rules:
            raise NotImplementedError
        in_pvals = [t.pval for t in tracers]
        if is_map:
            unknown = pe.PartialVal.unknown
            in_pvals = [
                pval if pval.is_known() or in_axis is None else unknown(
                    mapped_aval(params['axis_size'], in_axis, pval[0]))
                for pval, in_axis in zip(in_pvals, params['in_axes'])
            ]
            out_axes_thunk = params['out_axes_thunk']

            @jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                assert all(out_axis == 0 for out_axis in out_axes)
                _, num_outputs, _ = aux()
                return (0, ) * num_outputs

            new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
        else:
            new_params = params
        pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
        keys = tuple(t.is_key() for t in tracers)
        new_settings = UnzipSettings(settings.tag, call_primitive
                                     in block_registry)
        fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
        out_flat = call_primitive.bind(fun, *in_consts, **new_params)
        success, _, results = aux()
        if not success:
            out_pvs, out_keys, jaxpr, env = results
            out_pv_consts, consts = jax_util.split_list(
                out_flat, [len(out_pvs)])
            out_tracers = self._bound_output_tracers(call_primitive,
                                                     new_params, jaxpr, consts,
                                                     env, tracers, out_pvs,
                                                     out_pv_consts, out_keys,
                                                     name, is_map)
            return out_tracers
        init_name = jax_util.wrap_name(name, 'init')
        apply_name = jax_util.wrap_name(name, 'apply')
        init_pvs, num_init_consts, apply_pvs = results[0]
        init_jaxpr, apply_jaxpr = results[1]
        init_env, apply_env = results[2]
        variable_names, variable_tree, apply_keys = results[3]

        key_tracers = [t for t in tracers if t.is_key()]
        abstract_tracers = [t for t in tracers if not t.is_key()]
        all_init_consts, all_apply_consts = jax_util.split_list(
            out_flat, [len(init_pvs) + num_init_consts])
        init_pv_consts, init_consts = jax_util.split_list(
            all_init_consts, [len(init_pvs)])
        apply_pv_consts, apply_consts = jax_util.split_list(
            all_apply_consts, [len(apply_pvs)])

        variable_tracers = self._bound_output_tracers(
            call_primitive, new_params, init_jaxpr, init_consts, init_env,
            key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
            init_name, is_map)

        unflat_variables = tree_util.tree_unflatten(variable_tree,
                                                    variable_tracers)
        if call_primitive is harvest.nest_p:
            variable_dict = harvest.sow(dict(
                safe_zip(variable_names, unflat_variables)),
                                        tag=settings.tag,
                                        name=new_params['scope'],
                                        mode='strict')
            unflat_variables = tuple(variable_dict[name]
                                     for name in variable_names)
        else:
            unflat_variables = [
                harvest.sow(  # pylint: disable=g-complex-comprehension
                    unflat_variable,
                    tag=settings.tag,
                    name=name,
                    mode='strict') for unflat_variable, name in safe_zip(
                        unflat_variables, variable_names)
            ]
        variable_tracers = tree_util.tree_leaves(unflat_variables)

        out_tracers = self._bound_output_tracers(
            call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
            variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
            apply_keys, apply_name, is_map)
        return out_tracers
Ejemplo n.º 15
0
def _scan_harvest_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr,
                       num_consts, num_carry, linear, unroll):
    """Collects and injects values into/from the scan body."""
    context = trace_util.get_dynamic_context(trace)
    settings = context.settings
    values = [t.val for t in tracers]
    consts, init, xs = jax_util.split_list(values, [num_consts, num_carry])

    active_sows = _find_sows(jaxpr, settings.tag)
    active_modes = [params['mode'] for params in active_sows]
    if any(mode == 'strict' for mode in active_modes):
        raise ValueError('Cannot use strict mode in a scan.')
    active_names = [params['name'] for params in active_sows]
    sow_modes = {name: mode for name, mode in zip(active_names, active_modes)}
    carry_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'clobber'
    }
    xs_plants = {
        name: context.plants[name]
        for name in active_names
        if name in context.plants and sow_modes[name] == 'append'
    }

    def jaxpr_fun(carry, x):
        body_out = jax_core.eval_jaxpr(jaxpr.jaxpr, jaxpr.literals,
                                       *(consts + carry + x))
        carry, y = jax_util.split_list(body_out, [num_carry])
        return carry, y

    harvest_body = harvest(jaxpr_fun,
                           tag=settings.tag,
                           allowlist=settings.allowlist,
                           blocklist=settings.blocklist)

    def body(carry, x):
        x_plants, x_vals = x
        (carry, y), reaps = harvest_body({
            **carry_plants,
            **x_plants
        }, carry, x_vals)
        return carry, (y, reaps)

    xs_flat = tree_util.tree_leaves((xs_plants, xs))
    x_avals = []
    for x in xs_flat:
        x_aval = jax_core.get_aval(x)
        if x_aval is jax_core.abstract_unit:
            x_avals.append(x_aval)
        else:
            x_shape, x_dtype = masking.padded_shape_as_value(
                x.shape[1:]), x.dtype
            x_avals.append(abstract_arrays.ShapedArray(x_shape, x_dtype))
    x_avals = tuple(x_avals)
    init_avals = tuple(
        abstract_arrays.raise_to_shaped(jax_core.get_aval(a)) for a in init)
    in_flat, in_tree = tree_util.tree_flatten((init, (xs_plants, xs)))
    body_jaxpr, new_consts, out_tree = (
        jax.lax.lax_control_flow._initial_style_jaxpr(  # pylint: disable=protected-access
            body, in_tree, init_avals + x_avals))
    new_values = list(new_consts) + in_flat
    num_xs_plants = len(new_values) - len(init) - len(xs) - len(new_consts)
    remaining_linear = linear[num_consts:]
    new_linear = ((False, ) * len(new_consts) + remaining_linear[:len(init)] +
                  (False, ) * num_xs_plants + remaining_linear[len(init):])
    assert len(new_linear) == len(new_values)

    outs = lax.scan_p.bind(*new_values,
                           length=length,
                           reverse=reverse,
                           jaxpr=body_jaxpr,
                           num_consts=len(new_consts),
                           num_carry=num_carry,
                           linear=new_linear,
                           unroll=unroll)
    outs = safe_map(trace.pure, outs)
    carry, (ys, reaps) = tree_util.tree_unflatten(out_tree, outs)
    out_reaps = {}
    for k, val in reaps.items():
        mode = sow_modes.get(k, 'strict')
        if mode == 'append':
            val = tree_util.tree_map(np.concatenate, val)
        elif mode == 'clobber':
            val = tree_util.tree_map(lambda x: x[-1], val)
        out_reaps[k] = sow(val, tag=settings.tag, name=k, mode='strict')
    (carry, ys) = prim.tie_in(out_reaps, (carry, ys))
    return carry + ys
Ejemplo n.º 16
0
 def handle_sow(self, *tracers, name, tag, mode, tree):
     vals = [t.val for t in tracers]
     context = trace_util.get_dynamic_context(self)
     return safe_map(
         self.pure,
         context.handle_sow(vals, name=name, tag=tag, mode=mode, tree=tree))