def _add_sparse(spenv, *argspecs): X, Y = argspecs if X.is_sparse() and Y.is_sparse(): if X.shape != Y.shape: raise NotImplementedError( "Addition between sparse matrices of different shapes.") if X.indices_ref == Y.indices_ref: out_data = lax.add(X.data(spenv), Y.data(spenv)) out_argspec = ArgSpec(X.shape, spenv.push(out_data), X.indices_ref) elif X.indices(spenv).ndim != Y.indices(spenv).ndim or X.data( spenv).ndim != Y.data(spenv).ndim: raise NotImplementedError( "Addition between sparse matrices with different batch/dense dimensions." ) else: out_indices = lax.concatenate( [X.indices(spenv), Y.indices(spenv)], dimension=X.indices(spenv).ndim - 1) out_data = lax.concatenate( [X.data(spenv), Y.data(spenv)], dimension=X.indices(spenv).ndim - 2) out_argspec = ArgSpec(X.shape, spenv.push(out_data), spenv.push(out_indices)) else: raise NotImplementedError("Addition between sparse and dense matrix.") return (out_argspec, )
def _add_sparse(spenv, *spvalues): X, Y = spvalues if X.is_sparse() and Y.is_sparse(): if X.shape != Y.shape: raise NotImplementedError( "Addition between sparse matrices of different shapes.") if X.indices_ref == Y.indices_ref: out_data = lax.add(spenv.data(X), spenv.data(Y)) if config.jax_enable_checks: assert X.indices_sorted == Y.indices_sorted assert X.unique_indices == Y.unique_indices out_spvalue = spenv.sparse(X.shape, out_data, indices_ref=X.indices_ref, indices_sorted=X.indices_sorted, unique_indices=X.unique_indices) elif spenv.indices(X).ndim != spenv.indices(Y).ndim or spenv.data( X).ndim != spenv.data(Y).ndim: raise NotImplementedError( "Addition between sparse matrices with different batch/dense dimensions." ) else: out_indices = lax.concatenate( [spenv.indices(X), spenv.indices(Y)], dimension=spenv.indices(X).ndim - 2) out_data = lax.concatenate( [spenv.data(X), spenv.data(Y)], dimension=spenv.indices(X).ndim - 2) out_spvalue = spenv.sparse(X.shape, out_data, out_indices) else: raise NotImplementedError("Addition between sparse and dense array.") return (out_spvalue, )
def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)
def test_concatenate(self): self.check(lambda x, y, z: lax.concatenate([x, y, z], 0), ['n', 'm', 'n'], 'm + 2 * n', { 'n': 2, 'm': 3 }, [(4, ), (3, ), (4, )], ['float_', 'float_', 'float_'], jtu.rand_default(self.rng()))
def _dct_ortho_norm(out, axis): factor = lax.concatenate([ lax.full((1, ), 4, out.dtype), lax.full((out.shape[axis] - 1, ), 2, out.dtype) ], 0) factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis]) return out / lax.sqrt(factor * out.shape[axis])
def concatenate_dependency_rule(outstart, outcount, *operands, dimension): if not is_ones(outcount): raise NotImplementedError dim = dimension outstart, outshape = list(outstart), list(outcount.shape) dimstart, dimshape = outstart[dim], outshape[dim] position = 0 inboxes = [] incounts = [] for operand in operands: shape = operand.shape if dimstart < position + shape[dim] and position < dimstart + dimshape: instart = (outstart[:dim] + [max(0, dimstart - position)] + outstart[dim + 1:]) inshape = (outshape[:dim] + [ min(dimstart + dimshape - position, shape[dim], dimshape, position + shape[dim] - instart[dim]) ] + outshape[dim + 1:]) inboxes.append((instart, inshape)) incounts.append(Ones(inshape)) else: inboxes.append(None) incounts.append(None) position += shape[dim] return inboxes, incounts, lambda *inslices: lax.concatenate( [x for x in inslices if x is not None], dimension)
def testConcatenate(self): R = lambda *shape: np.random.RandomState(0).randn(*shape).astype(np.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = np.concatenate([x, np.swapaxes(y, 0, 1), np.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = np.concatenate([x, np.broadcast_to(y, (10, 2, 3)), np.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False)
def threefry_seed(seed: int) -> jnp.ndarray: """Create a single raw threefry PRNG key given an integer seed. Args: seed: a 64- or 32-bit integer used as the value of the key. Returns: The PRNG key contents, modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros). """ # Avoid overflowerror in X32 mode by first converting ints to int64. # This breaks JIT invariance for large ints, but supports the common # use-case of instantiating with Python hashes in X32 mode. if isinstance(seed, int): seed_arr = jnp.asarray(np.int64(seed)) else: seed_arr = jnp.asarray(seed) if seed_arr.shape: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed_arr.dtype, np.integer): raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32))) k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF))) return lax.concatenate([k1, k2], 0)
def scan_reference(f, init, xs): carry = init ys = [] for x in xs: (carry, y) = f(carry, x) ys.append(lax.reshape(y, (1, ) + onp.shape(y))) ys = lax.concatenate(ys, 0) return carry, ys
def test_concatenate(dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(np.random) shapes = [ base_shape[:dim] + (size, ) + base_shape[dim + 1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs)) ] args = [rng(shape, dtype) for shape in shapes] op = lambda *args: lax.concatenate(args, dim) tu.check_lazy_fun(op, *args)
def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def block_diag(*arrs): if len(arrs) == 0: arrs = [jnp.zeros((1, 0))] arrs = jnp._promote_dtypes(*arrs) bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2] if bad_shapes: raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at " "most 2 dimensions, got {} at argument {}." .format(arrs[bad_shapes[0]], bad_shapes[0])) arrs = [jnp.atleast_2d(a) for a in arrs] acc = arrs[0] dtype = lax.dtype(acc) for a in arrs[1:]: _, c = a.shape a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0))) acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0))) acc = lax.concatenate([acc, a], dimension=0) return acc
def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=t.dtype) mask = lax.concatenate( [full(1.0, shape=(1,)), full(2.0, shape=(n - 2 + is_odd,)), full(1.0, shape=(1 - is_odd,))], dimension=0) scale = 1 / prod(fft_lengths) out = scale * mask * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) return out
def _scale_and_translate(x, output_shape, scale, translate, kernel, antialias): input_shape = x.shape assert len(input_shape) == len(output_shape) assert len(input_shape) == len(scale) assert len(input_shape) == len(translate) spatial_dims = np.nonzero( np.not_equal(input_shape, output_shape) | np.not_equal(scale, 1) | np.not_equal(translate, 0))[0] if len(spatial_dims) == 0: return x output_spatial_shape = tuple(np.array(output_shape)[spatial_dims]) indices = [] contractions = [] slice_shape = list(input_shape) in_indices = list(range(len(output_shape) + len(spatial_dims))) out_indices = list(range(len(output_shape))) for i, d in enumerate(spatial_dims): m = input_shape[d] n = output_shape[d] starts, weights = _compute_spans(m, n, scale[d], translate[d], kernel, antialias=antialias) starts = lax.broadcast_in_dim(starts, output_spatial_shape + (1, ), (i, )) slice_shape[d] = weights.shape[1] indices.append(starts.astype(np.int32)) contractions.append(weights.astype(x.dtype)) contractions.append([len(output_shape) + i, d]) out_indices[d] = len(output_shape) + i index = lax.concatenate(indices, len(output_spatial_shape)) dnums = lax.GatherDimensionNumbers(offset_dims=tuple( range(len(output_shape))), collapsed_slice_dims=(), start_index_map=tuple(spatial_dims)) out = lax.gather(x, index, dnums, slice_shape) contractions.append(out_indices) return jnp.einsum(out, in_indices, *contractions, precision=lax.Precision.HIGHEST)
def _irfft_transpose(t, fft_lengths): # The transpose of IRFFT is the RFFT of the cotangent times a scaling # factor and a mask. The mask scales the cotangent for the Hermitian # symmetric components of the RFFT by a factor of two, since these components # are de-duplicated in the RFFT. x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 full = partial(lax.full_like, t, dtype=t.dtype) mask = lax.concatenate([ full(1.0, shape=(1, )), full(2.0, shape=(n - 2 + is_odd, )), full(1.0, shape=(1 - is_odd, )) ], dimension=0) scale = 1 / prod(fft_lengths) out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients # https://github.com/google/jax/issues/6223#issuecomment-807740707 return lax.conj(out)
def _concatenate_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray], cts_in: ShapedArray) -> np.ndarray: dimension = eqn.params['dimension'] js = [] inval = invals[idx] for i in range(len(invals)): inval_i = invals[i] inval_i_shape = tuple( inval_i.shape[k] if k == dimension else inval.shape[k] for k in range(inval.ndim)) if i == idx: j = np.eye(inval.size, dtype=inval.dtype) else: inval_i_size = onp.prod(inval_i_shape) j = np.zeros((inval_i_size, inval.size), inval.dtype) j = j.reshape(inval_i_shape + inval.shape) js.append(j) j = lax.concatenate(js, dimension) j = j.reshape(cts_in.shape + inval.shape) return j
def concat(y): return lax.concatenate([x, y], 0)
def duplicate(x): assert python_should_be_executing return lax.concatenate([x, x], 0)
def cat(x, y, z): return lax.concatenate([x, y, x], 0)
def _dct_interleave(x, axis): v0 = lax.slice_in_dim(x, None, None, 2, axis) v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis, )) return lax.concatenate([v0, v1], axis)
def concatenate(arrays, axis=0): if not arrays: raise ValueError("Need at least one array to concatenate.") return lax.concatenate(_promote_dtypes(*arrays), axis % ndim(arrays[0]))
def _rewriting_take(arr, idx, axis=0): """A function like numpy.take that handles boxes and rewrites to LAX.""" # Handle special indexers: (), Ellipsis, slice(None), and None. # TODO(mattjj): don't compare empty tuple identity (though works for CPython) if idx is () or idx is Ellipsis or _is_slice_none(idx): # pylint: disable=literal-comparison return arr elif idx is None: return expand_dims(arr, 0) # Handle int index _int = lambda aval: not aval.shape and onp.issubdtype( aval.dtype, onp.integer) try: abstract_idx = core.get_aval(idx) except TypeError: abstract_idx = None if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx): return lax.index_in_dim(arr, idx, axis, False) elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx): idx = mod(idx, arr.shape[axis]) return lax.dynamic_index_in_dim(arr, idx, axis, False) # Handle slice index (only static, otherwise an error is raised) elif isinstance(idx, slice): if not _all( elt is None or isinstance(core.get_aval(elt), ConcreteArray) for elt in (idx.start, idx.stop, idx.step)): msg = ( "Array slice indices must have static start/stop/step to be used " "with Numpy indexing syntax. Try lax.dynamic_slice instead.") raise IndexError(msg) else: start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis]) result = lax.slice_in_dim(arr, start, limit, stride, axis=axis) return lax.rev(result, [axis]) if needs_rev else result # Handle non-advanced tuple indices by recursing once elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx): canonical_idx = _canonicalize_tuple_index(arr, idx) result, axis = arr, 0 for elt in (elt for elt in canonical_idx if elt is not None): result = _rewriting_take(result, elt, axis=axis) axis += isinstance(elt, slice) # advance axis index if not eliminated unexpanded_shape_itr = iter(result.shape) result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr) for elt in canonical_idx if not isinstance(elt, int)) return lax.reshape(result, result_shape) # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool, # or a tuple with at least one sequence object). # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing # https://gist.github.com/seberg/976373b6a2b7c4188591 # Handle integer array indexing *without* ellipsis/slices/nones # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing if _is_advanced_int_indexer_without_slices(idx): if isinstance(idx, list): if _any(_shape(e) for e in idx): # At least one sequence element in the index list means broadcasting. idx = broadcast_arrays(*idx) else: # The index list is a flat list of integers. idx = [ lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0) ] else: # The indexer is just a single integer array. idx = [idx] flat_idx = tuple( mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx)) out = lax.index_take(arr, flat_idx, tuple(range(len(idx)))) return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):]) # Handle integer array indexing *with* ellipsis/slices/nones by recursing once # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing elif _is_advanced_int_indexer(idx): canonical_idx = _canonicalize_tuple_index(arr, tuple(idx)) idx_noadvanced = [ slice(None) if _is_int(e) else e for e in canonical_idx ] arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced)) advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e)) idx_advanced, axes = zip(*advanced_pairs) idx_advanced = broadcast_arrays(*idx_advanced) flat_idx = tuple( mod(ravel(x), arr_sliced.shape[i]) for i, x in zip(axes, idx_advanced)) out = lax.index_take(arr_sliced, flat_idx, axes) shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes)) out = lax.reshape(out, idx_advanced[0].shape + shape_suffix) axes_are_contiguous = onp.all(onp.diff(axes) == 1) if axes_are_contiguous: start = axes[0] naxes = idx_advanced[0].ndim out = moveaxis(out, list(range(naxes)), list(range(start, start + naxes))) return out msg = "Indexing mode not yet supported. Open a feature request!\n{}" raise IndexError(msg.format(idx))
def concat(x): return lax.concatenate([x, x], 0)