Пример #1
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
  assert not jaxpr.constvars
  policy = params['policy'] or (lambda *_, **__: False)
  # unzip into jaxpr_known and jaxpr_unknown
  in_unknowns = [not t.is_known() for t in tracers]
  jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
      pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
  jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
  _, used_outs_unknown = partition_list(out_inst, out_unknowns)
  jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)

  # compute known outputs and residuals (hoisted out of remat primitive)
  _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
  _, in_consts = partition_list(in_used_known, in_consts_)
  out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
  out_consts_ = iter(out_consts)
  # form known outputs and collect residual tracers
  out_known_tracers = [
      pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
      for uk in out_unknowns if not uk]
  residuals = list(out_consts_)

  # set up unknown outputs with a recipe to call remat
  res_tracers = map(trace.new_instantiated_const, residuals)
  in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
  _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
  out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
                       for x in jaxpr_unknown.outvars]
  new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
  recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                             new_params, source_info_util.current())
  for t in out_jaxpr_tracers: t.recipe = recipe

  # zip together known and unknown outputs
  return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
Пример #2
0
def _xla_call_partial_eval_custom_params_updater(
        unks_in: Sequence[bool], inst_in: Sequence[bool],
        kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
        num_res: int, params_known: dict,
        params_staged: dict) -> Tuple[dict, dict]:
    # pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
    donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
    new_params_known = dict(params_known, donated_invars=tuple(donated_known))
    # added num_res new inputs to jaxpr_staged, so extend donated_invars
    _, donated_staged_ = partition_list(inst_in,
                                        params_staged['donated_invars'])
    donated_staged = [False] * num_res + donated_staged_
    new_params_staged = dict(params_staged,
                             donated_invars=tuple(donated_staged))
    return new_params_known, new_params_staged
Пример #3
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
    assert not jaxpr.constvars
    policy = params['policy'] or nothing_saveable
    in_unknowns = [not t.is_known() for t in tracers]
    jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
        pe.partial_eval_jaxpr_custom(
            jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)

    # DCE jaxpr_staged, keeping only instantiated outputs which are unknown
    _, out_inst_unknown = partition_list(out_inst, out_unknowns)
    jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged,
                                                 out_inst_unknown)
    used_res, in_used_staged = split_list(in_used_staged, [num_res])

    # DCE jaxpr_known, keeping all known outputs but discarding dce'd res
    out_used_known = [True
                      ] * (len(out_unknowns) - sum(out_unknowns)) + used_res
    jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
    num_res = sum(used_res)

    # compute known outputs and residuals (hoisted out of remat primitive)
    _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
    _, in_consts = partition_list(in_used_known, in_consts_)
    out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
    out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res])

    # set up unknown outputs with a recipe to call remat
    res_tracers = map(trace.new_instantiated_const, residuals)
    _, tracers_staged = partition_list(in_used_staged, tracers)
    in_jaxpr_tracers = res_tracers + map(trace.instantiate_const,
                                         tracers_staged)
    out_jaxpr_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
        for x in jaxpr_unknown.outvars
    ]
    new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
    recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                               new_params, jaxpr_unknown.effects,
                               source_info_util.current())
    for t in out_jaxpr_tracers:
        t.recipe = recipe

    # zip together known and unknown outputs
    return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
Пример #4
0
 def tree_flatten(self):
     flat_locals, locals_tree = tree_util.tree_flatten(self.locals)
     is_valid = [
         isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray))
         for l in flat_locals
     ]
     invalid_locals, valid_locals = util.partition_list(
         is_valid, flat_locals)
     return valid_locals, (is_valid, invalid_locals, locals_tree,
                           self.filename, self.code_context, self.source,
                           self.lineno, self.offset)