Esempio n. 1
0
def _prune_unused_inputs(
    jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
  used = {v for v in jaxpr.outvars if isinstance(v, core.Var)}
  # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive
  # applications that do not produce used outputs. Must handle side-effecting
  # primitives and nested jaxpr.
  used.update(
      v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var))
  kept_const_idx, new_constvars = util.unzip2(
      (i, v) for i, v in enumerate(jaxpr.constvars) if v in used)
  kept_var_idx, new_invars = util.unzip2(
      (i, v) for i, v in enumerate(jaxpr.invars) if v in used)
  new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns)
  return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
Esempio n. 2
0
 def post_process_call(self, call_primitive, out_tracers, params):
   vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
   main = self.main
   def todo(vals):
     trace = MaskTrace(main, core.cur_sublevel())
     return map(partial(MaskTracer, trace), vals, shapes)
   return vals, todo
Esempio n. 3
0
File: ad.py Progetto: jbampton/jax
 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
   assert call_primitive.multiple_results
   primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
   nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
   nz_tangents = [type(t) is not Zero for t in tangents]
   if 'name' in params and not config.jax_experimental_name_stack:
     params = dict(params, name=wrap_name(params['name'], 'jvp'))
   f_jvp = jvp_subtrace(f, self.main)
   f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
   if isinstance(call_primitive, core.MapPrimitive):
     in_axes = params['in_axes']
     tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz]
     out_axes_thunk = params['out_axes_thunk']
     # The new thunk depends deterministically on the old thunk and the wrapped function.
     # Any caching already has to include the wrapped function as part of the key, so we
     # only use the previous thunk for equality checks.
     # NOTE: This assumes that the output tangents being zero is a deterministic
     #       function of which input tangents were zero.
     @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
     def new_out_axes_thunk():
       out_axes = out_axes_thunk()
       return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz))
     params = dict(params,
                   in_axes=(*in_axes, *tangent_in_axes),
                   out_axes_thunk=new_out_axes_thunk)
   f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
   update_params = call_param_updaters.get(call_primitive)
   new_params = update_params(params, nz_tangents) if update_params else params
   result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
   primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
   return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
Esempio n. 4
0
File: ad.py Progetto: jbampton/jax
  def process_custom_transpose(self, prim, call, tracers, **params):
    ps_in, ts_in = unzip2((t.primal, t.tangent) for t in tracers)
    res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves])
    res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves])

    # TODO(frostig): Handle differentiation with respect to residual
    # operands. Calling `call` twice on all operands invalid, since it
    # isn't linear in the residuals. However, we know that if we
    # write:
    #
    #   jvp_call_res = lambda x: partial(jvp, lambda r: call(r, x))
    #
    # then:
    #
    #   jvp(call, (r, x), (dr, dx)) == jvp_call_res(x)(r, dr) + call(r, dx)
    #
    # In words: a possible strategy is to take the jvp of `call` with
    # respect to residuals, and with linear arguments fixed, then add
    # that to a custom-transpose call to `call` (i.e. what we already
    # do below in the all-linear argument case).

    if any(type(t) is not Zero for t in res_ts_in):
      raise NotImplementedError(
        'JVP of custom transpose with respect to non-symbolic-zero residuals')

    ps_out = prim.bind(call, *ps_in, **params)

    lin_ts_in = map(instantiate_zeros, lin_ts_in)
    ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params)

    return map(partial(JVPTracer, self), ps_out, ts_out)
