Esempio n. 1
0
    def test_broadcast_in_dim(self):
        x = np.zeros((7, 1))
        lax.broadcast_in_dim(x,
                             shape=(3, x.shape[0], 4),
                             broadcast_dimensions=(1, 2))

        @shapecheck(['(n, 1)'], '(3, n, 4)')
        def broadcast_in_dim(x):
            return lax.broadcast_in_dim(x,
                                        shape=(3, x.shape[0], 4),
                                        broadcast_dimensions=(1, 2))
Esempio n. 2
0
def repeat(a, repeats, axis=None):
    if not isscalar(repeats):
        raise NotImplementedError(
            "np.repeat implementation only supports scalar repeats")
    if axis is None or isscalar(a):
        a = ravel(a)
        axis = 0
    a_shape = list(shape(a))
    num_dims = len(a_shape)
    if axis < 0:
        axis = axis + num_dims

    if axis < 0 or axis >= num_dims:
        raise ValueError(
            "axis {} is out of bounds for array of dimension {}".format(
                axis, num_dims))

    # Broadcasts to [..., X, repeats, ...] and reshapes to [..., X * repeats, ...]
    broadcast_shape = list(a_shape)
    broadcast_shape.insert(axis + 1, repeats)
    broadcast_dims = onp.concatenate(
        (onp.arange(0, axis + 1), onp.arange(axis + 2, num_dims + 1)))
    a_shape[axis] *= repeats
    return lax.reshape(
        lax.broadcast_in_dim(a, broadcast_shape, broadcast_dims), a_shape)
Esempio n. 3
0
 def tri(n, m, k=0):
   # Tie in the key to avoid the mask becoming a constant.
   # This way XLA can construct the mask during computation and fuse it
   # with the attention ops.
   x = jnp.arange(n, dtype=jnp.int32)
   y = jnp.arange(m, dtype=jnp.int32)
   mask = lax.ge(
       (lax.broadcast_in_dim(x, shape=(n, m), broadcast_dimensions=(0,))) + k,
       lax.broadcast(y, [n]))
   return mask
Esempio n. 4
0
def _broadcast_in_dim_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
                        cts_in: ShapedArray) -> np.ndarray:
    inval = invals[idx]
    j = np.eye(inval.size, dtype=inval.dtype)
    j = j.reshape(inval.shape * 2)
    j = lax.broadcast_in_dim(
        j,
        cts_in.shape + inval.shape,
        broadcast_dimensions=eqn.params['broadcast_dimensions'] +
        tuple(range(cts_in.ndim, cts_in.ndim + inval.ndim)))
    return j
Esempio n. 5
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Esempio n. 6
0
def broadcast_in_dim_dependency_rule(outstart, outcount, operand, shape,
                                     broadcast_dimensions):
    if not is_ones(outcount):
        raise NotImplementedError
    outshape = outcount.shape
    is_broadcast = np.not_equal(np.shape(operand),
                                np.take(shape, broadcast_dimensions))
    instart = np.where(is_broadcast, 0, np.take(outstart,
                                                broadcast_dimensions))
    inshape = np.where(is_broadcast, 1, np.take(outshape,
                                                broadcast_dimensions))
    incount = np.full(inshape, prod(shape) // prod(operand.shape))
    return [(instart, inshape)], [incount
                                  ], lambda inslice: lax.broadcast_in_dim(
                                      inslice, outshape, broadcast_dimensions)
Esempio n. 7
0
    def merged_func(*func_args):
        typed_jaxpr, out_avals = jax.make_jaxpr(f,
                                                return_shape=True)(*func_args)
        out_tree = jax.tree_structure(out_avals)
        jaxpr, consts = typed_jaxpr.jaxpr, typed_jaxpr.literals

        # Mapping from variable -> value
        env = dict()
        read = functools.partial(read_env, env)
        write = functools.partial(write_env, env)

        # Bind args and consts to environment
        flat_args = jax.tree_flatten(func_args)[0]
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, jaxpr.invars, flat_args)
        jax_util.safe_map(write, jaxpr.constvars, consts)

        # Bind args and consts to environment
        write(jax.core.unitvar, jax.core.unit)
        jax_util.safe_map(write, jaxpr.invars, flat_args)
        jax_util.safe_map(write, jaxpr.constvars, consts)

        # Loop through equations and evaluate primitives using `bind`
        broadcasts_outputs = dict()
        for eqn in clean_jaxpr_eqns(jaxpr):
            # We ignore broadcasting of constants
            if (eqn.primitive.name == "broadcast_in_dim" and not all(
                    isinstance(v, jax_core.Literal) for v in eqn.invars)):
                if eqn.invars[0] in broadcasts_outputs:
                    x, dims = broadcasts_outputs[eqn.invars[0]]
                    kept_dims = eqn.params["broadcast_dimensions"]
                    kept_dims = [kept_dims[d] for d in dims]
                    y = lax.broadcast_in_dim(x, eqn.params["shape"], kept_dims)
                    jax_util.safe_map(write, eqn.outvars, [y])
                    broadcasts_outputs[eqn.outvars[0]] = (x, kept_dims)
                else:
                    inputs = jax_util.safe_map(read, eqn.invars)
                    evaluate_eqn(eqn, inputs, write)
                    broadcasts_outputs[eqn.outvars[0]] = (
                        inputs[0], eqn.params["broadcast_dimensions"])
            else:
                evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
        return jax.tree_unflatten(out_tree,
                                  jax_util.safe_map(read, jaxpr.outvars))
