Ejemplo n.º 1
0
 def start_subtrace(self):
     """Starts a nested trace, returns the Trace object."""
     # TODO: This follows the __enter__ part of core.new_main.
     level = core.thread_local_state.trace_state.trace_stack.next_level()
     name_stack = source_info_util.current_name_stack()
     main = core.MainTrace(level, pe.JaxprTrace, name_stack=name_stack)
     core.thread_local_state.trace_state.trace_stack.push(main)
     self._count_subtraces += 1
     return pe.JaxprTrace(main, core.cur_sublevel(), name_stack=name_stack)
Ejemplo n.º 2
0
Archivo: ad.py Proyecto: jbampton/jax
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
                  consts, primals_in, cotangents_in):
  if all(type(ct) is Zero for ct in cotangents_in):
    return map(lambda v: Zero(v.aval), jaxpr.invars)

  def write_cotangent(prim, v, ct):
    # assert v not in primal_env
    assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
    if ct is None or type(v) is Literal:
      return
    if type(ct) is Zero:
      # FIXME: This triggers a lot of failures!
      # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
      return
    axes_to_reduce = tuple(axis_name for axis_name in reduce_axes
                           if axis_name in core.get_aval(ct).named_shape
                           and axis_name not in v.aval.named_shape)
    if axes_to_reduce:
      ct = jax.lax.psum(ct, axis_name=axes_to_reduce)
    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
    if config.jax_enable_checks:
      ct_aval = core.get_aval(ct_env[v])
      joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type().strip_named_shape()
      assert v.aval.strip_weak_type().strip_named_shape() == joined_aval, (prim, v.aval, ct_aval)

  def read_cotangent(v):
    return ct_env.pop(v, Zero(v.aval))

  def read_primal(v):
    if type(v) is Literal:
      return v.val
    else:
      return primal_env.get(v, UndefinedPrimal(v.aval))

  def write_primal(v, val):
    if not is_undefined_primal(val):
      primal_env[v] = val

  primal_env: Dict[Any, Any] = {}
  write_primal(core.unitvar, core.unit)
  map(write_primal, jaxpr.constvars, consts)
  # FIXME: invars can contain both primal and tangent values, and this line
  #        forces primal_in to contain UndefinedPrimals for tangent values!
  map(write_primal, jaxpr.invars, primals_in)

  ct_env: Dict[Any, Any] = {}
  ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
         else contextlib.nullcontext())
  with ctx:
    map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
    for eqn in jaxpr.eqns[::-1]:
      # FIXME: Some invars correspond to tangents
      invals = map(read_primal, eqn.invars)
      if eqn.primitive.multiple_results:
        cts_in = map(read_cotangent, eqn.outvars)
      else:
        cts_in, = map(read_cotangent, eqn.outvars)
      name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
      with source_info_util.user_context(eqn.source_info.traceback, name_stack=name_stack):
        if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
          cts_in_avals = [v.aval for v in eqn.outvars]
          params = dict(eqn.params)
          call_jaxpr = params.pop('call_jaxpr')
          cts_out = get_primitive_transpose(eqn.primitive)(
              params, call_jaxpr, invals, cts_in, cts_in_avals, reduce_axes)
        elif eqn.primitive in reducing_transposes:
          cts_out = reducing_transposes[eqn.primitive](
              reduce_axes, cts_in, *invals, **eqn.params)
        else:
          cts_out = get_primitive_transpose(eqn.primitive)(
              cts_in, *invals, **eqn.params)
        cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
        # FIXME: Some invars correspond to primals!
        map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)

  cotangents_out = map(read_cotangent, jaxpr.invars)
  return cotangents_out
Ejemplo n.º 3
0
def _cond_partial_eval(trace, *tracers, branches, linear):
    in_unknowns = [t.pval[0] is not None for t in tracers]
    index_uk, *ops_uk = in_unknowns

    if index_uk:
        # When the branch index is unknown, we stage out the whole cond.
        # TODO(mattjj): remove this path when old remat is removed
        params = dict(branches=branches, linear=linear)
        return trace.default_process_primitive(cond_p, tracers, params)

    branches_out_uks = []
    for branch_jaxpr in branches:
        _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr,
                                                         ops_uk,
                                                         instantiate=False)
        branches_out_uks.append(out_uks)
    out_uks = [any(uks) for uks in zip(*branches_out_uks)]

    branches_known, branches_unknown, branch_res_avals = [], [], []
    for branch_jaxpr in branches:
        branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \
            pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks)
        branches_known.append(branch_jaxpr_known)
        branches_unknown.append(branch_jaxpr_unknown)
        branch_res_avals.append(res_avals)

    all_res_avals, res_avals_per_branch = _merge_branch_residuals(
        branch_res_avals)
    num_res = len(all_res_avals)

    num_known_outs = len(out_uks) - sum(out_uks)
    branches_known = _join_cond_outputs(branches_known, all_res_avals,
                                        res_avals_per_branch, num_known_outs)
    branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
        branches_unknown, all_res_avals, res_avals_per_branch)
    assert all(
        all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
        for j in branches_known[1:])

    in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
    linear_known = [l for l, uk in zip(linear, ops_uk) if not uk]
    out_consts_res = cond_p.bind(*in_consts,
                                 branches=branches_known,
                                 linear=tuple(linear_known))
    out_consts, res = split_list(out_consts_res,
                                 [len(out_consts_res) - num_res])

    index_tracer = trace.instantiate_const(tracers[0])
    ops_tracers = [
        trace.instantiate_const(t)
        for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk
    ]
    res_tracers = _map(trace.new_instantiated_const, res)
    out_tracers = [
        pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
        for aval in branches_unknown[0].out_avals
    ]
    linear_unknown = ([False] * num_res +
                      [l for l, uk in zip(linear, in_unknowns[1:]) if uk])
    params = dict(branches=branches_unknown, linear=tuple(linear_unknown))
    name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
    source = source_info_util.current().replace(name_stack=name_stack)
    eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers,
                            out_tracers, cond_p, params, core.no_effects,
                            source)
    for t in out_tracers:
        t.recipe = eqn
    return util.merge_lists(out_uks, out_consts, out_tracers)