Esempio n. 5
0
def _ravel_list(lst):
  if not lst: return jnp.array([], jnp.float32), lambda _: []
  from_dtypes = [dtypes.dtype(l) for l in lst]
  to_dtype = dtypes.result_type(*from_dtypes)
  sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
  indices = np.cumsum(sizes)

  if all(dt == to_dtype for dt in from_dtypes):
    # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
    # See https://github.com/google/jax/issues/7809.
    del from_dtypes, to_dtype
    def unravel(arr):
      chunks = jnp.split(arr, indices[:-1])
      return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]
    raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
    return raveled, unravel

  # When there is more than one distinct input dtype, we perform type
  # conversions and produce a dtype-specific unravel function.
  def unravel(arr):
    arr_dtype = dtypes.dtype(arr)
    if arr_dtype != to_dtype:
      raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
                      f"but expected dtype {to_dtype}")
    chunks = jnp.split(arr, indices[:-1])
    with warnings.catch_warnings():
      warnings.simplefilter("ignore")  # ignore complex-to-real cast warning
      return [lax.convert_element_type(chunk.reshape(shape), dtype)
              for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

  ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
  raveled = jnp.concatenate([ravel(e) for e in lst])
  return raveled, unravel
Esempio n. 6
0
def remat_partial_eval(trace, *tracers, jaxpr, **params):
  assert not jaxpr.constvars
  policy = params['policy'] or (lambda *_, **__: False)
  # unzip into jaxpr_known and jaxpr_unknown
  in_unknowns = [not t.is_known() for t in tracers]
  jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \
      pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy)
  jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars))
  _, used_outs_unknown = partition_list(out_inst, out_unknowns)
  jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown)

  # compute known outputs and residuals (hoisted out of remat primitive)
  _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known())
  _, in_consts = partition_list(in_used_known, in_consts_)
  out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
  out_consts_ = iter(out_consts)
  # form known outputs and collect residual tracers
  out_known_tracers = [
      pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None)
      for uk in out_unknowns if not uk]
  residuals = list(out_consts_)

  # set up unknown outputs with a recipe to call remat
  res_tracers = map(trace.new_instantiated_const, residuals)
  in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)]
  _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers)
  out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None)
                       for x in jaxpr_unknown.outvars]
  new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True)
  recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p,
                             new_params, source_info_util.current())
  for t in out_jaxpr_tracers: t.recipe = recipe

  # zip together known and unknown outputs
  return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
Esempio n. 7
0
 def process_call(self, call_primitive, f, tracers, params):
     assert call_primitive.multiple_results
     if config.jax_experimental_name_stack:
         params = dict(params, name=params.get('name', f.__name__))
     else:
         params = dict(params,
                       name=wrap_name(params.get('name', f.__name__),
                                      'vmap'))
     vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
     if all(bdim is not_mapped for bdim in dims):
         return call_primitive.bind(f, *vals, **params)
     else:
         f_, dims_out = batch_subtrace(f, self.main, dims)
         ax_size, = {
             x.shape[d]
             for x, d in zip(vals, dims) if d is not not_mapped
         }
         f_ = _update_annotation(f_, f.in_type, ax_size, self.axis_name,
                                 dims)
         vals_out = call_primitive.bind(f_, *vals, **params)
         src = source_info_util.current()
         return [
             BatchTracer(self, v, d, src)
             for v, d in zip(vals_out, dims_out())
         ]
Esempio n. 8
0
  def to_jaxpr(self, in_dim_tracers, in_tracers, out_dim_tracers, out_tracers):
    t2v = lambda t: self.tracer_to_var[id(t)]
    in_dim_binders, in_binders = map(t2v, in_dim_tracers), map(t2v, in_tracers)
    out_dims, outs = map(t2v, out_dim_tracers), map(t2v, out_tracers)

    # only include constants that are used
    used_vars = ({a for eqn in self.eqns for a in eqn.invars if isinstance(a, Var)} |
                 {a for grp in [out_dims, outs] for a in grp if isinstance(a, Var)})
    constvars, constvals = unzip2(
        (v, c) for v, c in self.constvar_to_val.items() if v in used_vars)
    in_binders = [*constvars, *in_binders]

    # promote some lambda binders to pi binders
    used_shape_vars = ({d for eqn in self.eqns for v in eqn.outvars
                        if isinstance(v.aval, AbsArray)
                        for d in v.aval.shape if isinstance(d, Var)} |
                       {d.name for eqn in self.eqns for v in eqn.outvars
                        if isinstance(v.aval, AbsArray)
                        for d in v.aval.shape if isinstance(d, DimIndexingExpr)})
    lambda_binders = [v not in used_shape_vars for v in in_binders]
    converted_binders, in_binders = partition_list(lambda_binders, in_binders)
    in_dim_binders = in_dim_binders + converted_binders
    out_dims = [v for v in out_dims if v not in in_dim_binders]  # TODO

    jaxpr = DJaxpr(in_dim_binders, in_binders, out_dims, outs, self.eqns)
    typecheck_jaxpr(jaxpr)
    return jaxpr, constvals, lambda_binders
