Example #1
def closure_convert(fun, in_tree, in_avals):
    if config.omnistaging_enabled:
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(
            wrapped_fun, in_avals)
        in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
        with core.initial_style_staging():  # type: ignore
            jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
                wrapped_fun, in_pvals, instantiate=True,
                stage_out=False)  # type: ignore
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(mattjj): revise this approach
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(y, t, *hconsts_args):
        hoisted_consts, args = split_list(hconsts_args, [num_consts])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Example #2
def _array_xla_shape(aval: AbsArray):
  if isinstance(aval._eltTy, BaseType):
    dtype = aval._eltTy._dtype
    shape = [d._eltTy._bound if isinstance(d, AbsArray) and not d.shape
             else d for d in aval.shape]
    return (xla.xc.Shape.array_shape(dtype, shape),)
  elif isinstance(aval._eltTy, BoundedIntTy):
    shape = [d._bound if isinstance(d, BoundedInt) else d for d in aval.shape]
    return (xla.xc.Shape.array_shape(dtypes.dtype('int32'), shape),)
    raise NotImplementedError
Example #3
def psum(x, axis_name, *, axis_index_groups=None):
    """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  Inputs of boolean dtype are converted to integers before the reduction.

    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.

    Array(s) with the same shape as ``x`` representing the result of an
    all-reduce sum along the axis ``axis_name``.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [6 6 6 6]
  >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [ 0.          0.16666667  0.33333334  0.5       ]
    if not isinstance(axis_name, (tuple, list)):
        axis_name = (axis_name, )
    leaves, treedef = tree_util.tree_flatten(x)
    leaves = [
        lax.convert_element_type(l, np.int32)
        if dtypes.dtype(l) == np.bool_ else l for l in leaves
    out_flat = psum_p.bind(*leaves,
    return tree_util.tree_unflatten(treedef, out_flat)
Example #4
def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = [dtypes.dtype(l) for l in lst]
    to_dtype = dtypes.result_type(*from_dtypes)
    sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
    indices = np.cumsum(sizes)

    def unravel(arr):
        chunks = jnp.split(arr, indices[:-1])
        with warnings.catch_warnings():
                "ignore")  # ignore complex-to-real cast warning
            return [
                lax.convert_element_type(chunk.reshape(shape), dtype)
                for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)

    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([ravel(e) for e in lst])
    return raveled, unravel
Example #5
 def testDtypeFromString(self, dtype):
     self.assertEqual(dtypes.dtype(str(dtype)), dtype)
Example #6
 def __init__(self, dtype: DType):
   self._dtype = dtypes.dtype(dtype)
Example #7
 def is_float(c):
     return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)