示例#1
0
def eig_abstract_eval(operand, *, compute_left_eigenvectors,
                      compute_right_eigenvectors):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
            raise ValueError(
                "Argument to nonsymmetric eigendecomposition must have "
                "shape [..., n, n], got shape {}".format(operand.shape))

        batch_dims = operand.shape[:-2]
        n = operand.shape[-1]
        dtype = np.complex64 if dtypes.finfo(
            operand.dtype).bits == 32 else np.complex128
        dtype = dtypes.canonicalize_dtype(dtype)
        vl = vr = ShapedArray(batch_dims + (n, n), dtype)
        w = ShapedArray(batch_dims + (n, ), dtype)
    else:
        raise NotImplementedError

    output = [w]
    if compute_left_eigenvectors:
        output.append(vl)
    if compute_right_eigenvectors:
        output.append(vr)

    return tuple(output)
示例#2
0
文件: linalg.py 项目: tataudat/jax
def _lu_abstract_eval(operand):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to LU decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    pivot = ShapedArray(batch_dims + (min(m, n),), jnp.int32)
    perm = ShapedArray(batch_dims + (m,), jnp.int32)
  else:
    pivot = operand
    perm = operand
  return operand, pivot, perm
示例#3
0
文件: linalg.py 项目: tataudat/jax
def qr_abstract_eval(operand, full_matrices):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to QR decomposition must have ndims >= 2")
    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    k = m if full_matrices else min(m, n)
    q = ShapedArray(batch_dims + (m, k), operand.dtype)
    r = ShapedArray(batch_dims + (k, n), operand.dtype)
  else:
    q = operand
    r = operand
  return q, r
示例#4
0
def eigh_abstract_eval(operand, lower):
    if isinstance(operand, ShapedArray):
        if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
            raise ValueError(
                "Argument to symmetric eigendecomposition must have shape [..., n, n],"
                "got shape {}".format(operand.shape))

        batch_dims = operand.shape[:-2]
        n = operand.shape[-1]
        v = ShapedArray(batch_dims + (n, n), operand.dtype)
        w = ShapedArray(batch_dims + (n, ),
                        lax_internal._complex_basetype(operand.dtype))
    else:
        v, w = operand, operand
    return v, w
示例#5
0
 def all_reduce(x):
     replica_groups_protos = xc.make_replica_groups(
         _replica_groups(axis_env, axis_name, axis_index_groups))
     scalar = ShapedArray((), c.get_shape(x).numpy_dtype())
     computation = xla.primitive_subcomputation(prim, scalar, scalar)
     return xops.AllReduce(x, computation, replica_groups_protos, None,
                           None)
示例#6
0
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis,
                              axis_index_groups):
    input_aval = raise_to_shaped(x)
    shape = list(input_aval.shape)
    size = shape.pop(split_axis)
    shape.insert(concat_axis, size)
    return ShapedArray(tuple(shape), input_aval.dtype, weak_type=False)
示例#7
0
def omnistaging_disabler() -> None:
  global axis_index

  psum_p.bind = partial(core.Primitive.bind, psum_p)  # type: ignore
  psum_p.def_impl(partial(pxla.apply_parallel_primitive, psum_p))  # type: ignore
  pxla.parallel_pure_rules[psum_p] = lambda *args, shape: (x * prod(shape) for x in args)  # type: ignore

  def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
    nreps = dynamic_axis_env.nreps
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer

  def _axis_index_translation_rule(c, nreps, sizes, axis_name):
    div = xb.constant(c, np.array(nreps // prod(sizes), dtype=np.uint32))
    mod = xb.constant(c, np.array(sizes[-1], dtype=np.uint32))
    unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
    return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))

  axis_index_p.def_custom_bind(_axis_index_bind)
  axis_index_p.def_abstract_eval(
      lambda *args, **params: ShapedArray((), np.int32))
  xla.translations[axis_index_p] = _axis_index_translation_rule
示例#8
0
def _allreduce_translation_rule(prim, c, val, *, axis_name, axis_index_groups,
                                axis_env, platform):
  replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
  dtype = c.get_shape(val).numpy_dtype()
  scalar = ShapedArray((), dtype)
  computation = xla.primitive_subcomputation(prim, scalar, scalar)
  replica_groups_protos = xc.make_replica_groups(replica_groups)
  return xops.AllReduce(val, computation, replica_groups_protos, None, None)
