示例#1
0
文件: ode.py 项目: tudorcebere/jax
def closure_convert(fun, in_tree, in_avals):
    if config.omnistaging_enabled:
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(
            wrapped_fun, in_avals)
    else:
        in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
        wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun),
                                                     in_tree)
        with core.initial_style_staging():  # type: ignore
            jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
                wrapped_fun, in_pvals, instantiate=True,
                stage_out=False)  # type: ignore
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(mattjj): revise this approach
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(y, t, *hconsts_args):
        hoisted_consts, args = split_list(hconsts_args, [num_consts])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten((y, t, *args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
示例#2
0
def _array_xla_shape(aval: AbsArray):
  if isinstance(aval._eltTy, BaseType):
    dtype = aval._eltTy._dtype
    shape = [d._eltTy._bound if isinstance(d, AbsArray) and not d.shape
             else d for d in aval.shape]
    return (xla.xc.Shape.array_shape(dtype, shape),)
  elif isinstance(aval._eltTy, BoundedIntTy):
    shape = [d._bound if isinstance(d, BoundedInt) else d for d in aval.shape]
    return (xla.xc.Shape.array_shape(dtypes.dtype('int32'), shape),)
  else:
    raise NotImplementedError
示例#3
0
def psum(x, axis_name, *, axis_index_groups=None):
    """Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.

  If ``x`` is a pytree then the result is equivalent to mapping this function to
  each leaf in the tree.

  Inputs of boolean dtype are converted to integers before the reduction.

  Args:
    x: array(s) with a mapped axis named ``axis_name``.
    axis_name: hashable Python object used to name a pmapped axis (see the
      :func:`jax.pmap` documentation for more details).
    axis_index_groups: optional list of lists containing axis indices (e.g. for
      an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
      two and last two replicas). Groups must cover all axis indices exactly
      once, and all groups must be the same size.


  Returns:
    Array(s) with the same shape as ``x`` representing the result of an
    all-reduce sum along the axis ``axis_name``.

  For example, with 4 XLA devices available:

  >>> x = np.arange(4)
  >>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [6 6 6 6]
  >>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
  >>> print(y)
  [ 0.          0.16666667  0.33333334  0.5       ]
  """
    if not isinstance(axis_name, (tuple, list)):
        axis_name = (axis_name, )
    _validate_axis_index_groups(axis_index_groups)
    leaves, treedef = tree_util.tree_flatten(x)
    leaves = [
        lax.convert_element_type(l, np.int32)
        if dtypes.dtype(l) == np.bool_ else l for l in leaves
    ]
    out_flat = psum_p.bind(*leaves,
                           axis_name=axis_name,
                           axis_index_groups=axis_index_groups)
    return tree_util.tree_unflatten(treedef, out_flat)
示例#4
0
def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = [dtypes.dtype(l) for l in lst]
    to_dtype = dtypes.result_type(*from_dtypes)
    sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
    indices = np.cumsum(sizes)

    def unravel(arr):
        chunks = jnp.split(arr, indices[:-1])
        with warnings.catch_warnings():
            warnings.simplefilter(
                "ignore")  # ignore complex-to-real cast warning
            return [
                lax.convert_element_type(chunk.reshape(shape), dtype)
                for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)
            ]

    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([ravel(e) for e in lst])
    return raveled, unravel
示例#5
0
 def testDtypeFromString(self, dtype):
     self.assertEqual(dtypes.dtype(str(dtype)), dtype)
示例#6
0
 def __init__(self, dtype: DType):
   self._dtype = dtypes.dtype(dtype)
示例#7
0
 def is_float(c):
     return dtypes.issubdtype(dtypes.dtype(c), jnp.inexact)