Пример #1
0
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
  from jax.interpreters.partial_eval import (
    trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
    convert_constvars_jaxpr, new_jaxpr_eqn)
  assert primitive is xmap_p
  in_avals = [t.aval for t in tracers]
  global_axis_sizes = params['global_axis_sizes']
  mapped_in_avals = [_delete_aval_axes(a, a_in_axes)
                     for a, a_in_axes in zip(in_avals, params['in_axes'])]
  with core.extend_axis_env_nd(global_axis_sizes.items()):
    jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
        f, self.main, mapped_in_avals)
  out_axes = params['out_axes_thunk']()
  axis_resource_count = _get_axis_resource_count(params['axis_resources'],
                                                 params['resource_env'])
  local_axis_sizes = {axis: axis_resource_count[axis].to_local(global_size)
                      for axis, global_size in global_axis_sizes.items()}
  out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
               for a, a_out_axes in zip(mapped_out_avals, out_axes)]
  _check_out_avals_vs_out_axes(out_avals, out_axes, params['global_axis_sizes'])
  source_info = source_info_util.current()
  out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
  invars = map(self.getvar, tracers)
  constvars = map(self.getvar, map(self.instantiate_const, consts))
  outvars = map(self.makevar, out_tracers)
  new_in_axes = (AxisNamePos(user_repr='{}'),) * len(consts) + params['in_axes']
  new_donated_invars = (False,) * len(consts) + params['donated_invars']
  new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
                    donated_invars=new_donated_invars,
                    call_jaxpr=convert_constvars_jaxpr(jaxpr))
  del new_params['out_axes_thunk']
  eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
                      new_params, source_info)
  self.frame.eqns.append(eqn)
  return out_tracers
Пример #2
0
def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
  from jax.interpreters.partial_eval import (
    trace_to_subjaxpr_dynamic, DynamicJaxprTracer, source_info_util,
    convert_constvars_jaxpr, call_param_updaters, new_jaxpr_eqn)
  assert primitive is xmap_p
  in_avals = [t.aval for t in tracers]
  axis_sizes = params['axis_sizes']
  mapped_in_avals = [_delete_aval_axes(a, a_in_axes)
                     for a, a_in_axes in zip(in_avals, params['in_axes'])]
  with core.extend_axis_env_nd(params['axis_sizes'].items()):
    jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
        f, self.main, mapped_in_avals)
  out_axes = params['out_axes_thunk']()
  out_avals = [_insert_aval_axes(a, a_out_axes, axis_sizes)
               for a, a_out_axes in zip(mapped_out_avals, out_axes)]
  source_info = source_info_util.current()
  out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
  invars = map(self.getvar, tracers)
  constvars = map(self.getvar, map(self.instantiate_const, consts))
  outvars = map(self.makevar, out_tracers)
  new_in_axes = (None,) * len(consts) + params['in_axes']
  new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
                    call_jaxpr=convert_constvars_jaxpr(jaxpr))
  del new_params['out_axes_thunk']
  update_params = call_param_updaters.get(primitive)
  if update_params:
    new_params = update_params(new_params, [True] * len(tracers))
  eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
                      new_params, source_info)
  self.frame.eqns.append(eqn)
  return out_tracers
Пример #3
0
def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
              ) -> Tuple[List[bool], Optional[core.JaxprEqn]]:
  new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
  new_params = dict(eqn.params, jaxpr=new_jaxpr)
  if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
    return used_inputs, None
  else:
    new_eqn = pe.new_jaxpr_eqn(
        [v for v, used in zip(eqn.invars, used_inputs) if used],
        [v for v, used in zip(eqn.outvars, used_outputs) if used],
        eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
    return used_inputs, new_eqn
Пример #4
0
def _broadcast_staging_rule(trace, tracers, params):
  x, d = tracers
  d_const = trace.get_const(d)
  if d_const is not None:
    raise NotImplementedError  # TODO
  else:
    aval = x.aval
    dtype = aval._eltTy._dtype if isinstance(aval, AbsArray) else aval.dtype
    out_aval = AbsArray((d, *x.shape), BaseType(dtype))
    out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None)
    eqn = pe.new_jaxpr_eqn([trace.getvar(x), trace.getvar(d)],
                           [trace.makevar(out_tracer)], broadcast_p, {}, None)
    trace.frame.eqns.append(eqn)
    return out_tracer
Пример #5
0
def _iota_staging_rule(trace, tracers, params):
  tracer, = tracers
  n = trace.get_const(tracer)
  if n is not None:
    if type(n) is not int: raise NotImplementedError  # TODO batched version?
    out_aval = core.ShapedArray((n,), np.dtype('int32'))
    out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None)
    outvar = trace.makevar(out_tracer)
    eqn = pe.new_jaxpr_eqn([], [outvar], iota_p, dict(size=n), None)
  else:
    aval = tracer.aval
    if not isinstance(aval, AbsArray): raise TypeError
    if aval.shape:
      indices = tuple(range(len(aval.shape)))
      out_aval = AbsArray((*aval.shape, DimIndexingExpr(tracer, indices)),
                             BaseType(np.dtype('int32')))
    else:
      out_aval = AbsArray((tracer,), BaseType(np.dtype('int32')))
    out_tracer = pe.DynamicJaxprTracer(trace, out_aval, None)
    outvar = trace.makevar(out_tracer)
    invar = trace.getvar(tracer)
    eqn = pe.new_jaxpr_eqn([invar], [outvar], iota_p, {}, None)
  trace.frame.eqns.append(eqn)
  return out_tracer
