Пример #1
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
  assert not jaxpr.constvars
  policy = params['policy'] or (lambda *_, **__: False)
  # unzip into jaxpr_known and jaxpr_unknown
  in_unknowns = [not t.is_known() for t in tracers]
  jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
      pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
  jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
  _, used_outs_unknown = partition_list(out_inst, out_unknowns)
  jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)

  # compute known outputs and residuals (hoisted out of remat primitive)
  _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
  _, in_consts = partition_list(in_used_known, in_consts_)
  out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
  out_consts_ = iter(out_consts)
  # form known outputs and collect residual tracers
  out_known_tracers = [
      pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
      for uk in out_unknowns if not uk]
  residuals = list(out_consts_)

  # set up unknown outputs with a recipe to call remat
  res_tracers = map(trace.new_instantiated_const, residuals)
  in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
  _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
  out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
                       for x in jaxpr_unknown.outvars]
  new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
  recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                             new_params, source_info_util.current())
  for t in out_jaxpr_tracers: t.recipe = recipe

  # zip together known and unknown outputs
  return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
Пример #2
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
    assert not jaxpr.constvars
    policy = params['policy'] or nothing_saveable
    in_unknowns = [not t.is_known() for t in tracers]
    jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
        pe.partial_eval_jaxpr_custom(
            jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)

    # DCE jaxpr_staged, keeping only instantiated outputs which are unknown
    _, out_inst_unknown = partition_list(out_inst, out_unknowns)
    jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged,
                                                 out_inst_unknown)
    used_res, in_used_staged = split_list(in_used_staged, [num_res])

    # DCE jaxpr_known, keeping all known outputs but discarding dce'd res
    out_used_known = [True
                      ] * (len(out_unknowns) - sum(out_unknowns)) + used_res
    jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
    num_res = sum(used_res)

    # compute known outputs and residuals (hoisted out of remat primitive)
    _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
    _, in_consts = partition_list(in_used_known, in_consts_)
    out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
    out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res])

    # set up unknown outputs with a recipe to call remat
    res_tracers = map(trace.new_instantiated_const, residuals)
    _, tracers_staged = partition_list(in_used_staged, tracers)
    in_jaxpr_tracers = res_tracers + map(trace.instantiate_const,
                                         tracers_staged)
    out_jaxpr_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
        for x in jaxpr_unknown.outvars
    ]
    new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
    recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                               new_params, jaxpr_unknown.effects,
                               source_info_util.current())
    for t in out_jaxpr_tracers:
        t.recipe = recipe

    # zip together known and unknown outputs
    return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
Пример #3
0
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
              ) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
  new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
  new_params = dict(eqn.params, jaxpr=new_jaxpr)
  if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
    return used_inputs, None
  else:
    new_eqn = pe.new_jaxpr_eqn(
        [v for v, used in zip(eqn.invars, used_inputs) if used],
        [v for v, used in zip(eqn.outvars, used_outputs) if used],
        eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
    return used_inputs, new_eqn