예제 #1
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
  assert not jaxpr.constvars
  policy = params['policy'] or (lambda *_, **__: False)
  # unzip into jaxpr_known and jaxpr_unknown
  in_unknowns = [not t.is_known() for t in tracers]
  jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
      pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
  jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
  _, used_outs_unknown = partition_list(out_inst, out_unknowns)
  jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)

  # compute known outputs and residuals (hoisted out of remat primitive)
  _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
  _, in_consts = partition_list(in_used_known, in_consts_)
  out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
  out_consts_ = iter(out_consts)
  # form known outputs and collect residual tracers
  out_known_tracers = [
      pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
      for uk in out_unknowns if not uk]
  residuals = list(out_consts_)

  # set up unknown outputs with a recipe to call remat
  res_tracers = map(trace.new_instantiated_const, residuals)
  in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
  _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
  out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
                       for x in jaxpr_unknown.outvars]
  new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
  recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                             new_params, source_info_util.current())
  for t in out_jaxpr_tracers: t.recipe = recipe

  # zip together known and unknown outputs
  return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
예제 #2
0
def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer
예제 #3
0
def _scan_partial_eval(trace, *tracers, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
  num_xs = len(jaxpr.in_avals) - num_carry - num_consts
  num_ys = len(jaxpr.out_avals) - num_carry

  unknowns = original_unknowns = [t.pval[0] is not None for t in tracers]
  const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])

  carry_uk = init_uk
  for _ in range(1000):
    unknowns = const_uk + carry_uk + xs_uk
    jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
        jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
    carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
    if carry_uk_out == carry_uk:
      break
    else:
      carry_uk = carry_uk_out
  else:
    raise FixedPointError

  in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)]
  new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
                 for uk, t in zip(unknowns, tracers)]

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  out_avals = carry_avals + ys_avals
  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]

  linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)]
  out_flat = scan_p.bind(
      *in_consts, forward=forward, length=length, jaxpr=jaxpr_1,
      num_consts=num_consts, num_carry=num_carry, linear=linear_1)
  out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys])
  out_consts = out_carry + ys
  residual_tracers = _map(trace.new_instantiated_const, residuals)
  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
                 for pv, const in zip(out_pvs, out_consts)]
  linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)]
              + [False] * len(residual_tracers))
  eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p,
                         (), dict(forward=forward, length=length, jaxpr=jaxpr_2,
                                  num_consts=num_consts, num_carry=num_carry,
                                  linear=linear_2))
  for t in out_tracers: t.recipe = eqn
  return out_tracers
예제 #4
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
    assert not jaxpr.constvars
    policy = params['policy'] or nothing_saveable
    in_unknowns = [not t.is_known() for t in tracers]
    jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \
        pe.partial_eval_jaxpr_custom(
            jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy)

    # DCE jaxpr_staged, keeping only instantiated outputs which are unknown
    _, out_inst_unknown = partition_list(out_inst, out_unknowns)
    jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged,
                                                 out_inst_unknown)
    used_res, in_used_staged = split_list(in_used_staged, [num_res])

    # DCE jaxpr_known, keeping all known outputs but discarding dce'd res
    out_used_known = [True
                      ] * (len(out_unknowns) - sum(out_unknowns)) + used_res
    jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known)
    num_res = sum(used_res)

    # compute known outputs and residuals (hoisted out of remat primitive)
    _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
    _, in_consts = partition_list(in_used_known, in_consts_)
    out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
    out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res])

    # set up unknown outputs with a recipe to call remat
    res_tracers = map(trace.new_instantiated_const, residuals)
    _, tracers_staged = partition_list(in_used_staged, tracers)
    in_jaxpr_tracers = res_tracers + map(trace.instantiate_const,
                                         tracers_staged)
    out_jaxpr_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
        for x in jaxpr_unknown.outvars
    ]
    new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
    recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                               new_params, jaxpr_unknown.effects,
                               source_info_util.current())
    for t in out_jaxpr_tracers:
        t.recipe = recipe

    # zip together known and unknown outputs
    return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