Esempio n. 9
0
  def testOpShardingRoundTrip(self):
    FakeDevice = namedtuple('FakeDevice', ['id'])
    mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
    mesh_axes, mesh_shape = unzip2(mesh_named_shape.items())
    devices = [FakeDevice(i) for i in range(np.prod(list(mesh_shape)))]
    mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes))

    dims = 5
    aval = jax.core.ShapedArray((len(devices),) * dims, jnp.float32)
    def roundtrip(spec):
      op_sharding = pjit_lib.get_aval_sharding_proto(aval, spec, mesh)
      parsed_spec = pjit_lib.parse_op_sharding(op_sharding, mesh).partitions
      self.assertEqual(parsed_spec[:len(spec)], spec)
      self.assertEqual(parsed_spec[len(spec):], ((),) * (len(parsed_spec) - len(spec)))

    special_specs = [P()]
    for spec in special_specs:
      roundtrip(spec)

    rng = np.random.default_rng(1)
    for i in range(100):
      spec = [()] * dims
      for axis in rng.permutation(mesh_axes)[:rng.integers(low=1, high=len(mesh_axes) + 1)]:
        spec[rng.choice(dims)] += (axis,)
      roundtrip(P(*spec))
Esempio n. 10
0
    def post_process_map(self, call_primitive, out_tracers, params):
        vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
        main = self.main

        def both_mapped(in_out_axis, d):
            return in_out_axis is not None and d is not not_mapped

        def todo(vals):
            trace = main.with_cur_sublevel()
            return [
                BatchTracer(
                    trace, v,
                    d + 1 if both_mapped(out_axis, d) and out_axis <= d else d)
                for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']
                                          ())
            ]

        if call_primitive.map_primitive:

            def out_axes_transform(out_axes):
                return tuple(out_axis + 1 if both_mapped(out_axis, d)
                             and d < out_axis else out_axis
                             for out_axis, d in zip(out_axes, dims))

            todo = (todo, out_axes_transform)
        return vals, todo
Esempio n. 11
0
File: jet.py Progetto: 0x0is1/jax
def jet_subtrace(main, primals, series):
    trace = JetTrace(main, core.cur_sublevel())
    in_tracers = map(partial(JetTracer, trace), primals, series)
    ans = yield in_tracers, {}
    out_tracers = map(trace.full_raise, ans)
    out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
    yield out_primals, out_terms
Esempio n. 12
0
def cond_error_check(error, index, *ops, branches, linear):
  new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
  new_linear = (False, False, *linear)
  err, code, *outs = control_flow.cond_p.bind(
      index, error.err, error.code, *ops,
      branches=tuple(new_branches), linear=new_linear)
  new_msgs = {k:v for d in it.chain([error.msgs], msgs_) for k, v in d.items()}
  return outs, Error(err, code, new_msgs)
Esempio n. 13
0
def doubling_subtrace(main, heads, tails):
  trace = DoublingTrace(main, core.cur_sublevel())
  in_tracers = [DoublingTracer(trace, h, t) if t is not None else h
                for h, t in zip(heads, tails)]
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  yield unzip2([(out_tracer.head, out_tracer.tail)
                for out_tracer in out_tracers])
Esempio n. 14
0
 def new_f(*args, **kwargs):
   axis_names, shape = unzip2(named_shape)
   size = np.prod(shape)
   local_devices = list(jax.local_devices())
   if len(local_devices) < size:
     raise SkipTest(f"Test requires {size} local devices")
   mesh_devices = np.array(local_devices[:size]).reshape(shape)
   with mesh(mesh_devices, axis_names):
     return f(*args, **kwargs)
Esempio n. 15
0
 def process_call(self, call_primitive, f, tracers, params):
   assert call_primitive.multiple_results
   params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask'))
   vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
   if not any(is_polymorphic(s) for s in shapes):
     return call_primitive.bind(f, *vals, **params)
   else:
     logical_env, padded_env = shape_envs
     env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
     logical_env_vals = tuple(logical_env[k] for k in env_keys)
     # Make padded_env hashable
     padded_env = (env_keys, padded_env_vals)
     f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env)
     if 'donated_invars' in params:
       params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
                                             params['donated_invars']))
     vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params)
     return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())]
Esempio n. 16
0
    def post_process_custom_jvp_call(self, out_tracers, params):
        vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
        main = self.main

        def todo(vals):
            trace = main.with_cur_sublevel()
            return map(partial(BatchTracer, trace), vals, dims)

        return vals, todo
