コード例 #1
0
ファイル: jax2tex.py プロジェクト: xjdrlabs/google-research
 def read_shaped(v):
     if isinstance(v, core.Literal):
         if isinstance(v.val, float) or isinstance(v.val, int):
             return ShapedArray((), type(v.val))
         return ShapedArray(v.val.shape, v.val.dtype)
     else:
         return abstract[v]
コード例 #2
0
ファイル: ops.py プロジェクト: exoplanet-dev/exoplanet-core
def _contact_points_abstract_eval(*args):
    if any(a.dtype != np.float64 for a in args):
        raise ValueError("float64 precision is required")
    shape = args[0].shape
    if any(a.shape != shape for a in args[1:]):
        raise ValueError("Dimension mismatch")
    return (
        ShapedArray(shape, np.float64),
        ShapedArray(shape, np.float64),
        ShapedArray(shape, np.int32),
    )
コード例 #3
0
def eigh_abstract_eval(operand, lower):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError(
        "Argument to symmetric eigendecomposition must have shape [..., n, n]")

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    v = ShapedArray(batch_dims + (n, n), operand.dtype)
    w = ShapedArray(batch_dims + (n,), operand.dtype)
  else:
    v, w = operand, operand
  return core.AbstractTuple((v, w))
コード例 #4
0
ファイル: lax_linalg.py プロジェクト: zhongwen/jax
def eig_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    vl = vr = ShapedArray(batch_dims + (n, n), operand.dtype)
    w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
  else:
    w = vl = vr = operand
  return core.AbstractTuple((w, vl, vr))
コード例 #5
0
def eigh_abstract_eval(operand, lower):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError(
        "Argument to symmetric eigendecomposition must have shape [..., n, n],"
        "got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    v = ShapedArray(batch_dims + (n, n), operand.dtype)
    w = ShapedArray(batch_dims + (n,), lax.lax._complex_basetype(operand.dtype))
  else:
    v, w = operand, operand
  return v, w
コード例 #6
0
def svd_abstract_eval(operand, full_matrices, compute_uv):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to singular value decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    s = ShapedArray(batch_dims + (min(m, n),), lax.lax._complex_basetype(operand.dtype))
    u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
    vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
  else:
    raise NotImplementedError
  return s, u, vt
コード例 #7
0
ファイル: linalg.py プロジェクト: varun-alla/jax
def _lu_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to LU decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32)
    perm = ShapedArray(batch_dims + (m,), jnp.int32)
  else:
    pivot = operand
    perm = operand
  return operand, pivot, perm
コード例 #8
0
ファイル: linalg.py プロジェクト: varun-alla/jax
def qr_abstract_eval(operand, full_matrices):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to QR decomposition must have ndims >= 2")
    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    k = m if full_matrices else min(m, n)
    q = ShapedArray(batch_dims + (m, k), operand.dtype)
    r = ShapedArray(batch_dims + (k, n), operand.dtype)
  else:
    q = operand
    r = operand
  return q, r
コード例 #9
0
def eig_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    dtype = onp.complex64 if onp.finfo(operand.dtype).bits == 32 else onp.complex128
    dtype = xb.canonicalize_dtype(dtype)
    vl = vr = ShapedArray(batch_dims + (n, n), dtype)
    w = ShapedArray(batch_dims + (n,), dtype)
  else:
    raise NotImplementedError
  return w, vl, vr
コード例 #10
0
def svd_abstract_eval(operand, full_matrices, compute_uv):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to singular value decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    s = ShapedArray(batch_dims + (min(m, n),), operand.dtype)
    u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
    vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
  else:
    s = operand
    u = operand
    vt = operand
  return core.AbstractTuple((s, u, vt))
コード例 #11
0
ファイル: ops.py プロジェクト: exoplanet-dev/exoplanet-core
def _kepler_abstract_eval(M, ecc):
    if M.dtype != np.float64 or ecc.dtype != np.float64:
        raise ValueError("float64 precision is required")
    if M.shape != ecc.shape:
        raise ValueError("Dimension mismatch")
    out_shape = ShapedArray(M.shape, np.float64)
    return (out_shape, out_shape)