Пример #6
0
def _scan_partial_eval(trace, *tracers, **kwargs):
  forward, length, num_consts, num_carry, jaxpr, linear = split_dict(
      kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"])
  num_xs = len(jaxpr.in_avals) - num_carry - num_consts
  num_ys = len(jaxpr.out_avals) - num_carry

  unknowns = original_unknowns = [t.pval[0] is not None for t in tracers]
  const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])

  carry_uk = init_uk
  for _ in range(1000):
    unknowns = const_uk + carry_uk + xs_uk
    jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr(
        jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
    carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:]
    if carry_uk_out == carry_uk:
      break
    else:
      carry_uk = carry_uk_out
  else:
    raise FixedPointError

  in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)]
  new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit)
                 for uk, t in zip(unknowns, tracers)]

  carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])
  ys_avals = _map(partial(_promote_aval_rank, length), y_avals)
  out_avals = carry_avals + ys_avals
  out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)]

  linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)]
  out_flat = scan_p.bind(
      *in_consts, forward=forward, length=length, jaxpr=jaxpr_1,
      num_consts=num_consts, num_carry=num_carry, linear=linear_1)
  out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys])
  out_consts = out_carry + ys
  residual_tracers = _map(trace.new_instantiated_const, residuals)
  out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
                 for pv, const in zip(out_pvs, out_consts)]
  linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)]
              + [False] * len(residual_tracers))
  eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p,
                         (), dict(forward=forward, length=length, jaxpr=jaxpr_2,
                                  num_consts=num_consts, num_carry=num_carry,
                                  linear=linear_2))
  for t in out_tracers: t.recipe = eqn
  return out_tracers
Пример #7
0
def partial_eval_jaxpr(jaxpr, in_unknowns):
  env: Dict[Var, bool] = {}
  res = []

  def read(v):
    if type(v) is core.Literal:
      raise NotImplementedError  # TODO
    else:
      return env[v]

  def write(unk, v):
    env[v] = unk

  def new_res(v):
    res.append(v)
    return v

  eqns1, eqns2 = [], []
  map(write, in_unknowns, jaxpr.in_binders)
  for eqn in jaxpr.eqns:
    unks = map(read, eqn.invars)
    if any(unks):
      invars = [v if unk else new_res(v) for unk, v in zip(unks, eqn.invars)]
      eqns2.append(pe.new_jaxpr_eqn(invars, eqn.outvars, eqn.primitive,
                                    eqn.params, None))
      map(partial(write, True), eqn.outvars)
    else:
      eqns1.append(eqn)
      map(partial(write, False), eqn.outvars)
  out_unknowns = map(read, jaxpr.outs)
  out_dim_unknowns = map(read, jaxpr.out_dims)  # when linearizing, all known

  invars1, invars2 = partition_list(in_unknowns, jaxpr.in_binders)
  outvars1, outvars2 = partition_list(out_unknowns, jaxpr.outs)
  out_dims1, out_dims2 = partition_list(out_dim_unknowns, jaxpr.out_dims)

  outvars1 = outvars1 + res
  invars2 = res + invars2

  # TODO forward the correct residuals here (all dimvars used in types)
  in_dimvars2 = out_dims1 + jaxpr.in_dim_binders

  jaxpr1 = DJaxpr(jaxpr.in_dim_binders, invars1, out_dims1, outvars1, eqns1)
  jaxpr2 = DJaxpr(in_dimvars2,          invars2, out_dims2, outvars2, eqns2)

  return jaxpr1, jaxpr2, out_unknowns, len(res)
Пример #8
0
def _nonzero_staging_rule(trace, tracers, params):
  aval = tracers[0].aval
  if isinstance(aval, AbsArray) and not isinstance(aval._eltTy, BaseType):
    raise NotImplementedError
  bound = aval.shape[-1]
  bound = bound if isinstance(bound, int) else bound._bound
  out_dim_aval = AbsArray(aval.shape[:-1], BoundedIntTy(bound))
  out_dim_tracer = pe.DynamicJaxprTracer(trace, out_dim_aval, None)
  if len(aval.shape) == 1:
    out_val_aval = AbsArray((out_dim_tracer,), BaseType(np.dtype('int32')))
  else:
    indices = tuple(range(len(aval.shape[:-1])))
    expr = DimIndexingExpr(out_dim_tracer, indices)
    out_val_aval = AbsArray((*aval.shape[:-1], expr),
                              BaseType(np.dtype('int32')))
  out_val_tracer = pe.DynamicJaxprTracer(trace, out_val_aval, None)
  invars = map(trace.getvar, tracers)
  outvars = map(trace.makevar, [out_dim_tracer, out_val_tracer])
  eqn = pe.new_jaxpr_eqn(invars, outvars, nonzero_p, {}, None)
  trace.frame.eqns.append(eqn)
  return out_val_tracer