Ejemplo n.º 1
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
  assert not jaxpr.constvars
  policy = params['policy'] or (lambda *_, **__: False)
  # unzip into jaxpr_known and jaxpr_unknown
  in_unknowns = [not t.is_known() for t in tracers]
  jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
      pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
  jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
  _, used_outs_unknown = partition_list(out_inst, out_unknowns)
  jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)

  # compute known outputs and residuals (hoisted out of remat primitive)
  _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
  _, in_consts = partition_list(in_used_known, in_consts_)
  out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
  out_consts_ = iter(out_consts)
  # form known outputs and collect residual tracers
  out_known_tracers = [
      pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
      for uk in out_unknowns if not uk]
  residuals = list(out_consts_)

  # set up unknown outputs with a recipe to call remat
  res_tracers = map(trace.new_instantiated_const, residuals)
  in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
  _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
  out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
                       for x in jaxpr_unknown.outvars]
  new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
  recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                             new_params, source_info_util.current())
  for t in out_jaxpr_tracers: t.recipe = recipe

  # zip together known and unknown outputs
  return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
Ejemplo n.º 2
0
 def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
                           in_tracers, out_pvs, out_consts, out_keys, name,
                           is_map):
     """Takes a traced function and binds the Jaxpr to output tracers."""
     lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr)
     const_tracers = safe_map(self.new_instantiated_const, consts)
     env_tracers = safe_map(self.instantiate_const, env)
     out_tracers = [
         UnzipTracer(self, pe.PartialVal((pv, const)), None, key)
         for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
     ]
     new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
     if 'donated_invars' in params:
         new_donated_invars = (
             (False, ) * len(const_tracers) + (False, ) * len(env_tracers) +
             tuple(v for v, t in zip(params['donated_invars'], in_tracers)
                   if not t.pval.is_known()))
         new_params['donated_invars'] = new_donated_invars
     if is_map:
         out_axes = params['out_axes_thunk']()
         assert all(out_axis == 0 for out_axis in out_axes)
         new_params['out_axes'] = (0, ) * len(out_tracers)
         del new_params['out_axes_thunk']
     eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers +
                                   in_tracers), out_tracers, primitive,
                             new_params, source_info_util.current())  # pytype: disable=wrong-arg-types
     for t in out_tracers:
         t.recipe = eqn
     return out_tracers
Ejemplo n.º 3
0
 def custom_rule(*tracers, **params):
   out_jaxpr_tracers = pe.custom_partial_eval_rules[key](*tracers,
                                                         **params)
   out_tracers = [UnzipTracer(
       out_tracer._trace, out_tracer.pval, out_tracer.recipe,  # pylint: disable=protected-access
       False, None) for out_tracer in out_jaxpr_tracers]
   for out_tracer in out_tracers:
     recipe = out_tracer.recipe
     out_tracer.recipe = pe.new_eqn_recipe(recipe.invars, out_tracers,
                                           recipe.primitive, recipe.params,
                                           recipe.source_info)  # pytype: disable=wrong-arg-types
   return out_tracers
Ejemplo n.º 4
0
def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer
Ejemplo n.º 5
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
    assert not jaxpr.constvars
    policy = params['policy'] or nothing_saveable
    in_unknowns = [not t.is_known() for t in tracers]
    jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
        pe.partial_eval_jaxpr_custom(
            jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)

    # DCE jaxpr_staged, keeping only instantiated outputs which are unknown
    _, out_inst_unknown = partition_list(out_inst, out_unknowns)
    jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged,
                                                 out_inst_unknown)
    used_res, in_used_staged = split_list(in_used_staged, [num_res])

    # DCE jaxpr_known, keeping all known outputs but discarding dce'd res
    out_used_known = [True
                      ] * (len(out_unknowns) - sum(out_unknowns)) + used_res
    jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
    num_res = sum(used_res)

    # compute known outputs and residuals (hoisted out of remat primitive)
    _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
    _, in_consts = partition_list(in_used_known, in_consts_)
    out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
    out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res])

    # set up unknown outputs with a recipe to call remat
    res_tracers = map(trace.new_instantiated_const, residuals)
    _, tracers_staged = partition_list(in_used_staged, tracers)
    in_jaxpr_tracers = res_tracers + map(trace.instantiate_const,
                                         tracers_staged)
    out_jaxpr_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
        for x in jaxpr_unknown.outvars
    ]
    new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
    recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                               new_params, jaxpr_unknown.effects,
                               source_info_util.current())
    for t in out_jaxpr_tracers:
        t.recipe = recipe

    # zip together known and unknown outputs
    return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