Esempio n. 8
0
def logpdf(x, alpha):
  x, alpha = _promote_dtypes_inexact(x, alpha)
  if alpha.ndim != 1:
    raise ValueError(
      f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
    )
  if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
    raise ValueError(
      "`x` must have either the same number of entries as `alpha` "
      f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
    )
  one = lax._const(x, 1)
  if x.shape[0] != alpha.shape[0]:
    x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
  normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
  if x.ndim > 1:
    alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
  log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
  return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Esempio n. 9
0
def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias):
    input_shape = x.shape
    assert len(input_shape) == len(output_shape)
    assert len(input_shape) == len(scale)
    assert len(input_shape) == len(translate)
    spatial_dims = np.nonzero(
        np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1)
        | np.not_equal(translate, 0))[0]
    if len(spatial_dims) == 0:
        return x
    output_spatial_shape = tuple(np.array(output_shape)[spatial_dims])
    indices = []
    contractions = []
    slice_shape = list(input_shape)
    in_indices = list(range(len(output_shape) + len(spatial_dims)))
    out_indices = list(range(len(output_shape)))
    for i, d in enumerate(spatial_dims):
        m = input_shape[d]
        n = output_shape[d]
        starts, weights = _compute_spans(m,
                                         n,
                                         scale[d],
                                         translate[d],
                                         kernel,
                                         antialias=antialias)
        starts = lax.broadcast_in_dim(starts, output_spatial_shape + (1, ),
                                      (i, ))
        slice_shape[d] = weights.shape[1]
        indices.append(starts.astype(np.int32))
        contractions.append(weights.astype(x.dtype))
        contractions.append([len(output_shape) + i, d])
        out_indices[d] = len(output_shape) + i
    index = lax.concatenate(indices, len(output_spatial_shape))
    dnums = lax.GatherDimensionNumbers(offset_dims=tuple(
        range(len(output_shape))),
                                       collapsed_slice_dims=(),
                                       start_index_map=tuple(spatial_dims))
    out = lax.gather(x, index, dnums, slice_shape)
    contractions.append(out_indices)
    return jnp.einsum(out,
                      in_indices,
                      *contractions,
                      precision=lax.Precision.HIGHEST)
Esempio n. 10
0
def scatter_in_bounds(operand, indices, updates, dnums):
    # Ref: see clamping code used in scatter_translation_rule
    slice_sizes = []
    pos = 0
    for i in range(len(operand.shape)):
        if i in dnums.inserted_window_dims:
            slice_sizes.append(1)
        else:
            slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
            pos += 1

    upper_bound = np.array([
        operand.shape[i] - slice_sizes[i]
        for i in dnums.scatter_dims_to_operand_dims
    ], np.int64)
    upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
    upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
                                       (len(indices.shape) - 1, ))

    lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0))
    upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
    return jnp.logical_and(lower_in_bounds, upper_in_bounds)