Esempio n. 17
0
def batch_subtrace(main, in_dims, *in_vals):
  # used in e.g. process_call
  trace = main.with_cur_sublevel()
  in_dims = in_dims() if callable(in_dims) else in_dims
  in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
                if dim is not None else x for x, dim in zip(in_vals, in_dims)]
  outs = yield in_tracers, {}
  out_tracers = map(trace.full_raise, outs)
  out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
  yield out_vals, out_dims
Esempio n. 18
0
 def process_custom_jvp_call(self, prim, fun, jvp, tracers):
     in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
     fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
     jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
     out_vals = prim.bind(fun, jvp, *in_vals)
     fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
     if not fst:
         assert out_dims == out_dims[:len(out_dims) // 2] * 2
         out_dims = out_dims[:len(out_dims) // 2]
     return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
Esempio n. 19
0
 def post_process_call(self, call_primitive, out_tracers, params):
   primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
   out, treedef = tree_flatten((primals, series))
   del primals, series
   main = self.main
   def todo(x):
     primals, series = tree_unflatten(treedef, x)
     trace = JetTrace(main, core.cur_sublevel())
     return map(partial(JetTracer, trace), primals, series)
   return out, todo
Esempio n. 20
0
def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
  env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
  logical_env_vals = [logical_env[k] for k in env_keys]
  # Make padded_env hashable
  padded_env = (env_keys, padded_env_vals)
  with core.new_main(MaskTrace) as main:
    fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env)
    out_vals = fun.call_wrapped(*(logical_env_vals + in_vals))
    del main
  return out_vals, out_shapes()
Esempio n. 21
0
 def process_call(self, call_primitive, f, tracers, params):
   primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
   primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
   f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
   update_params = call_param_updaters.get(call_primitive)
   new_params = (update_params(params, len(primals_and_series))
                 if update_params else params)
   result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
   primals_out, series_out = tree_unflatten(out_tree_def(), result)
   return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
Esempio n. 22
0
    def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
        vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
        if all(dim is not_mapped for dim in dims):
            return map_primitive.bind(f, *vals, **params)
        else:
            assert len({
                x.shape[d]
                for x, d in zip(vals, dims) if d is not not_mapped
            }) == 1

            # The logic for the dimension math below is as follows:
            # ╔═════════════╦════════════════════════════════════════╦═══════════╗
            # ║ d / in_axis ║ None                                   ║ int       ║
            # ╠═════════════╬════════════════════════════════════════╩═══════════╣
            # ║ None        ║ No extra axis, so in_axis unaffected               ║
            # ╠═════════════╬════════════════════════════════════════╦═══════════╣
            # ║ int         ║ Not mapped, so batching dim unaffected ║ See below ║
            # ╚═════════════╩════════════════════════════════════════╩═══════════╝
            # When both d and in_axis are defined then:
            # - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
            # - If `d >  in_axis`, we have to decrement `d` (as `in_axis` will get removed).
            def both_mapped(in_out_axis, d):
                return in_out_axis is not None and d is not not_mapped

            new_in_axes = tuple(
                in_axis +
                1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
                for d, in_axis in zip(dims, params['in_axes']))
            new_dims = tuple(
                d - 1 if both_mapped(in_axis, d) and in_axis < d else d
                for d, in_axis in zip(dims, params['in_axes']))
            f, dims_out = batch_subtrace(f, self.main, new_dims)
            out_axes_thunk = params['out_axes_thunk']
            # NOTE: This assumes that the choice of the dimensions over which outputs
            #       are batched is entirely dependent on the function and not e.g. on the
            #       data or its shapes.
            @as_hashable_function(closure=out_axes_thunk)
            def new_out_axes_thunk():
                return tuple(
                    out_axis + 1
                    if both_mapped(out_axis, d) and d < out_axis else out_axis
                    for out_axis, d in zip(out_axes_thunk(), dims_out()))

            new_params = dict(params,
                              in_axes=new_in_axes,
                              out_axes_thunk=new_out_axes_thunk)
            vals_out = map_primitive.bind(f, *vals, **new_params)
            dims_out = (d +
                        1 if both_mapped(out_axis, d) and out_axis <= d else d
                        for d, out_axis in zip(dims_out(), out_axes_thunk()))
            src = source_info_util.current()
            return [
                BatchTracer(self, v, d, src)
                for v, d in zip(vals_out, dims_out)
            ]
Esempio n. 23
0
File: ad.py Progetto: jbampton/jax
def jvp_subtrace(main, primals, tangents):
  trace = JVPTrace(main, core.cur_sublevel())
  for x in list(primals) + list(tangents):
    if isinstance(x, Tracer):
      assert x._trace.level < trace.level
  in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x
                for x, t in zip(primals, tangents)]
  ans = yield in_tracers, {}
  out_tracers = map(trace.full_raise, ans)
  yield unzip2([(out_tracer.primal, out_tracer.tangent)
                for out_tracer in out_tracers])
