Exemplo n.º 1
0
def _nest_translation_rule(*args, backend, name, call_jaxpr, scope, **_):
    return xla._xla_call_translation_rule(  # pylint: disable=protected-access
        *args,
        name=jax_util.wrap_name(name, f'nest[{scope}]'),
        backend=backend,
        call_jaxpr=call_jaxpr,
        donated_invars=(False, ) * len(args))
Exemplo n.º 2
0
def _nest_lowering(ctx, *args, name, call_jaxpr, scope, **_):
    return mlir._xla_call_lower(  # pylint: disable=protected-access
        ctx,
        *args,
        name=jax_util.wrap_name(name, f'nest[{scope}]'),
        call_jaxpr=call_jaxpr,
        donated_invars=(False, ) * len(args))
Exemplo n.º 3
0
    def process_higher_order_primitive(self, primitive, f, tracers, params,
                                       is_map):
        name = params.pop('name', f.__name__)
        tracers = safe_map(self.instantiate_const, tracers)
        vals = [t.val for t in tracers]
        context = trace_util.get_dynamic_context(self)
        active_tag = context.settings.tag
        plants = context.plants
        if primitive is nest_p:
            plants = plants.get(params['scope'], {})
        if is_map:
            # TODO(sharadmv): figure out if invars are mapped or unmapped
            params = params.copy()
            out_axes_thunk = params['out_axes_thunk']

            @jax_util.as_hashable_function(key=('harvest', out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                assert all(out_axis == 0 for out_axis in out_axes)
                return (0, ) * out_tree().num_leaves

            new_params = dict(
                params,
                in_axes=(0, ) * len(tree_util.tree_leaves(plants)) +
                params['in_axes'],
                out_axes_thunk=new_out_axes_thunk)
        else:
            new_params = dict(params)
        all_args, all_tree = tree_util.tree_flatten((plants, vals))
        num_plants = len(all_args) - len(vals)
        if 'donated_invars' in params:
            new_params['donated_invars'] = ((False, ) * num_plants +
                                            params['donated_invars'])
        f, out_tree = harvest_eval(f, self, context.settings, all_tree)
        out_flat = primitive.bind(f,
                                  *all_args,
                                  **new_params,
                                  name=jax_util.wrap_name(name, 'harvest'))
        out, reaps = tree_util.tree_unflatten(out_tree(), out_flat)
        out_tracers = safe_map(self.pure, out)
        reap_tracers = tree_util.tree_map(self.pure, reaps)
        if primitive is nest_p and reap_tracers:
            flat_tracers, tree = tree_util.tree_flatten(reap_tracers)
            self.handle_sow(*flat_tracers,
                            name=params['scope'],
                            tag=active_tag,
                            mode='strict',
                            tree=tree)
        else:
            for name, reap_tracer in reap_tracers.items():
                flat_tracers, tree = tree_util.tree_flatten(reap_tracer)
                self.handle_sow(*flat_tracers,
                                name=name,
                                tag=active_tag,
                                mode='strict',
                                tree=tree)
        return out_tracers
Exemplo n.º 4
0
 def process_call(self, call_primitive, f, tracers, params):
   assert call_primitive.multiple_results
   heads, tails = unzip2((t.head, t.tail) for t in tracers)
   nonzero_tails, in_tree_def = tree_flatten(tails)
   f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main),
                                         len(heads), in_tree_def)
   name = params.get('name', f.__name__)
   new_params = dict(params, name=wrap_name(name, 'doubledouble'),
                     donated_invars=(False,) * (len(heads) + len(nonzero_tails)))
   result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params)
   heads_out, tails_out = tree_unflatten(out_tree_def(), result)
   return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)]
Exemplo n.º 5
0
 def process_higher_order_primitive(self, trace, call_primitive, f, tracers,
                                    params, is_map):
     del is_map
     name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap')
     context = trace_util.get_dynamic_context(trace)
     vals = [t.val for t in tracers]
     plants = context.plants
     if 'in_axes' in params:
         # TODO(b/199459308): figure out if invars are mapped or unmapped
         params = dict(params,
                       in_axes=(0, ) * len(tree_util.tree_leaves(plants)) +
                       params['in_axes'])
     if 'donated_invars' in params:
         params = dict(params)
         params['donated_invars'] = (
             (False, ) * len(tree_util.tree_leaves(plants)) +
             params['donated_invars'])
     elif call_primitive is nest_p:
         plants = plants.get(params['scope'], {})
     all_vals, all_tree = tree_util.tree_flatten((plants, vals))
     f = plant_eval(f, trace, self.settings, all_tree)
     out_vals = call_primitive.bind(f, *all_vals, name=name, **params)
     return jax_util.safe_map(trace.pure, out_vals)