コード例 #12
0
def _psum_translation_rule(c, *args, replica_groups=None, platform=None):
  if platform in ("cpu", "tpu"):
    return _notuple_psum_translation_rule(c, *args, replica_groups=replica_groups)

  # XLA's tuple all-reduce doesn't support different dtypes in the same
  # allreduce. Instead, we perform once all-reduce for each argument input type.
  args_by_type = collections.defaultdict(lambda: ([], []))
  for i, arg in enumerate(args):
    indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
    indices.append(i)
    dtype_args.append(arg)

  # The outputs, in the original argument order.
  out = [None] * len(args)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
    is_complex = dtypes.issubdtype(dtype, onp.complexfloating)
    n = len(dtype_args)
    if is_complex:
      dtype_args = ([xops.Real(x) for x in dtype_args] +
                    [xops.Imag(x) for x in dtype_args])
    scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
    computation = xla.primitive_subcomputation(lax.add_p, scalar, scalar)
    all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                replica_groups_protos, None, None)
    if is_complex:
      xs = [xops.Complex(xops.GetTupleElement(all_reduce, i),
                         xops.GetTupleElement(all_reduce, n + i)) for i in range(n)]
    else:
      xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
    for i, x in zip(indices, xs):
      out[i] = x
  return xops.Tuple(c, out)
コード例 #13
0
ファイル: ops.py プロジェクト: exoplanet-dev/exoplanet-core
def _quad_solution_vector_abstract_eval(b, r):
    if b.dtype != np.float64 or r.dtype != np.float64:
        raise ValueError("float64 precision is required")
    if b.shape != r.shape:
        raise ValueError("Dimension mismatch")
    out_shape = ShapedArray(tuple(b.shape) + (3, ), np.float64)
    return (out_shape, out_shape, out_shape)
コード例 #14
0
ファイル: lax_parallel.py プロジェクト: wig-l/jax
def _allreduce_translation_rule(prim, c, val, replica_groups, backend=None):
    dtype = c.GetShape(val).numpy_dtype()
    scalar = ShapedArray((), dtype)
    computation = xla.primitive_subcomputation(prim,
                                               scalar,
                                               scalar,
                                               backend=backend)
    return c.AllReduce(val, computation, replica_groups=replica_groups)
コード例 #15
0
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups,
                                axis_env, platform):
    replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
    dtype = c.get_shape(val).numpy_dtype()
    scalar = ShapedArray((), dtype)
    computation = xla.primitive_subcomputation(prim, scalar, scalar)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    return xops.AllReduce(val, computation, replica_groups_protos, None, None)
コード例 #16
0
ファイル: lax_fft.py プロジェクト: yonistack/jax
def fft_abstract_eval(x, fft_type, fft_lengths):
    if not dtypes.issubdtype(x.dtype, onp.complexfloating):
        raise TypeError("FFT requires complex inputs, got {}.".format(
            x.dtype.name))
    if x.dtype != onp.complex64:
        msg = "FFT is only implemented for complex64 types, got {}."
        raise NotImplementedError(msg.format(x.dtype.name))
    return ShapedArray(x.shape, x.dtype)
コード例 #17
0
        def layer_abstract_eval(*avals):
            akey = ShapedArray((2, ), 'uint32')

            def init_and_apply(key, *inputs):
                params = init_fun(key, *inputs)
                return apply_fun(params, *inputs)

            return pe.abstract_eval_fun(init_and_apply, akey, *avals)
コード例 #18
0
ファイル: smap.py プロジェクト: berkonat/jax-md
def _grid_trace_shape(fn, *args, **kwargs):
    """Traces a function to compute the shape of its output."""
    shaped_args = []
    for arg in args:
        if isinstance(arg, np.ndarray):
            shaped_args += [ShapedArray(tuple(arg.shape), arg.dtype)]
        else:
            shaped_args += [arg]
    return pe.abstract_eval_fun(fn, *shaped_args, **kwargs).shape
コード例 #19
0
def while_loop(cond_fun, body_fun, init_val):
    """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.

  The type signature in brief is

  .. code-block:: haskell

    while_loop :: (a -> Bool) -> (a -> a) -> a -> a

  The semantics of ``while_loop`` are given by this Python implementation::

    def while_loop(cond_fun, body_fun, init_val):
      val = init_val
      while cond_fun(val):
        val = body_fun(val)
      return val

  Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
  to a single XLA While HLO. That makes it useful for reducing compilation times
  for jit-compiled functions, since native Python loop constructs in an ``@jit``
  function are unrolled, leading to large XLA computations.

  Another difference from using Python-native loop constructs is that
  ``while_loop`` is not reverse-mode differentiable because XLA computations
  require static bounds on memory requirements.

  Args:
    cond_fun: function of type ``a -> Bool``.
    body_fun: function of type ``a -> a``.
    init_val: value of type ``a``, a type that can be a scalar, array, or any
      pytree (nested Python tuple/list/dict) thereof, representing the initial
      loop carry value.

  Returns:
    The output from the final iteration of body_fun, of type ``a``.
  """
    init_vals, in_tree = tree_flatten((init_val, ))
    init_avals = tuple(_map(_abstractify, init_vals))
    cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
        cond_fun, in_tree, init_avals)
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
        body_fun, in_tree, init_avals)
    if not treedef_is_leaf(cond_tree):
        msg = "cond_fun must return a boolean scalar, but got pytree {}."
        raise TypeError(msg.format(cond_tree))
    if cond_jaxpr.out_avals != [ShapedArray((), onp.bool_)]:
        msg = "cond_fun must return a boolean scalar, but got output type(s) {}."
        raise TypeError(msg.format(cond_jaxpr.out_avals))
    if not treedef_children(in_tree) == [body_tree]:
        msg = "body_fun output pytree structure must match init_val, got {} and {}."
        raise TypeError(msg.format(body_tree, treedef_children(in_tree)[0]))
    outs = while_p.bind(*itertools.chain(cond_consts, body_consts, init_vals),
                        cond_nconsts=len(cond_consts),
                        cond_jaxpr=cond_jaxpr,
                        body_nconsts=len(body_consts),
                        body_jaxpr=body_jaxpr)
    return tree_unflatten(body_tree, outs)
