def closure_convert(fun, in_tree, in_avals): if config.omnistaging_enabled: wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic( wrapped_fun, in_avals) else: in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) with core.initial_style_staging(): # type: ignore jaxpr, out_pvals, consts = pe.trace_to_jaxpr( wrapped_fun, in_pvals, instantiate=True, stage_out=False) # type: ignore out_tree = out_tree() # We only want to closure convert for constants with respect to which we're # differentiating. As a proxy for that, we hoist consts with float dtype. # TODO(mattjj): revise this approach is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact) (closure_consts, hoisted_consts), merge = partition_list(is_float, consts) num_consts = len(hoisted_consts) def converted_fun(y, t, *hconsts_args): hoisted_consts, args = split_list(hconsts_args, [num_consts]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten((y, t, *args)) assert in_tree == in_tree2 out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) return converted_fun, hoisted_consts
def _array_xla_shape(aval: AbsArray): if isinstance(aval._eltTy, BaseType): dtype = aval._eltTy._dtype shape = [d._eltTy._bound if isinstance(d, AbsArray) and not d.shape else d for d in aval.shape] return (xla.xc.Shape.array_shape(dtype, shape),) elif isinstance(aval._eltTy, BoundedIntTy): shape = [d._bound if isinstance(d, BoundedInt) else d for d in aval.shape] return (xla.xc.Shape.array_shape(dtypes.dtype('int32'), shape),) else: raise NotImplementedError
def psum(x, axis_name, *, axis_index_groups=None): """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. Inputs of boolean dtype are converted to integers before the reduction. Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) with the same shape as ``x`` representing the result of an all-reduce sum along the axis ``axis_name``. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [6 6 6 6] >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x) >>> print(y) [ 0. 0.16666667 0.33333334 0.5 ] """ if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name, ) _validate_axis_index_groups(axis_index_groups) leaves, treedef = tree_util.tree_flatten(x) leaves = [ lax.convert_element_type(l, np.int32) if dtypes.dtype(l) == np.bool_ else l for l in leaves ] out_flat = psum_p.bind(*leaves, axis_name=axis_name, axis_index_groups=axis_index_groups) return tree_util.tree_unflatten(treedef, out_flat)
def _ravel_list(lst): if not lst: return jnp.array([], jnp.float32), lambda _: [] from_dtypes = [dtypes.dtype(l) for l in lst] to_dtype = dtypes.result_type(*from_dtypes) sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst) indices = np.cumsum(sizes) def unravel(arr): chunks = jnp.split(arr, indices[:-1]) with warnings.catch_warnings(): warnings.simplefilter( "ignore") # ignore complex-to-real cast warning return [ lax.convert_element_type(chunk.reshape(shape), dtype) for chunk, shape, dtype in zip(chunks, shapes, from_dtypes) ] ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype)) raveled = jnp.concatenate([ravel(e) for e in lst]) return raveled, unravel
def testDtypeFromString(self, dtype): self.assertEqual(dtypes.dtype(str(dtype)), dtype)
def __init__(self, dtype: DType): self._dtype = dtypes.dtype(dtype)
def is_float(c): return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)