Esempio n. 24
0
def _squeeze_lowering(ctx, x, dimensions):
  in_aval, = ctx.avals_in
  out_aval, = ctx.avals_out
  if not out_aval.shape:
    return Idx(x, (unitIdx,))
  idx_names, idx_tys = unzip2((ctx.fresh('i'), FinType(Literal(sz)))
                              for sz in out_aval.shape)
  idx_name = iter(idx_names)
  idxs = [unitIdx if dim in dimensions else Var(next(idx_name))
          for dim in range(in_aval.ndim)]
  return For(tuple(idx_names), tuple(idx_tys), Idx(x, tuple(idxs)))
Esempio n. 25
0
File: ad.py Progetto: jbampton/jax
 def process_primitive(self, primitive, tracers, params):
   primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
   jvp = primitive_jvps.get(primitive)
   if not jvp:
     msg = f"Differentiation rule for '{primitive}' not implemented"
     raise NotImplementedError(msg)
   primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
   if primitive.multiple_results:
     return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
   else:
     return JVPTracer(self, primal_out, tangent_out)
Esempio n. 26
0
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
    """Test utility for setting up meshes given mesh data from `schedules`."""
    # This is similar to the `with_mesh` function above, but isn't a decorator.
    axis_names, shape = unzip2(named_shape)
    size = prod(shape)
    local_devices = list(jax.local_devices())
    if len(local_devices) < size:
        raise SkipTest(f"Test requires {size} local devices")
    mesh_devices = np.array(local_devices[:size]).reshape(shape)
    with mesh(mesh_devices, axis_names):
        yield
Esempio n. 27
0
File: ad.py Progetto: jbampton/jax
 def process_custom_jvp_call(self, _, __, f_jvp, tracers):
   primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
   primals_in = map(core.full_lower, primals_in)
   tangents_in = map(instantiate_zeros, tangents_in)
   # Cast float0 to zeros with the primal dtype because custom jvp rules don't
   # currently handle float0s
   tangents_in = map(replace_float0s, primals_in, tangents_in)
   outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
   primals_out, tangents_out = split_list(outs, [len(outs) // 2])
   tangents_out = map(recast_to_float0, primals_out, tangents_out)
   return map(partial(JVPTracer, self), primals_out, tangents_out)
Esempio n. 28
0
def _broadcasting_binop(binop_expr: Expr, ctx, x, y):
  x_aval, y_aval = ctx.avals_in
  out_aval, = ctx.avals_out
  if not out_aval.shape:
    return App(App(binop_expr, x), y)
  idx_names, idx_tys = unzip2((ctx.fresh('i'), FinType(Literal(sz)))
                              for sz in out_aval.shape)
  x_expr = _make_bcast_expr(idx_names, out_aval.shape, x_aval.shape, x)
  y_expr = _make_bcast_expr(idx_names, out_aval.shape, y_aval.shape, y)
  out = For(tuple(idx_names), tuple(idx_tys),
            App(App(binop_expr, x_expr), y_expr))
  return out
Esempio n. 29
0
File: ad.py Progetto: jbampton/jax
 def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
   primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
   tangents_in = map(instantiate_zeros, tangents_in)
   res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
   out_tree, res_tree = out_trees()
   res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
   avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
   tangents_out = custom_lin_p.bind(
       *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
       out_avals=avals_out)
   tangents_out = map(recast_to_float0, primals_out, tangents_out)
   return map(partial(JVPTracer, self), primals_out, tangents_out)
Esempio n. 30
0
File: ad.py Progetto: jbampton/jax
def jvp_subtrace_aux(main, primals, tangents):
  trace = JVPTrace(main, core.cur_sublevel())
  for x in list(primals) + list(tangents):
    if isinstance(x, Tracer):
      assert x._trace.level < trace.level
  ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
  ans_tracers = map(trace.full_raise, ans)
  out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
  aux_primals = [core.full_lower(x.primal)
                 if isinstance(x, JVPTracer) and x._trace.level == trace.level
                 else x for x in aux]
  yield (out_primals, out_tangents), aux_primals