示例#9
0
文件: linalg.py 项目: tataudat/jax
def svd_abstract_eval(operand, full_matrices, compute_uv):
  if isinstance(operand, ShapedArray):
    if operand.ndim < 2:
      raise ValueError("Argument to singular value decomposition must have ndims >= 2")

    batch_dims = operand.shape[:-2]
    m = operand.shape[-2]
    n = operand.shape[-1]
    s = ShapedArray(batch_dims + (min(m, n),),
                    lax_internal._complex_basetype(operand.dtype))
    if compute_uv:
      u = ShapedArray(batch_dims + (m, m if full_matrices else min(m, n)), operand.dtype)
      vt = ShapedArray(batch_dims + (n if full_matrices else min(m, n), n), operand.dtype)
      return s, u, vt
    else:
      return s,
  else:
    raise NotImplementedError
示例#10
0
def get_structure(eqn: Optional[JaxprEqn], invals: List[Union[ShapedArray,
                                                              AbstractValue]],
                  idx: int, _s_rules: bool) -> Structure:
    if any(i is AbstractValue for i in invals):
        raise TypeError(invals)

    if eqn is None:
        # Identity function
        primitive = None
        cts_in = invals[0]
        assert idx == 0

    else:
        if len(eqn.outvars) != 1:
            raise NotImplementedError(eqn)
        cts_in = eqn.outvars[0].aval

        primitive = eqn.primitive
        assert len(invals) == len(eqn.invars)
        assert 0 <= idx < len(eqn.invars)

    if not isinstance(cts_in, ShapedArray):
        raise TypeError(cts_in)

    if primitive in STRUCTURE_RULES and _s_rules:
        structure = STRUCTURE_RULES[primitive](eqn, idx, invals, cts_in)

    else:
        # No simplification rule found.
        structure = Structure()

    # TODO(romann): can we avoid special-casing `reshape`s?
    if primitive == lax.reshape_p:
        cts_in = ShapedArray(invals[idx].shape, invals[idx].dtype)

    # Check that number of trace output and input axes match.
    assert len(structure.in_trace) == len(structure.out_trace)

    # Check that input and output traced sizes are the same.
    out_trace_size = utils.size_at(cts_in, structure.out_trace)
    in_trace_size = utils.size_at(invals[idx], structure.in_trace)
    assert in_trace_size == out_trace_size

    # Check that number of input/output diagonal axes match.
    assert len(structure.out_diagonal) == len(structure.in_diagonal)

    # Check for each output diagonal axis there's only input axes of correct
    # size or `None`. Inval axis should be not `None`.
    for out_d, in_d in zip(structure.out_diagonal, structure.in_diagonal):
        assert len(in_d) == len(invals)
        assert in_d[idx] is not None
        for ix, i in enumerate(in_d):
            if i is not None:
                assert invals[ix].shape[i] == cts_in.shape[out_d]

    return structure
示例#11
0
文件: fft.py 项目: tomcharnock/jax
def fft_abstract_eval(x, fft_type, fft_lengths):
    if fft_type == xla_client.FftType.RFFT:
        shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] +
                 (fft_lengths[-1] // 2 + 1, ))
        dtype = _complex_dtype(x.dtype)
    elif fft_type == xla_client.FftType.IRFFT:
        shape = x.shape[:-len(fft_lengths)] + fft_lengths
        dtype = _real_dtype(x.dtype)
    else:
        shape = x.shape
        dtype = x.dtype
    return ShapedArray(shape, dtype)
示例#12
0
def _reduce_window_sum_translation_rule(ctx, avals_in, avals_out, operand, *,
                                        window_dimensions, window_strides,
                                        padding, base_dilation,
                                        window_dilation):
    operand_aval, = avals_in
    scalar = ShapedArray((), operand_aval.dtype)
    return [
        xops.ReduceWindowWithGeneralPadding(
            operand,
            xla.pyval_to_ir_constant(ctx.builder,
                                     np.array(0, operand_aval.dtype)),
            xla.primitive_subcomputation(ctx.platform, ctx.axis_env, lax.add_p,
                                         scalar, scalar), window_dimensions,
            window_strides, base_dilation, window_dilation, padding)
    ]
示例#13
0
  def _axis_index_bind(*, axis_name):
    dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env
    frame = dynamic_axis_env[axis_name]
    sizes = dynamic_axis_env.sizes[:dynamic_axis_env.index(frame)+1]
    nreps = dynamic_axis_env.nreps
    trace = frame.pmap_trace

    out_aval = ShapedArray((), np.int32)
    out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
    eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p,
                            dict(nreps=nreps, sizes=sizes, axis_name=axis_name),
                            source_info_util.current())
    out_tracer.recipe = eqn

    return out_tracer