예제 #5
0
def _scan_partial_eval(trace, *tracers, **kwargs):
    jaxpr = kwargs.pop('jaxpr')
    length = kwargs.pop('length')
    forward = kwargs.pop('forward')
    assert not kwargs
    in_pvs, _ = unzip2([t.pval for t in tracers])
    sc_consts, sc_init, sc_xs = map(pe.unknown, in_pvs)

    sc_carry = sc_init
    for i in range(1000):
        second_components = (sc_consts, sc_carry, sc_xs)
        jaxpr_1, jaxpr_2, sc_out = pe.partial_eval_jaxpr(jaxpr,
                                                         second_components,
                                                         instantiate=(sc_carry,
                                                                      False))
        sc_carry_out, sc_ys = sc_out
        if sc_carry_out == sc_carry:
            break
        else:
            sc_carry = _binary_lattice_join(sc_carry, sc_carry_out)
    else:
        raise FixedPointError

    consts_tracer, init_tracer, xs_tracer = tracers
    lifted_init_tracer = _lift_tracer(trace, init_tracer, sc_carry)
    lifted_tracers = consts_tracer, lifted_init_tracer, xs_tracer
    in_pvs, in_consts = unzip2([t.pval for t in lifted_tracers])

    carry_aval, y_aval = jaxpr.out_aval
    ys_aval = _promote_aval_rank(length, y_aval)
    out_aval = core.AbstractTuple((carry_aval, ys_aval))
    out_pv = _put_known_pvs(sc_out, out_aval)

    out_carry, (ys, residuals) = scan_p.bind(*in_consts,
                                             forward=forward,
                                             length=length,
                                             jaxpr=jaxpr_1)
    out_const = core.pack((out_carry, ys))
    residuals_tracer = trace.new_instantiated_const(core.pack(residuals))
    d, c, a = lifted_tracers
    new_tracers = (d, c, (a, residuals_tracer))
    eqn = core.JaxprEqn(new_tracers, None, scan_p, (), True, False,
                        dict(forward=forward, length=length, jaxpr=jaxpr_2))
    return pe.JaxprTracer(trace, pe.PartialVal((out_pv, out_const)), eqn)
예제 #6
0
def _dynamic_xla_call_pe(trace, *tracers, jaxpr, num_consts):
  in_dim_tracers, tracers = split_list(tracers, [len(jaxpr.in_dim_binders)])
  if any(not t.pval.is_known() for t in in_dim_tracers):
    raise NotImplementedError
  in_unknowns = [not t.pval.is_known() for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)

  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.get_known() for t in known_tracers]
  in_dim_vals = [t.pval.get_known() for t in in_dim_tracers]
  outs1_res = dynamic_xla_call_p.bind(*in_dim_vals, *known_vals, jaxpr=jaxpr1,
                                      num_consts=num_consts)
  outs1, res = split_list(outs1_res, [len(jaxpr1.outs) - num_res])

  in_dim_tracers = map(trace.new_instantiated_const, in_dim_tracers)
  res_tracers = map(trace.new_instantiated_const, res)
  outs2 = [pe.JaxprTracer(trace, pe.PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = pe.new_eqn_recipe(in_dim_tracers + res_tracers + unknown_tracers, outs2,
                          dynamic_xla_call_p, dict(jaxpr=jaxpr2, num_consts=0),
                          None)
  for t in outs2: t.recipe = eqn
  outs1, outs2 = iter(outs1), iter(outs2)
  return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
예제 #7
0
def _cond_partial_eval(trace, *tracers, branches, linear):
    in_unknowns = [t.pval[0] is not None for t in tracers]
    index_uk, *ops_uk = in_unknowns

    if index_uk:
        # When the branch index is unknown, we stage out the whole cond.
        # TODO(mattjj): remove this path when old remat is removed
        params = dict(branches=branches, linear=linear)
        return trace.default_process_primitive(cond_p, tracers, params)

    branches_out_uks = []
    for branch_jaxpr in branches:
        _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr,
                                                         ops_uk,
                                                         instantiate=False)
        branches_out_uks.append(out_uks)
    out_uks = [any(uks) for uks in zip(*branches_out_uks)]

    branches_known, branches_unknown, branch_res_avals = [], [], []
    for branch_jaxpr in branches:
        branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
            pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
        branches_known.append(branch_jaxpr_known)
        branches_unknown.append(branch_jaxpr_unknown)
        branch_res_avals.append(res_avals)

    all_res_avals, res_avals_per_branch = _merge_branch_residuals(
        branch_res_avals)
    num_res = len(all_res_avals)

    num_known_outs = len(out_uks) - sum(out_uks)
    branches_known = _join_cond_outputs(branches_known, all_res_avals,
                                        res_avals_per_branch, num_known_outs)
    branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
        branches_unknown, all_res_avals, res_avals_per_branch)
    assert all(
        all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
        for j in branches_known[1:])

    in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
    linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
    out_consts_res = cond_p.bind(*in_consts,
                                 branches=branches_known,
                                 linear=tuple(linear_known))
    out_consts, res = split_list(out_consts_res,
                                 [len(out_consts_res) - num_res])

    index_tracer = trace.instantiate_const(tracers[0])
    ops_tracers = [
        trace.instantiate_const(t)
        for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk
    ]
    res_tracers = _map(trace.new_instantiated_const, res)
    out_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
        for aval in branches_unknown[0].out_avals
    ]
    linear_unknown = ([False] * num_res +
                      [l for l, uk in zip(linear, in_unknowns[1:]) if uk])
    params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
    name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
    source = source_info_util.current().replace(name_stack=name_stack)
    eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers,
                            out_tracers, cond_p, params, core.no_effects,
                            source)
    for t in out_tracers:
        t.recipe = eqn
    return util.merge_lists(out_uks, out_consts, out_tracers)