Esempio n. 11
0
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
                             reduction_dimension, recall_target, is_max_k,
                             reduction_input_size_override):
    prototype_arg, new_bdim = next(
        (a, b) for a, b in zip(batched_args, batch_dims) if b is not None)
    new_args = []
    for arg, bdim in zip(batched_args, batch_dims):
        if bdim is None:
            dims = np.delete(np.arange(prototype_arg.ndim), new_bdim)
            new_args.append(
                lax.broadcast_in_dim(arg, prototype_arg.shape, dims))
        else:
            new_args.append(batching.moveaxis(arg, bdim, new_bdim))
    new_reduction_dim = reduction_dimension + (new_bdim <= reduction_dimension)
    bdims = (new_bdim, ) * len(new_args)
    return (approx_top_k_p.bind(
        *new_args,
        k=k,
        reduction_dimension=new_reduction_dim,
        recall_target=recall_target,
        is_max_k=False,
        reduction_input_size_override=reduction_input_size_override), bdims)
Esempio n. 12
0
def _one_hot(x: Array, num_classes: int, *, dtype: Any,
             axis: Union[int, AxisName]) -> Array:
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Esempio n. 13
0
 def broadcast_in_dim(x):
     return lax.broadcast_in_dim(x,
                                 shape=(3, x.shape[0], 4),
                                 broadcast_dimensions=(1, 2))
Esempio n. 14
0
def broadcast_dims(for_idxs, idxs, x):
  shape = [i.size for i in for_idxs]
  idxs_used = get_stack_idxs_used(for_idxs, idxs)
  bcast_dims = [i for i, b in enumerate(idxs_used) if b]
  return lax.broadcast_in_dim(x, shape, bcast_dims)
Esempio n. 15
0
 def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(inshape, dtype)
   broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
   check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
Esempio n. 16
0
 def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
     self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng)
Esempio n. 17
0
      'a{1.0 \\over a + b}')),
 # EX 6
 Jax2TexExample(lambda W, x: W @ x, (S((3, 2)), S((2, ))),
                'f_{i} &= \\sum_{j}W_{ij}x_{j}'),
 # EX 7
 Jax2TexExample(lambda W, x: W @ x, (S((3, 2)), S((2, 3))),
                'f_{ij} &= \\sum_{k}W_{ik}x_{kj}'),
 # EX 8
 Jax2TexExample(lambda W, x: (W + W) @ (x * x), (S((3, 2)), S((2, 3))),
                ('f_{ij} &= \\sum_{k}\\left(W_{ik} + W_{ik}\\right)'
                 'x_{kj}x_{kj}')),
 # EX 9
 Jax2TexExample(grad(lambda W, x: (W + W) @ (x * x)), (S((2, )), S((2, ))),
                'f_{i} &= 1.0x_{i}x_{i} + 1.0x_{i}x_{i}'),
 # EX 10
 Jax2TexExample(lambda x: lax.broadcast_in_dim(x, (2, 3), (1, )), (S(
     (3, )), ), 'f_{ij} &= x_{j}'),
 # EX 11
 # pylint: disable=unnecessary-lambda
 Jax2TexExample(lambda c, x, y: np.where(c, x, y), (BS((3, )), S(
     (3, )), S((3, ))), ('f_{i} &= \\mathbbm 1_{c_{i}}x_{i} + \\left(1 - '
                         '\\mathbbm 1_{c_{i}}\\right)y_{i}')),
 # EX 12
 Jax2TexExample(lambda c, x, y: np.where(c, x, y), (BS(()), S(
     (3, )), S((3, ))),
                ('f_{i} &= \\mathbbm 1_{c}x_{i} + \\left(1 - \\mathbbm '
                 '1_{c}\\right)y_{i}')),
 # EX 13
 Jax2TexExample(lambda x: np.transpose(x), (S((3, 2)), ),
                ('f_{ij} &= x_{ji}')),
 # EX 14
Esempio n. 18
0
def test_broadcast_in_dim(inshape, dtype, outshape, dimensions, rng_factory):
    rng = rng_factory(np.random)
    args = [rng(inshape, dtype)]
    op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    tu.check_lazy_fun(op, *args)
Esempio n. 19
0
 def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
     rng = jtu.rand_default(self.rng())
     raise SkipTest("this test has failures in some cases")  # TODO(mattjj)
     op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
     self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng)
Esempio n. 20
0
 def fn(x):
   return lax.broadcast_in_dim(x, (2, 3), (1,))