示例#14
0
def _allreduce_translation_rule(prim, c, *args, axis_name, axis_index_groups,
                                axis_env, platform):
    if platform in ("cpu", "tpu"):
        return _notuple_allreduce_translation_rule(
            prim,
            c,
            *args,
            axis_name=axis_name,
            axis_index_groups=axis_index_groups,
            axis_env=axis_env,
            platform=platform)

    # XLA's tuple all-reduce doesn't support different dtypes in the same
    # allreduce. Instead, we perform once all-reduce for each argument input type.
    args_by_type = collections.defaultdict(lambda: ([], []))
    for i, arg in enumerate(args):
        indices, dtype_args = args_by_type[c.get_shape(arg).numpy_dtype()]
        indices.append(i)
        dtype_args.append(arg)

    # The outputs, in the original argument order.
    out = [None] * len(args)
    replica_groups = _replica_groups(axis_env, axis_name, axis_index_groups)
    replica_groups_protos = xc.make_replica_groups(replica_groups)
    for dtype, (indices, dtype_args) in sorted(args_by_type.items()):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        n = len(dtype_args)
        if is_complex and prim is lax.add_p:
            # TODO(b/141575627): we handle complex-dtype sum-reduction directly as a
            # special case because it's not currently handled by XLA:GPU
            dtype_args = ([xops.Real(x) for x in dtype_args] +
                          [xops.Imag(x) for x in dtype_args])
        scalar = ShapedArray((), c.get_shape(dtype_args[0]).numpy_dtype())
        computation = xla.primitive_subcomputation(prim, scalar, scalar)
        all_reduce = xops.AllReduce(xops.Tuple(c, dtype_args), computation,
                                    replica_groups_protos, None, None)
        if is_complex and prim is lax.add_p:
            xs = [
                xops.Complex(xops.GetTupleElement(all_reduce, i),
                             xops.GetTupleElement(all_reduce, n + i))
                for i in range(n)
            ]
        else:
            xs = [xops.GetTupleElement(all_reduce, i) for i in range(n)]
        for i, x in zip(indices, xs):
            out[i] = x
    return xops.Tuple(c, out)