コード例 #20
0
ファイル: lax_fft.py プロジェクト: yotarok/jax
def fft_abstract_eval(x, fft_type, fft_lengths):
    if fft_type == xla_client.FftType.RFFT:
        shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] +
                 (fft_lengths[-1] // 2 + 1, ))
        dtype = _complex_dtype(x.dtype)
    elif fft_type == xla_client.FftType.IRFFT:
        shape = x.shape[:-len(fft_lengths)] + fft_lengths
        dtype = _real_dtype(x.dtype)
    else:
        shape = x.shape
        dtype = x.dtype
    return ShapedArray(shape, dtype)
コード例 #21
0
ファイル: lax_parallel.py プロジェクト: trevorcai/jax
def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer
コード例 #22
0
ファイル: linalg.py プロジェクト: varun-alla/jax
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
                      compute_right_eigenvectors):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
      raise ValueError("Argument to nonsymmetric eigendecomposition must have "
                       "shape [..., n, n], got shape {}".format(operand.shape))

    batch_dims = operand.shape[:-2]
    n = operand.shape[-1]
    dtype = np.complex64 if dtypes.finfo(operand.dtype).bits == 32 else np.complex128
    dtype = dtypes.canonicalize_dtype(dtype)
    vl = vr = ShapedArray(batch_dims + (n, n), dtype)
    w = ShapedArray(batch_dims + (n,), dtype)
  else:
    raise NotImplementedError

  output = [w]
  if compute_left_eigenvectors:
    output.append(vl)
  if compute_right_eigenvectors:
    output.append(vr)

  return tuple(output)
コード例 #23
0
def _abstract_eval(spec, *args):
    vals = spec["get_dims"](*(a.shape for a in args))
    for s, arg in zip(spec["inputs"], args):
        if arg.dtype != s.get("dtype", np.float64):
            raise ValueError(
                f"Invalid dtype for '{s['name']}'; "
                f"expected {s.get('dtype', np.float64)}, got {arg.dtype}")
        shape = eval(s["shape"], dict(vals))
        if arg.shape != shape:
            raise ValueError(f"Invalid shape for '{s['name']}'; "
                             f"expected {shape}, got {arg.shape}")
    return tuple(
        ShapedArray(eval(s["shape"], dict(vals)), s.get("dtype", np.float64))
        for s in spec["outputs"] + spec["extra_outputs"])
コード例 #24
0
 def abstract_call(*inputs):
     key_and_inputs = (ShapedArray((2, ), 'uint32'), ) + inputs
     flat_rng_and_inputs, in_tree_with_rng = jax.tree_flatten(
         key_and_inputs)
     flat_fun, self._cached_out_tree = jax.flatten_fun_nokwargs(
         self._init_and_apply, in_tree_with_rng)
     flat_partial_inputs = [
         PartialVal((a, jc.unit)) for a in flat_rng_and_inputs
     ]
     _, flat_partial_outs, _ = trace_to_jaxpr(flat_fun,
                                              flat_partial_inputs,
                                              instantiate=True)
     flat_outs, _ = unzip2(flat_partial_outs)
     return flat_outs
コード例 #25
0
ファイル: energy.py プロジェクト: berkonat/jax-md
def _canonicalize_displacement_or_metric(displacement_or_metric):
    """Checks whether or not a displacement or metric was provided."""
    for dim in range(4):
        try:
            R = ShapedArray((1, dim), f32)
            dR_or_dr = pe.abstract_eval_fun(displacement_or_metric, R, R, t=0)
            if len(dR_or_dr.shape) == 2:
                return displacement_or_metric
            else:
                return space.metric(displacement_or_metric)
        except ValueError:
            continue
    raise ValueError(
        'Canonicalize displacement not implemented for spatial dimension larger'
        'than 4.')