Ejemplo n.º 6
0
    def default_process_primitive(self, primitive, tracers, params):
        """Partially evaluate primitives and saves variable recipes."""
        pvs, consts = jax_util.unzip2(t.pval for t in tracers)
        if all(pv is None for pv in pvs):
            return primitive.bind(*consts, **params)
        settings = trace_util.get_dynamic_context(self).settings
        tracers = safe_map(self.instantiate_const, tracers)
        if any(not isinstance(t, UnzipTracer) for t in tracers):
            assert False
        key = all(t.is_key() for t in tracers)
        avals = [t.aval for t in tracers]
        ans = primitive.abstract_eval(*avals, **params)
        if not primitive.multiple_results:
            ans = [ans]
        out_tracers = [
            UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key)
            for aval in ans
        ]
        # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer
        eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params,
                                source_info_util.current())  # pytype: disable=wrong-arg-types
        for t in out_tracers:
            t.recipe = eqn

        is_variable = (key and primitive is harvest.sow_p
                       and params['tag'] == settings.tag)
        # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where
        # JaxprTrace will just return out_tracers, UnzipTrace will record an
        # additional VariableRecipe into the tracers, which will be used after
        # the trace is complete to construct init/apply Jaxprs.
        if is_variable:
            name, var_in_tracers, var_out_tracers = unzip_registry[primitive](
                tracers, out_tracers, **params)
            variable_recipe = VariableRecipe(name, var_in_tracers,
                                             var_out_tracers)
            for t in out_tracers:
                t.variable_recipe = variable_recipe

        if primitive.multiple_results:
            return out_tracers
        return out_tracers[0]
Ejemplo n.º 7
0
def _dynamic_xla_call_pe(trace, *tracers, jaxpr, num_consts):
  in_dim_tracers, tracers = split_list(tracers, [len(jaxpr.in_dim_binders)])
  if any(not t.pval.is_known() for t in in_dim_tracers):
    raise NotImplementedError
  in_unknowns = [not t.pval.is_known() for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)

  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.get_known() for t in known_tracers]
  in_dim_vals = [t.pval.get_known() for t in in_dim_tracers]
  outs1_res = dynamic_xla_call_p.bind(*in_dim_vals, *known_vals, jaxpr=jaxpr1,
                                      num_consts=num_consts)
  outs1, res = split_list(outs1_res, [len(jaxpr1.outs) - num_res])

  in_dim_tracers = map(trace.new_instantiated_const, in_dim_tracers)
  res_tracers = map(trace.new_instantiated_const, res)
  outs2 = [pe.JaxprTracer(trace, pe.PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = pe.new_eqn_recipe(in_dim_tracers + res_tracers + unknown_tracers, outs2,
                          dynamic_xla_call_p, dict(jaxpr=jaxpr2, num_consts=0),
                          None)
  for t in outs2: t.recipe = eqn
  outs1, outs2 = iter(outs1), iter(outs2)
  return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
Ejemplo n.º 8
0
def _cond_partial_eval(trace, *tracers, branches, linear):
    in_unknowns = [t.pval[0] is not None for t in tracers]
    index_uk, *ops_uk = in_unknowns

    if index_uk:
        # When the branch index is unknown, we stage out the whole cond.
        # TODO(mattjj): remove this path when old remat is removed
        params = dict(branches=branches, linear=linear)
        return trace.default_process_primitive(cond_p, tracers, params)

    branches_out_uks = []
    for branch_jaxpr in branches:
        _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr,
                                                         ops_uk,
                                                         instantiate=False)
        branches_out_uks.append(out_uks)
    out_uks = [any(uks) for uks in zip(*branches_out_uks)]

    branches_known, branches_unknown, branch_res_avals = [], [], []
    for branch_jaxpr in branches:
        branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
            pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
        branches_known.append(branch_jaxpr_known)
        branches_unknown.append(branch_jaxpr_unknown)
        branch_res_avals.append(res_avals)

    all_res_avals, res_avals_per_branch = _merge_branch_residuals(
        branch_res_avals)
    num_res = len(all_res_avals)

    num_known_outs = len(out_uks) - sum(out_uks)
    branches_known = _join_cond_outputs(branches_known, all_res_avals,
                                        res_avals_per_branch, num_known_outs)
    branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
        branches_unknown, all_res_avals, res_avals_per_branch)
    assert all(
        all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
        for j in branches_known[1:])

    in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
    linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
    out_consts_res = cond_p.bind(*in_consts,
                                 branches=branches_known,
                                 linear=tuple(linear_known))
    out_consts, res = split_list(out_consts_res,
                                 [len(out_consts_res) - num_res])

    index_tracer = trace.instantiate_const(tracers[0])
    ops_tracers = [
        trace.instantiate_const(t)
        for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk
    ]
    res_tracers = _map(trace.new_instantiated_const, res)
    out_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
        for aval in branches_unknown[0].out_avals
    ]
    linear_unknown = ([False] * num_res +
                      [l for l, uk in zip(linear, in_unknowns[1:]) if uk])
    params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
    name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
    source = source_info_util.current().replace(name_stack=name_stack)
    eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers,
                            out_tracers, cond_p, params, core.no_effects,
                            source)
    for t in out_tracers:
        t.recipe = eqn
    return util.merge_lists(out_uks, out_consts, out_tracers)