示例#15
0
def _reduce_window_abstract_eval_rule(
    *avals, jaxpr, consts, window_dimensions, window_strides, padding,
    base_dilation, window_dilation):
  operand_avals, init_val_avals = util.split_list(avals, [len(avals) // 2])
  if any(o.dtype != iv.dtype for o, iv in zip(operand_avals, init_val_avals)):
    msg = ("reduce_window got inconsistent dtypes for operands and init_values:"
           " got operand dtypes {} and init_value dtypes {}.")
    raise TypeError(msg.format([o.dtype for o in operand_avals],
                               [iv.dtype for iv in init_val_avals]))
  if any(len(v.shape) != 0 for v in init_val_avals):
    msg = ("reduce_window expected init_values to be scalars but init_values "
           "have shapes {}.")
    raise TypeError(msg.format([v.shape for v in init_val_avals]))
  out_shape = _common_reduce_window_shape_rule(
    operand_avals[0], window_dimensions, window_strides, padding,
    base_dilation, window_dilation)
  return tuple(ShapedArray(out_shape, op.dtype) for op in operand_avals)
示例#16
0
def _select_and_scatter_add_translation(ctx, avals_in, avals_out, source,
                                        operand, *, select_prim,
                                        window_dimensions, window_strides,
                                        padding, expand_padding):
    source_aval, operand_aval = avals_in
    c = ctx.builder
    dtype = operand_aval.dtype
    scalar = ShapedArray((), dtype)
    select = xla.primitive_subcomputation(ctx.platform, ctx.axis_env,
                                          select_prim, scalar, scalar)
    scatter = xla.primitive_subcomputation(
        ctx.platform, ctx.axis_env,
        lax.or_p if dtype == np.bool_ else lax.add_p, scalar, scalar)
    zero = xla.pyval_to_ir_constant(c, np.array(0, dtype))
    # TODO(b/161704903): remove this workaround when XLA:CPU bug is fixed.
    expand_padding = (expand_padding
                      and not all(lo == 0 and hi == 0 for (lo, hi) in padding))
    if expand_padding:
        original_padding = padding
        identity = (lax._get_max_identity
                    if select_prim is lax.ge_p else lax._get_min_identity)
        pads = [(lo, hi, 0) for (lo, hi) in padding]
        operand = xops.Pad(operand,
                           xla.pyval_to_ir_constant(c, identity(dtype)),
                           xc.make_padding_config(pads))
        padding = [(0, 0) for _ in padding]
    output = xops.SelectAndScatterWithGeneralPadding(operand, select,
                                                     window_dimensions,
                                                     window_strides, padding,
                                                     source, zero, scatter)
    if expand_padding:
        start_indices = [lo for (lo, hi) in original_padding]
        stop_indices = [
            lo + d
            for ((lo, hi), d) in zip(original_padding, operand_aval.shape)
        ]
        output = xops.Slice(output, start_indices, stop_indices,
                            [1] * len(start_indices))
    return [output]
示例#17
0
def sharded_aval(aval: core.ShapedArray,
                 sharding: Optional[xc.OpSharding]) -> core.ShapedArray:
    """Returns the new aval sharded based on sharding proto."""
    if sharding is None:
        return aval

    if (sharding.type == xc.OpSharding.Type.REPLICATED
            or sharding.type == xc.OpSharding.Type.MANUAL):
        return aval

    sharded_shape = []
    tile_rank = len(sharding.tile_assignment_dimensions)
    if sharding.replicate_on_last_tile_dim:
        tile_rank -= 1
    if sharding.last_tile_dims:
        tile_rank -= len(sharding.last_tile_dims)
    if tile_rank == 0:
        return aval

    for i in range(tile_rank):
        partitions = sharding.tile_assignment_dimensions[i]
        assert partitions > 0
        sharded_shape.append((aval.shape[i] + partitions - 1) // partitions)
    return aval.update(tuple(sharded_shape))
示例#18
0
  div = xb.constant(c, np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
                                dtype=np.uint32))
  mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
  unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
  return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))

def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
  assert not vals and not mapped
  idx = axis_index(axis_name)  # type: ignore
  return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True

axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule  # type: ignore
axis_index_p.def_abstract_eval(
    lambda *args, **params: ShapedArray((), np.int32))
pxla.multi_host_supported_collectives.add(axis_index_p)

# Axis index doesn't get any arguments, so that the default bind would have no
# way to call into a data-dependency based trace such as vmap. Each trace that
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
def _axis_index_bind(*, axis_name):
  if not isinstance(axis_name, (tuple, list)):
    axis_name = (axis_name,)
  inner_size = 1
  index = 0
  for name in reversed(axis_name):
    frame = core.axis_frame(name)
    if frame.main_trace is not None:
      trace = frame.main_trace.with_cur_sublevel()
示例#19
0
def _get_invals(idx, *xs):
    return [
        ShapedArray(x.shape, x.dtype) if idx == i else x
        for i, x in enumerate(xs)
    ]