コード例 #26
0
ファイル: lax_control_flow.py プロジェクト: jonasrauber/jax
def _masked_scan_jaxpr(jaxpr, num_consts, num_carry):
  fun = core.jaxpr_as_fun(jaxpr)

  @lu.wrap_init
  def masked(*args):
    [dynamic_length], consts, [i], carry, xs = split_list(
        args, [1, num_consts, 1, num_carry])
    out = fun(*(consts + carry + xs))
    new_carry, ys = split_list(out, [num_carry])
    new_carry = [lax.select(i < dynamic_length, new_c, c)
                 for new_c, c in zip(new_carry, carry)]
    return [i + 1] + new_carry + ys

  aval = ShapedArray((), onp.int64)
  const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
  return _make_typed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
コード例 #27
0
ファイル: space.py プロジェクト: zizai/jax-md
def canonicalize_displacement_or_metric(displacement_or_metric):
  """Checks whether or not a displacement or metric was provided."""
  for dim in range(1, 4):
    try:
      R = ShapedArray((dim,), f32)
      dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0)
      if len(dR_or_dr.shape) == 0:
        return displacement_or_metric
      else:
        return metric(displacement_or_metric)
    except TypeError:
      continue
    except ValueError:
      continue
  raise ValueError(
    'Canonicalize displacement not implemented for spatial dimension larger'
    'than 4.')
コード例 #28
0
ファイル: pmap_test.py プロジェクト: ziyadedher/jax
    def testReshardInput(self):
        if xla_bridge.device_count() < 6:
            raise SkipTest("testReshardInput requires 6 devices")
        # Manually construct a ShardedDeviceArray with the wrong sharding for the
        # subsequent pmap
        shard_shape = (3, 2)
        shard = np.arange(np.prod(shard_shape)).reshape(shard_shape)
        bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
        aval = ShapedArray((6, 4), shard.dtype)
        sharding_spec = pxla.ShardingSpec(shards_per_axis=(2, 2),
                                          is_axis_materialized=(True, True),
                                          replication_factor=2)
        arr = pxla.ShardedDeviceArray(aval, sharding_spec, bufs)

        r = pmap(lambda x: x + 1)(arr)
        self.assertAllClose(r, arr + 1, check_dtypes=True)
        self.assertEqual(len(r.device_buffers), 6)
コード例 #29
0
ファイル: partition.py プロジェクト: ruofei7/jax-md
def _displacement_or_metric_to_metric_sq(displacement_or_metric):
  """Checks whether or not a displacement or metric was provided."""
  for dim in range(1, 4):
    try:
      R = ShapedArray((dim,), f32)
      dR_or_dr = eval_shape(displacement_or_metric, R, R, t=0)
      if len(dR_or_dr.shape) == 0:
        return lambda Ra, Rb, **kwargs: \
          displacement_or_metric(Ra, Rb, **kwargs) ** 2
      else:
        return lambda Ra, Rb, **kwargs: space.square_distance(
          displacement_or_metric(Ra, Rb, **kwargs))
    except TypeError:
      continue
    except ValueError:
      continue
  raise ValueError(
    'Canonicalize displacement not implemented for spatial dimension larger'
    'than 4.')
コード例 #30
0
def omnistaging_disabler() -> None:
    global axis_index

    psum_p.bind = partial(core.Primitive.bind, psum_p)
    psum_p.def_impl(partial(pxla.apply_parallel_primitive,
                            psum_p))  # type: ignore
    pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (
        x * prod(shape) for x in args)  # type: ignore

    def _axis_index_bind(*, axis_name):
        dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
        frame = dynamic_axis_env[axis_name]
        sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame) + 1]
        nreps = dynamic_axis_env.nreps
        trace = frame.pmap_trace

        out_aval = ShapedArray((), np.int32)
        out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval),
                                    None)
        eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                                dict(nreps=nreps,
                                     sizes=sizes,
                                     axis_name=axis_name),
                                source_info_util.current())
        out_tracer.recipe = eqn

        return out_tracer

    def _axis_index_translation_rule(c, nreps, sizes, axis_name):
        div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
        mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
        unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
        return xops.ConvertElementType(unsigned_index,
                                       xb.dtype_to_etype(np.int32))

    axis_index_p.def_custom_bind(_axis_index_bind)
    axis_index_p.def_abstract_eval(lambda *args, **params: ShapedArray(
        (), np.int32))
    xla.translations[axis_index_p] = _axis_index_translation_rule