def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0): """Implementation of linspace differentiable w.r.t. start and stop args.""" lax._check_user_dtype_supported(dtype, "linspace") dtype = np.float32 if dtype is None else dtype bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop))) axis = len(bounds_shape) if axis == -1 else axis bounds_shape.insert(axis, 1) iota_shape = [ 1, ] * len(bounds_shape) iota_shape[axis] = num delta = (stop - start) / num if endpoint: delta *= num / (num - 1) out = (jnp.reshape(start, bounds_shape) + jnp.reshape(lax.iota(dtype, num), iota_shape) * jnp.reshape(delta, bounds_shape)) if retstep: return jnp.array(out, dtype=dtype), delta else: return jnp.array(out, dtype=dtype)
def body_fn(i, permutation): j = swaps[..., i] iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims)) x = permutation[..., i] y = permutation[iotas + (j, )] permutation = ops.index_update(permutation, ops.index[..., i], y) return ops.index_update(permutation, ops.index[iotas + (j, )], x)
def _lu_jvp_rule(primals, tangents): a, = primals a_dot, = tangents lu, pivots = lu_p.bind(a) if a_dot is ad_util.zero: return (core.pack( (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero))) a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) permutation = lu_pivots_to_permutation(pivots, m) batch_dims = a_shape[:-2] iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, ))) x = a_dot[iotas[:-1] + (permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas # # LU = A # ==> L'U + LU' = A' # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U) # ==> L' = L . tril(inv(L) . A' . inv(U), -1) # U' = triu(inv(L) . A' . inv(U)) . U ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = np._constant_like(lu, 0) l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding) l = l + np.eye(m, m, dtype=dtype) u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) l_dot = np.matmul(l, np.tril(lau, -1)) u_dot = np.matmul(np.triu(lau), u) lu_dot = l_dot + u_dot return (lu, pivots), (lu_dot, ad_util.zero)
def _lu_pivots_body_fn(i, permutation_and_swaps): permutation, swaps = permutation_and_swaps batch_dims = swaps.shape[:-1] j = swaps[..., i] iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims)) x = permutation[..., i] y = permutation[iotas + (j, )] permutation = ops.index_update(permutation, ops.index[..., i], y) return ops.index_update(permutation, ops.index[iotas + (j, )], x), swaps
def init(key, shape, dtype=np.float32): # assume shape is (horizon, num_state_dims, num_action_dims) T, n, m = shape def generate_lqr(key, _): new_key, p_key, a_key, b_key, q_key, r_key = random.split(key, 6) pmat = pmat_initializer(p_key, (n, n), dtype=dtype) lqr = LQR( A=amat_initializer(a_key, (n, n), dtype=dtype), B=bmat_initializer(b_key, (n, m), dtype=dtype), Q=qmat_initializer(q_key, (n, n), dtype=dtype), R=rmat_initializer(r_key, (m, m), dtype=dtype), ) return new_key, (pmat, lqr) return lax.scan(generate_lqr, key, lax.iota(T))[1]
def rollout(x_init, lqr, policy, num_steps): def timevarying_step(x_t, inputs): t, lqr = inputs u_t = policy(x_t, t) x_tp1 = lqr.A @ x_t + lqr.B @ u_t c_t = x_t @ (lqr.Q @ x_t) + u_t @ (lqr.R @ u_t) return x_tp1, (x_t, u_t, c_t) def nonvarying_step(x_t, t): return timevarying_step(x_t, (t, lqr)) step_rollout = nonvarying_step values = lax.iota(np.int32, num_steps) if lqr.A.ndim == 3: step_rollout = timevarying_step values = (values, lqr) return lax.scan(step_rollout, x_init, values)
def cell_list_fn(position: Array, capacity_overflow_update: Optional[Tuple[ int, bool, Callable[..., CellList]]] = None, extra_capacity: int = 0, **kwargs) -> CellList: N = position.shape[0] dim = position.shape[1] if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( f'Cell list spatial dimension must be 2 or 3. Found {dim}.') _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if capacity_overflow_update is None: cell_capacity = _estimate_cell_capacity(position, box_size, cell_size, buffer_size_multiplier) cell_capacity += extra_capacity overflow = False update_fn = cell_list_fn else: cell_capacity, overflow, update_fn = capacity_overflow_update hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(i32, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. cell_position = jnp.zeros((cell_count * cell_capacity, dim), dtype=position.dtype) cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32) # It might be worth adding an occupied mask. However, that will involve # more compute since often we will do a mask for species that will include # an occupancy test. It seems easier to design around this empty_data_value # for now and revisit the issue if it comes up later. empty_kwarg_value = 10**5 cell_kwargs = {} # pytype: disable=attribute-error for k, v in kwargs.items(): if not util.is_array(v): raise ValueError( (f'Data must be specified as an ndarray. Found "{k}" ' f'with type {type(v)}.')) if v.shape[0] != position.shape[0]: raise ValueError( ('Data must be specified per-particle (an ndarray ' f'with shape ({N}, ...)). Found "{k}" with ' f'shape {v.shape}.')) kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, ) cell_kwargs[k] = empty_kwarg_value * jnp.ones( (cell_count * cell_capacity, ) + kwarg_shape, v.dtype) # pytype: enable=attribute-error indices = jnp.array(position / cell_size, dtype=i32) hashes = jnp.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where # grid_id is a flat list that repeats 0, .., cell_capacity. So long as # there are fewer than cell_capacity particles per cell, each particle is # guarenteed to get a cell id that is unique. sort_map = jnp.argsort(hashes) sorted_position = position[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_kwargs = {} for k, v in kwargs.items(): sorted_kwargs[k] = v[sort_map] sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1, )) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) max_occupancy = jnp.max(occupancy) overflow = overflow | (max_occupancy >= cell_capacity) return CellList(cell_position, cell_id, cell_kwargs, overflow, cell_capacity, update_fn) # pytype: disable=wrong-arg-count
[RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)], poly_axes=[0, None]), _make_harness("jnp_getitem", "", lambda a, i: a[i], [RandArg((3, 4), _f32), np.array([2, 2], np.int32)], poly_axes=[None, 0]), # TODO(necula): not supported yet # _make_harness("jnp_getitem", "", # lambda a, i: a[i], # [RandArg((3, 4), _f32), np.array([2, 2], np.int32)], # poly_axes=[0, 0]), _make_harness("iota", "", lambda x: x + lax.iota(_f32, x.shape[0]), [RandArg((3,), _f32)], poly_axes=[0]), _make_harness("jnp_matmul", "0", jnp.matmul, [RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)], poly_axes=[0, 0], tol=1e-5), _make_harness("jnp_matmul", "1", jnp.matmul, [RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)], poly_axes=[0, None], tol=1e-5),
def build_cells(R, **kwargs): N = R.shape[0] dim = R.shape[1] if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( 'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim)) neighborhood_tile_count = 3 ** dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(np.int64, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. mask_id = np.ones((N,), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) # It might be worth adding an occupied mask. However, that will involve # more compute since often we will do a mask for species that will include # an occupancy test. It seems easier to design around this empty_data_value # for now and revisit the issue if it comes up later. empty_kwarg_value = 10 ** 5 cell_kwargs = {} for k, v in kwargs.items(): if not isinstance(v, np.ndarray): raise ValueError(( 'Data must be specified as an ndarry. Found "{}" with ' 'type {}'.format(k, type(v)))) if v.shape[0] != R.shape[0]: raise ValueError( ('Data must be specified per-particle (an ndarray with shape ' '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape))) kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,) cell_kwargs[k] = empty_kwarg_value * np.ones( (cell_count * cell_capacity,) + kwarg_shape, v.dtype) indices = np.array(R / cell_size, dtype=i32) hashes = np.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(hashes) sorted_R = R[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_kwargs = {} for k, v in kwargs.items(): sorted_kwargs[k] = v[sort_map] sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R) sorted_id = np.reshape(sorted_id, (N, 1)) cell_id = ops.index_update( cell_id, sorted_cell_id, sorted_id) cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = np.reshape(v, v.shape + (1,)) cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v) cell_kwargs[k] = _unflatten_cell_buffer( cell_kwargs[k], cells_per_side, dim) return CellList(cell_R, cell_id, cell_kwargs)
def range_like(x): return lax.iota(np.int32, x.shape[0])
def f_jax(x): x + lax.iota(np.float32, x.shape[0])
def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64
def build_cells(R): N = R.shape[0] dim = R.shape[1] neighborhood_tile_count = 3**dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if species is None: _species = np.zeros((N, ), dtype=i32) else: _species = species hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create grid data. particle_id = lax.iota(np.int64, N) # NOTE(schsam): We use the convention that particles that come from the # center cell have their true id copied, whereas particles that come from # the halo have an id = N. Then when we copy data back from the grid, # we copy it to an array of shape [N + 1, output_dimension] and then # truncate it to an array of shape [N, output_dimension] which ignores the # halo particles. mask_id = np.ones((N, ), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) # NOTE(schsam): empty_species_index is just supposed to be large enough that # we will never run into it. However, there might be a more robust way to do # this. empty_species_index = i16(1000) cell_species = empty_species_index * np.ones( (cell_count * cell_capacity, 1), dtype=_species.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) indices = np.array(R / cell_size, dtype=i32) # Create a copy of particle data for each neighboring cell shifting the hash # appropriately. # TODO(schsam): Replace with np.tile() when it gets implemented. tiled_R = R tiled_species = _species for _ in range(neighborhood_tile_count - 1): tiled_R = np.concatenate((tiled_R, R), axis=0) tiled_species = np.concatenate((tiled_species, _species), axis=0) tiled_hash = np.array([], dtype=i32) tiled_id = np.array([], dtype=i32) for dindex in _neighboring_cells(dim): tiled_indices = np.mod(indices + dindex, cells_per_side) tiled_hash = np.concatenate( (tiled_hash, np.sum(tiled_indices * hash_multipliers, axis=1)), axis=0) if np.all(dindex == 0): tiled_id = np.concatenate((tiled_id, particle_id), axis=0) else: tiled_id = np.concatenate((tiled_id, mask_id), axis=0) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(tiled_hash) sorted_R = tiled_R[sort_map] sorted_species = tiled_species[sort_map] sorted_hash = tiled_hash[sort_map] sorted_id = tiled_id[sort_map] tiled_size = neighborhood_tile_count * N sorted_cell_id = np.mod(lax.iota(np.int64, tiled_size), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id def copy_values_to_cell(cell_value, value, ids): scatter_indices = np.reshape(ids, (tiled_size, 1)) dnums = lax.ScatterDimensionNumbers( update_window_dims=tuple([1]), inserted_window_dims=tuple([0]), scatter_dims_to_operand_dims=tuple([0]), ) return lax.scatter(cell_value, scatter_indices, value, dnums) cell_R = copy_values_to_cell(cell_R, sorted_R, sorted_cell_id) sorted_species = np.reshape(sorted_species, (tiled_size, 1)) cell_species = copy_values_to_cell(cell_species, sorted_species, sorted_cell_id) sorted_id = np.reshape(sorted_id, (tiled_size, 1)) cell_id = copy_values_to_cell(cell_id, sorted_id, sorted_cell_id) cell_R = np.reshape(cell_R, (cell_count, cell_capacity, dim)) cell_species = np.reshape(cell_species, (cell_count, cell_capacity)) cell_id = np.reshape(cell_id, (cell_count, cell_capacity)) return Grid(N, dim, cell_count, cell_R, cell_species, cell_id)
def build_cells(R): N = R.shape[0] dim = R.shape[1] if dim != 2 and dim != 3: # NOTE(schsam): Do we want to check this in compute_fn as well? raise ValueError( 'Cell list spatial dimension must be 2 or 3. Found {}'.format( dim)) neighborhood_tile_count = 3**dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if species is None: _species = np.zeros((N, ), dtype=i32) else: _species = species hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(np.int64, N) # NOTE(schsam): We use the convention that particles that are successfully, # copied have their true id whereas particles empty slots have id = N. # Then when we copy data back from the grid, copy it to an array of shape # [N + 1, output_dimension] and then truncate it to an array of shape # [N, output_dimension] which ignores the empty slots. mask_id = np.ones((N, ), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) # NOTE(schsam): empty_species_index is just supposed to be large enough that # we will never run into it. However, there might be a more robust way to do # this. empty_species_index = i32(1000) cell_species = empty_species_index * np.ones( (cell_count * cell_capacity, 1), dtype=_species.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) indices = np.array(R / cell_size, dtype=i32) hashes = np.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(hashes) sorted_R = R[sort_map] sorted_species = _species[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R) sorted_species = np.reshape(sorted_species, (N, 1)) cell_species = ops.index_update(cell_species, sorted_cell_id, sorted_species) sorted_id = np.reshape(sorted_id, (N, 1)) cell_id = ops.index_update(cell_id, sorted_cell_id, sorted_id) cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim) cell_species = _unflatten_cell_buffer(cell_species, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) return CellList(N, dim, cell_count, cell_R, cell_species, cell_id)
def padded_add(x): return x + lax.iota(x.shape[0])
def _threefry_split(key, num) -> jnp.ndarray: counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (num, 2))
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x
def build_cells(R): N = R.shape[0] dim = R.shape[1] if dim != 2 and dim != 3: raise ValueError( 'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim)) neighborhood_tile_count = 3 ** dim _, cell_size, cells_per_side, cell_count = \ _cell_dimensions(dim, box_size, minimum_cell_size) if species is None: _species = np.zeros((N,), dtype=i32) else: _species = species hash_multipliers = _compute_hash_constants(dim, cells_per_side) # Create cell list data. particle_id = lax.iota(np.int64, N) mask_id = np.ones((N,), np.int64) * N cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype) empty_species_index = i32(1000) cell_species = empty_species_index * np.ones( (cell_count * cell_capacity, 1), dtype=_species.dtype) cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32) indices = np.array(R / cell_size, dtype=i32) hashes = np.sum(indices * hash_multipliers, axis=1) # Copy the particle data into the grid. Here we use a trick to allow us to # copy into all cells simultaneously using a single lax.scatter call. To do # this we first sort particles by their cell hash. We then assign each # particle to have a cell id = hash * cell_capacity + grid_id where grid_id # is a flat list that repeats 0, .., cell_capacity. So long as there are # fewer than cell_capacity particles per cell, each particle is guarenteed # to get a cell id that is unique. sort_map = np.argsort(hashes) sorted_R = R[sort_map] sorted_species = _species[sort_map] sorted_hash = hashes[sort_map] sorted_id = particle_id[sort_map] sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity) sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R) sorted_species = np.reshape(sorted_species, (N, 1)) cell_species = ops.index_update( cell_species, sorted_cell_id, sorted_species) sorted_id = np.reshape(sorted_id, (N, 1)) cell_id = ops.index_update( cell_id, sorted_cell_id, sorted_id) cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim) cell_species = _unflatten_cell_buffer(cell_species, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) return CellList(N, dim, cell_count, cell_R, cell_species, cell_id)