示例#20
0
    def _test_primitive(self, primitive: Optional[Primitive], shapes, dtype,
                        params):
        xs = _get_inputs(shapes, dtype)
        n = len(xs)
        eqn, f = _get_f_and_eqn(params, primitive, *xs)

        out = f(*xs)
        cts_in = ShapedArray(out.shape, out.dtype)

        argnums = tuple(range(n))
        js_fwd = jax.jacfwd(f, argnums)(*xs)
        js_rev = jax.jacrev(f, argnums)(*xs)

        for idx in range(n):
            if primitive == lax.conv_general_dilated_p and idx == 0:
                raise absltest.SkipTest(
                    'Jacobian of CNN wrt inputs not implemented.')

            if primitive == lax.div_p and idx == 1:
                raise absltest.SkipTest(
                    'Division is linear only in the first arg.')

            invals = _get_invals(idx, *xs)
            j_fwd, j_rev = js_fwd[idx], js_rev[idx]

            if primitive in rules.JACOBIAN_RULES:
                j_rule = rules.JACOBIAN_RULES[primitive](eqn, idx, invals,
                                                         cts_in)
            else:
                warnings.warn(
                    f'Jacobian rule for {primitive} at position {idx} not '
                    f'found.')
                j_rule = None

            with self.subTest(f'Jacobian ({idx})'):
                self._compare_jacobians(j_fwd, j_rev, j_rule, primitive)

            structure = rules.STRUCTURE_RULES[primitive](eqn, idx, invals,
                                                         cts_in)

            j = j_fwd if j_rule is None else j_rule

            if primitive == lax.reshape_p:
                out_ndim = xs[0].ndim
                j = j.transpose(
                    tuple(xs[0].ndim + i
                          for i in onp.argsort(structure.in_trace)) +
                    tuple(i for i in onp.argsort(structure.in_trace)))
                j = j.reshape(xs[0].shape +
                              tuple(xs[0].shape[i]
                                    for i in onp.argsort(structure.in_trace)))

            else:
                out_ndim = out.ndim

            with self.subTest(f'Diagonal axes ({idx})'):
                for i, o in zip(structure.in_diagonal, structure.out_diagonal):
                    self._assert_is_diagonal(j=j,
                                             axis1=out_ndim + i[idx],
                                             axis2=o,
                                             constant_diagonal=False)

            with self.subTest(f'Constant diagonal axes ({idx})'):
                for i, o in zip(structure.in_trace, structure.out_trace):
                    self._assert_is_diagonal(j=j,
                                             axis1=out_ndim + i,
                                             axis2=o,
                                             constant_diagonal=True)

            with self.subTest(f'Input broadcast axes ({idx})'):
                for i in structure.in_broadcast:
                    self._assert_constant(j=j, axis=i)

            with self.subTest(f'Output broadcast axes ({idx})'):
                for i in structure.out_broadcast:
                    self._assert_constant(j=j, axis=i)
示例#21
0
from jax import core
from jax import lax
from jax import numpy as jnp
from jax import test_util as jtu
from jax.abstract_arrays import make_shaped_array
from jax.api import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
from jax.util import partial
from jax.interpreters import partial_eval as pe

from jax.config import config
config.parse_flags_with_absl()

_ = pe.PartialVal.unknown(UnshapedArray(np.float32))
__ = pe.PartialVal.unknown(ShapedArray((), np.float32))


def call(f, *args):
    return jit(f)(*args)


def simple_fun(x, y):
    return jnp.sin(x * y)


def simple_fun_fanout(x, y):
    return jnp.sin(x * y) * x


def fun_with_call(x):
示例#22
0
 def aval(self):
   return ShapedArray(self.polymorphic_shape,
                      dtypes.canonicalize_dtype(self.dtype))
示例#23
0
def _make_abstract_python_scalar(typ, val):
    return ShapedArray((),
                       dtypes._scalar_type_to_dtype(typ, val),
                       weak_type=True)
示例#24
0
def _array_aval_from_xla_shape(xla_shape):
    # This function instantiates the assumption that we can map fro XLA array
    # types to JAX array types.
    # TODO(mattjj): remove assumption can map XLA array types to JAX array types
    assert not xla_shape.is_tuple()
    return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
示例#25
0
def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size):
  x_aval = raise_to_shaped(x)
  new_shape = list(x_aval.shape)
  new_shape.insert(all_gather_dimension, axis_size)
  return ShapedArray(new_shape, x_aval.dtype)
示例#26
0
                 dtype=np.uint32))
    mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
    unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod)
    return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32))


def _axis_index_soft_pmap_rule(vals, mapped, chunk_size, *, axis_name):
    assert not vals and not mapped
    idx = axis_index(axis_name)  # type: ignore
    return idx * chunk_size + np.arange(chunk_size, dtype=np.int32), True


axis_index_p = core.Primitive('axis_index')
xla.parallel_translations[axis_index_p] = _axis_index_translation_rule
pxla.soft_pmap_rules[axis_index_p] = _axis_index_soft_pmap_rule  # type: ignore
axis_index_p.def_abstract_eval(lambda *args, **params: ShapedArray(
    (), np.int32))
pxla.multi_host_supported_collectives.add(axis_index_p)


# Axis index doesn't get any arguments, so that the default bind would have no
# way to call into a data-dependency based trace such as vmap. Each trace that
# wants to bind an axis name has to additionally implement `process_axis_index`
# and put its main trace on the axis env stack.
def _axis_index_bind(*, axis_name):
    if not isinstance(axis_name, (tuple, list)):
        axis_name = (axis_name, )
    inner_size = 1
    index = 0
    for name in reversed(axis_name):
        frame = core.axis_frame(name)
        if frame.main_trace is not None:
示例#27
0
 def aval(self):
     return ShapedArray(self.polymorphic_shape, self.dtype)