Exemplo n.º 6
0
    def reap_higher_order_primitive(self, trace, call_primitive, f, tracers,
                                    params, is_map):
        """Wraps the inner function with a reap trace."""
        name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap')
        vals = [t.val for t in tracers]
        f, aux = reap_eval(f, trace, self.settings)

        if is_map:
            out_axes_thunk = params['out_axes_thunk']

            @jax_util.as_hashable_function(closure=('harvest', out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                assert all(out_axis == 0 for out_axis in out_axes)
                out_tree, _ = aux()
                return (0, ) * out_tree.num_leaves

            params = dict(params, out_axes_thunk=new_out_axes_thunk)
        out_flat = call_primitive.bind(f, *vals, name=name, **params)
        out_tree, metadata = aux()
        out_vals, reaps = tree_util.tree_unflatten(out_tree, out_flat)
        out_tracers = jax_util.safe_map(trace.pure, out_vals)
        reap_tracers = tree_util.tree_map(trace.pure, reaps)
        return out_tracers, reap_tracers, metadata
Exemplo n.º 7
0
    def handle_call_primitive(self, call_primitive, f, tracers, params,
                              is_map):
        """Handler for call_primitives, like jit or layer_call.

    When an UnzipTracer hits a call primitive, there is either a variable
    inside of the call primitive, in which case the input
    function needs to be unzipped into two, or there are no variables
    in the function, so the call_primitive is recorded in the trace as-is.

    We use `unzip_eval_wrapper`, which returns whether or not an unzip
    was successful or not. If it was successful, we record two new
    Jaxprs into the trace (one for init, one for apply). Otherwise, we
    just record the Jaxpr corresponding to the function call.

    Args:
      call_primitive: a call primitive like xla_call
      f: a jax.linear_util wrapped function to be called
      tracers: inputs to the function
      params: parameters of the primitives
      is_map: whether or not the primitive is a map primitive (e.g. xla_pmap)

    Returns:
      A list of output tracers
    """
        name = params.get('name', f.__name__)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const_abstracted, tracers)
        if call_primitive in current_custom_rules():
            return current_custom_rules()[call_primitive](self, f, *tracers,
                                                          **params)
        if call_primitive in pe.call_partial_eval_rules:
            raise NotImplementedError
        in_pvals = [t.pval for t in tracers]
        if is_map:
            unknown = pe.PartialVal.unknown
            in_pvals = [
                pval if pval.is_known() or in_axis is None else unknown(
                    mapped_aval(params['axis_size'], in_axis, pval[0]))
                for pval, in_axis in zip(in_pvals, params['in_axes'])
            ]
            out_axes_thunk = params['out_axes_thunk']

            @jax_util.as_hashable_function(closure=('unzip', out_axes_thunk))
            def new_out_axes_thunk():
                out_axes = out_axes_thunk()
                assert all(out_axis == 0 for out_axis in out_axes)
                _, num_outputs, _ = aux()
                return (0, ) * num_outputs

            new_params = dict(params, out_axes_thunk=new_out_axes_thunk)
        else:
            new_params = params
        pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
        keys = tuple(t.is_key() for t in tracers)
        new_settings = UnzipSettings(settings.tag, call_primitive
                                     in block_registry)
        fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
        out_flat = call_primitive.bind(fun, *in_consts, **new_params)
        success, _, results = aux()
        if not success:
            out_pvs, out_keys, jaxpr, env = results
            out_pv_consts, consts = jax_util.split_list(
                out_flat, [len(out_pvs)])
            out_tracers = self._bound_output_tracers(call_primitive,
                                                     new_params, jaxpr, consts,
                                                     env, tracers, out_pvs,
                                                     out_pv_consts, out_keys,
                                                     name, is_map)
            return out_tracers
        init_name = jax_util.wrap_name(name, 'init')
        apply_name = jax_util.wrap_name(name, 'apply')
        init_pvs, num_init_consts, apply_pvs = results[0]
        init_jaxpr, apply_jaxpr = results[1]
        init_env, apply_env = results[2]
        variable_names, variable_tree, apply_keys = results[3]

        key_tracers = [t for t in tracers if t.is_key()]
        abstract_tracers = [t for t in tracers if not t.is_key()]
        all_init_consts, all_apply_consts = jax_util.split_list(
            out_flat, [len(init_pvs) + num_init_consts])
        init_pv_consts, init_consts = jax_util.split_list(
            all_init_consts, [len(init_pvs)])
        apply_pv_consts, apply_consts = jax_util.split_list(
            all_apply_consts, [len(apply_pvs)])

        variable_tracers = self._bound_output_tracers(
            call_primitive, new_params, init_jaxpr, init_consts, init_env,
            key_tracers, init_pvs, init_pv_consts, [True] * len(init_pvs),
            init_name, is_map)

        unflat_variables = tree_util.tree_unflatten(variable_tree,
                                                    variable_tracers)
        if call_primitive is harvest.nest_p:
            variable_dict = harvest.sow(dict(
                safe_zip(variable_names, unflat_variables)),
                                        tag=settings.tag,
                                        name=new_params['scope'],
                                        mode='strict')
            unflat_variables = tuple(variable_dict[name]
                                     for name in variable_names)
        else:
            unflat_variables = [
                harvest.sow(  # pylint: disable=g-complex-comprehension
                    unflat_variable,
                    tag=settings.tag,
                    name=name,
                    mode='strict') for unflat_variable, name in safe_zip(
                        unflat_variables, variable_names)
            ]
        variable_tracers = tree_util.tree_leaves(unflat_variables)

        out_tracers = self._bound_output_tracers(
            call_primitive, new_params, apply_jaxpr, apply_consts, apply_env,
            variable_tracers + abstract_tracers, apply_pvs, apply_pv_consts,
            apply_keys, apply_name, is_map)
        return out_tracers