示例#1
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
示例#2
0
def linspace(start,
             stop,
             num=50,
             endpoint=True,
             retstep=False,
             dtype=None,
             axis=0):
    """Implementation of linspace differentiable w.r.t. start and stop args."""
    lax._check_user_dtype_supported(dtype, "linspace")
    dtype = np.float32 if dtype is None else dtype
    bounds_shape = list(lax.broadcast_shapes(np.shape(start), np.shape(stop)))
    axis = len(bounds_shape) if axis == -1 else axis
    bounds_shape.insert(axis, 1)
    iota_shape = [
        1,
    ] * len(bounds_shape)
    iota_shape[axis] = num
    delta = (stop - start) / num
    if endpoint:
        delta *= num / (num - 1)
    out = (jnp.reshape(start, bounds_shape) +
           jnp.reshape(lax.iota(dtype, num), iota_shape) *
           jnp.reshape(delta, bounds_shape))
    if retstep:
        return jnp.array(out, dtype=dtype), delta
    else:
        return jnp.array(out, dtype=dtype)
示例#3
0
 def body_fn(i, permutation):
     j = swaps[..., i]
     iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
     x = permutation[..., i]
     y = permutation[iotas + (j, )]
     permutation = ops.index_update(permutation, ops.index[..., i], y)
     return ops.index_update(permutation, ops.index[iotas + (j, )], x)
示例#4
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots = lu_p.bind(a)

    if a_dot is ad_util.zero:
        return (core.pack(
            (lu, pivots)), ad.TangentTuple((ad_util.zero, ad_util.zero)))

    a_shape = np.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    permutation = lu_pivots_to_permutation(pivots, m)
    batch_dims = a_shape[:-2]
    iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = np._constant_like(lu, 0)
    l = lax.pad(np.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + np.eye(m, m, dtype=dtype)

    u_eye = lax.pad(np.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(np.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = np.matmul(l, np.tril(lau, -1))
    u_dot = np.matmul(np.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots), (lu_dot, ad_util.zero)
示例#5
0
文件: lax_linalg.py 项目: yotarok/jax
def _lu_pivots_body_fn(i, permutation_and_swaps):
    permutation, swaps = permutation_and_swaps
    batch_dims = swaps.shape[:-1]
    j = swaps[..., i]
    iotas = np.ix_(*(lax.iota(np.int32, b) for b in batch_dims))
    x = permutation[..., i]
    y = permutation[iotas + (j, )]
    permutation = ops.index_update(permutation, ops.index[..., i], y)
    return ops.index_update(permutation, ops.index[iotas + (j, )], x), swaps
示例#6
0
    def init(key, shape, dtype=np.float32):
        # assume shape is (horizon, num_state_dims, num_action_dims)
        T, n, m = shape

        def generate_lqr(key, _):
            new_key, p_key, a_key, b_key, q_key, r_key = random.split(key, 6)

            pmat = pmat_initializer(p_key, (n, n), dtype=dtype)
            lqr = LQR(
                A=amat_initializer(a_key, (n, n), dtype=dtype),
                B=bmat_initializer(b_key, (n, m), dtype=dtype),
                Q=qmat_initializer(q_key, (n, n), dtype=dtype),
                R=rmat_initializer(r_key, (m, m), dtype=dtype),
            )
            return new_key, (pmat, lqr)

        return lax.scan(generate_lqr, key, lax.iota(T))[1]
示例#7
0
def rollout(x_init, lqr, policy, num_steps):
    def timevarying_step(x_t, inputs):
        t, lqr = inputs
        u_t = policy(x_t, t)
        x_tp1 = lqr.A @ x_t + lqr.B @ u_t
        c_t = x_t @ (lqr.Q @ x_t) + u_t @ (lqr.R @ u_t)
        return x_tp1, (x_t, u_t, c_t)

    def nonvarying_step(x_t, t):
        return timevarying_step(x_t, (t, lqr))

    step_rollout = nonvarying_step
    values = lax.iota(np.int32, num_steps)

    if lqr.A.ndim == 3:
        step_rollout = timevarying_step
        values = (values, lqr)

    return lax.scan(step_rollout, x_init, values)
示例#8
0
    def cell_list_fn(position: Array,
                     capacity_overflow_update: Optional[Tuple[
                         int, bool, Callable[..., CellList]]] = None,
                     extra_capacity: int = 0,
                     **kwargs) -> CellList:
        N = position.shape[0]
        dim = position.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                f'Cell list spatial dimension must be 2 or 3. Found {dim}.')

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if capacity_overflow_update is None:
            cell_capacity = _estimate_cell_capacity(position, box_size,
                                                    cell_size,
                                                    buffer_size_multiplier)
            cell_capacity += extra_capacity
            overflow = False
            update_fn = cell_list_fn
        else:
            cell_capacity, overflow, update_fn = capacity_overflow_update

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(i32, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        cell_position = jnp.zeros((cell_count * cell_capacity, dim),
                                  dtype=position.dtype)
        cell_id = N * jnp.ones((cell_count * cell_capacity, 1), dtype=i32)

        # It might be worth adding an occupied mask. However, that will involve
        # more compute since often we will do a mask for species that will include
        # an occupancy test. It seems easier to design around this empty_data_value
        # for now and revisit the issue if it comes up later.
        empty_kwarg_value = 10**5
        cell_kwargs = {}
        #  pytype: disable=attribute-error
        for k, v in kwargs.items():
            if not util.is_array(v):
                raise ValueError(
                    (f'Data must be specified as an ndarray. Found "{k}" '
                     f'with type {type(v)}.'))
            if v.shape[0] != position.shape[0]:
                raise ValueError(
                    ('Data must be specified per-particle (an ndarray '
                     f'with shape ({N}, ...)). Found "{k}" with '
                     f'shape {v.shape}.'))
            kwarg_shape = v.shape[1:] if v.ndim > 1 else (1, )
            cell_kwargs[k] = empty_kwarg_value * jnp.ones(
                (cell_count * cell_capacity, ) + kwarg_shape, v.dtype)
        #  pytype: enable=attribute-error
        indices = jnp.array(position / cell_size, dtype=i32)
        hashes = jnp.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where
        # grid_id is a flat list that repeats 0, .., cell_capacity. So long as
        # there are fewer than cell_capacity particles per cell, each particle is
        # guarenteed to get a cell id that is unique.
        sort_map = jnp.argsort(hashes)
        sorted_position = position[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_kwargs = {}
        for k, v in kwargs.items():
            sorted_kwargs[k] = v[sort_map]

        sorted_cell_id = jnp.mod(lax.iota(i32, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
        sorted_id = jnp.reshape(sorted_id, (N, 1))
        cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
        cell_position = _unflatten_cell_buffer(cell_position, cells_per_side,
                                               dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        for k, v in sorted_kwargs.items():
            if v.ndim == 1:
                v = jnp.reshape(v, v.shape + (1, ))
            cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
            cell_kwargs[k] = _unflatten_cell_buffer(cell_kwargs[k],
                                                    cells_per_side, dim)

        occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
        max_occupancy = jnp.max(occupancy)
        overflow = overflow | (max_occupancy >= cell_capacity)

        return CellList(cell_position, cell_id, cell_kwargs, overflow,
                        cell_capacity, update_fn)  # pytype: disable=wrong-arg-count
示例#9
0
                  [RandArg((3, 4, 5), _f32), np.array([1, 2], np.int32)],
                  poly_axes=[0, None]),

    _make_harness("jnp_getitem", "",
                  lambda a, i: a[i],
                  [RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
                  poly_axes=[None, 0]),

    # TODO(necula): not supported yet
    # _make_harness("jnp_getitem", "",
    #               lambda a, i: a[i],
    #               [RandArg((3, 4), _f32), np.array([2, 2], np.int32)],
    #               poly_axes=[0, 0]),

    _make_harness("iota", "",
                  lambda x: x + lax.iota(_f32, x.shape[0]),
                  [RandArg((3,), _f32)],
                  poly_axes=[0]),

    _make_harness("jnp_matmul", "0",
                  jnp.matmul,
                  [RandArg((7, 8, 4), _f32), RandArg((7, 4, 5), _f32)],
                  poly_axes=[0, 0],
                  tol=1e-5),

    _make_harness("jnp_matmul", "1",
                  jnp.matmul,
                  [RandArg((7, 8, 4), _f32), RandArg((4, 5), _f32)],
                  poly_axes=[0, None],
                  tol=1e-5),
示例#10
0
  def build_cells(R, **kwargs):
    N = R.shape[0]
    dim = R.shape[1]

    if dim != 2 and dim != 3:
      # NOTE(schsam): Do we want to check this in compute_fn as well?
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(np.int64, N)
    # NOTE(schsam): We use the convention that particles that are successfully,
    # copied have their true id whereas particles empty slots have id = N.
    # Then when we copy data back from the grid, copy it to an array of shape
    # [N + 1, output_dimension] and then truncate it to an array of shape
    # [N, output_dimension] which ignores the empty slots.
    mask_id = np.ones((N,), np.int64) * N
    cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
    cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

    # It might be worth adding an occupied mask. However, that will involve
    # more compute since often we will do a mask for species that will include
    # an occupancy test. It seems easier to design around this empty_data_value
    # for now and revisit the issue if it comes up later.
    empty_kwarg_value = 10 ** 5
    cell_kwargs = {}
    for k, v in kwargs.items():
      if not isinstance(v, np.ndarray):
        raise ValueError((
          'Data must be specified as an ndarry. Found "{}" with '
          'type {}'.format(k, type(v))))
      if v.shape[0] != R.shape[0]:
        raise ValueError(
          ('Data must be specified per-particle (an ndarray with shape '
           '(R.shape[0], ...)). Found "{}" with shape {}'.format(k, v.shape)))
      kwarg_shape = v.shape[1:] if v.ndim > 1 else (1,)
      cell_kwargs[k] = empty_kwarg_value * np.ones(
        (cell_count * cell_capacity,) + kwarg_shape, v.dtype)

    indices = np.array(R / cell_size, dtype=i32)
    hashes = np.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = np.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_kwargs = {}
    for k, v in kwargs.items():
      sorted_kwargs[k] = v[sort_map]

    sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
    sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_id = np.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    for k, v in sorted_kwargs.items():
      if v.ndim == 1:
        v = np.reshape(v, v.shape + (1,))
      cell_kwargs[k] = ops.index_update(cell_kwargs[k], sorted_cell_id, v)
      cell_kwargs[k] = _unflatten_cell_buffer(
        cell_kwargs[k], cells_per_side, dim)

    return CellList(cell_R, cell_id, cell_kwargs)
示例#11
0
 def range_like(x):
     return lax.iota(np.int32, x.shape[0])
示例#12
0
 def f_jax(x):
     x + lax.iota(np.float32, x.shape[0])
示例#13
0
 def testShapeUsesBuiltinInt(self):
     x = lax.iota(np.int32, 3) + 1
     self.assertIsInstance(x.shape[0], int)  # not np.int64
示例#14
0
文件: smap.py 项目: berkonat/jax-md
    def build_cells(R):
        N = R.shape[0]
        dim = R.shape[1]
        neighborhood_tile_count = 3**dim

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if species is None:
            _species = np.zeros((N, ), dtype=i32)
        else:
            _species = species

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create grid data.
        particle_id = lax.iota(np.int64, N)
        # NOTE(schsam): We use the convention that particles that come from the
        # center cell have their true id copied, whereas particles that come from
        # the halo have an id = N. Then when we copy data back from the grid,
        # we copy it to an array of shape [N + 1, output_dimension] and then
        # truncate it to an array of shape [N, output_dimension] which ignores the
        # halo particles.
        mask_id = np.ones((N, ), np.int64) * N
        cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
        # NOTE(schsam): empty_species_index is just supposed to be large enough that
        # we will never run into it. However, there might be a more robust way to do
        # this.
        empty_species_index = i16(1000)
        cell_species = empty_species_index * np.ones(
            (cell_count * cell_capacity, 1), dtype=_species.dtype)
        cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

        indices = np.array(R / cell_size, dtype=i32)

        # Create a copy of particle data for each neighboring cell shifting the hash
        # appropriately.
        # TODO(schsam): Replace with np.tile() when it gets implemented.
        tiled_R = R
        tiled_species = _species
        for _ in range(neighborhood_tile_count - 1):
            tiled_R = np.concatenate((tiled_R, R), axis=0)
            tiled_species = np.concatenate((tiled_species, _species), axis=0)
        tiled_hash = np.array([], dtype=i32)
        tiled_id = np.array([], dtype=i32)

        for dindex in _neighboring_cells(dim):
            tiled_indices = np.mod(indices + dindex, cells_per_side)
            tiled_hash = np.concatenate(
                (tiled_hash, np.sum(tiled_indices * hash_multipliers, axis=1)),
                axis=0)

            if np.all(dindex == 0):
                tiled_id = np.concatenate((tiled_id, particle_id), axis=0)
            else:
                tiled_id = np.concatenate((tiled_id, mask_id), axis=0)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
        # is a flat list that repeats 0, .., cell_capacity. So long as there are
        # fewer than cell_capacity particles per cell, each particle is guarenteed
        # to get a cell id that is unique.
        sort_map = np.argsort(tiled_hash)
        sorted_R = tiled_R[sort_map]
        sorted_species = tiled_species[sort_map]
        sorted_hash = tiled_hash[sort_map]
        sorted_id = tiled_id[sort_map]

        tiled_size = neighborhood_tile_count * N
        sorted_cell_id = np.mod(lax.iota(np.int64, tiled_size), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        def copy_values_to_cell(cell_value, value, ids):
            scatter_indices = np.reshape(ids, (tiled_size, 1))
            dnums = lax.ScatterDimensionNumbers(
                update_window_dims=tuple([1]),
                inserted_window_dims=tuple([0]),
                scatter_dims_to_operand_dims=tuple([0]),
            )
            return lax.scatter(cell_value, scatter_indices, value, dnums)

        cell_R = copy_values_to_cell(cell_R, sorted_R, sorted_cell_id)
        sorted_species = np.reshape(sorted_species, (tiled_size, 1))
        cell_species = copy_values_to_cell(cell_species, sorted_species,
                                           sorted_cell_id)
        sorted_id = np.reshape(sorted_id, (tiled_size, 1))
        cell_id = copy_values_to_cell(cell_id, sorted_id, sorted_cell_id)

        cell_R = np.reshape(cell_R, (cell_count, cell_capacity, dim))
        cell_species = np.reshape(cell_species, (cell_count, cell_capacity))
        cell_id = np.reshape(cell_id, (cell_count, cell_capacity))

        return Grid(N, dim, cell_count, cell_R, cell_species, cell_id)
示例#15
0
文件: smap.py 项目: jaxmd/jax-md-1
    def build_cells(R):
        N = R.shape[0]
        dim = R.shape[1]

        if dim != 2 and dim != 3:
            # NOTE(schsam): Do we want to check this in compute_fn as well?
            raise ValueError(
                'Cell list spatial dimension must be 2 or 3. Found {}'.format(
                    dim))

        neighborhood_tile_count = 3**dim

        _, cell_size, cells_per_side, cell_count = \
            _cell_dimensions(dim, box_size, minimum_cell_size)

        if species is None:
            _species = np.zeros((N, ), dtype=i32)
        else:
            _species = species

        hash_multipliers = _compute_hash_constants(dim, cells_per_side)

        # Create cell list data.
        particle_id = lax.iota(np.int64, N)
        # NOTE(schsam): We use the convention that particles that are successfully,
        # copied have their true id whereas particles empty slots have id = N.
        # Then when we copy data back from the grid, copy it to an array of shape
        # [N + 1, output_dimension] and then truncate it to an array of shape
        # [N, output_dimension] which ignores the empty slots.
        mask_id = np.ones((N, ), np.int64) * N
        cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
        # NOTE(schsam): empty_species_index is just supposed to be large enough that
        # we will never run into it. However, there might be a more robust way to do
        # this.
        empty_species_index = i32(1000)
        cell_species = empty_species_index * np.ones(
            (cell_count * cell_capacity, 1), dtype=_species.dtype)
        cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

        indices = np.array(R / cell_size, dtype=i32)
        hashes = np.sum(indices * hash_multipliers, axis=1)

        # Copy the particle data into the grid. Here we use a trick to allow us to
        # copy into all cells simultaneously using a single lax.scatter call. To do
        # this we first sort particles by their cell hash. We then assign each
        # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
        # is a flat list that repeats 0, .., cell_capacity. So long as there are
        # fewer than cell_capacity particles per cell, each particle is guarenteed
        # to get a cell id that is unique.
        sort_map = np.argsort(hashes)
        sorted_R = R[sort_map]
        sorted_species = _species[sort_map]
        sorted_hash = hashes[sort_map]
        sorted_id = particle_id[sort_map]

        sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
        sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

        cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
        sorted_species = np.reshape(sorted_species, (N, 1))
        cell_species = ops.index_update(cell_species, sorted_cell_id,
                                        sorted_species)
        sorted_id = np.reshape(sorted_id, (N, 1))
        cell_id = ops.index_update(cell_id, sorted_cell_id, sorted_id)

        cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
        cell_species = _unflatten_cell_buffer(cell_species, cells_per_side,
                                              dim)
        cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

        return CellList(N, dim, cell_count, cell_R, cell_species, cell_id)
示例#16
0
 def padded_add(x):
     return x + lax.iota(x.shape[0])
示例#17
0
def _threefry_split(key, num) -> jnp.ndarray:
    counts = lax.iota(np.uint32, num * 2)
    return lax.reshape(threefry_2x32(key, counts), (num, 2))
示例#18
0
文件: linalg.py 项目: ahoenselaar/jax
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x
示例#19
0
文件: smap.py 项目: VardaHagh/jax-md
  def build_cells(R):
    N = R.shape[0]
    dim = R.shape[1]

    if dim != 2 and dim != 3:
      raise ValueError(
          'Cell list spatial dimension must be 2 or 3. Found {}'.format(dim))

    neighborhood_tile_count = 3 ** dim

    _, cell_size, cells_per_side, cell_count = \
        _cell_dimensions(dim, box_size, minimum_cell_size)

    if species is None:
      _species = np.zeros((N,), dtype=i32)
    else:
      _species = species

    hash_multipliers = _compute_hash_constants(dim, cells_per_side)

    # Create cell list data.
    particle_id = lax.iota(np.int64, N)
    mask_id = np.ones((N,), np.int64) * N
    cell_R = np.zeros((cell_count * cell_capacity, dim), dtype=R.dtype)
    empty_species_index = i32(1000)
    cell_species = empty_species_index * np.ones(
        (cell_count * cell_capacity, 1), dtype=_species.dtype)
    cell_id = N * np.ones((cell_count * cell_capacity, 1), dtype=i32)

    indices = np.array(R / cell_size, dtype=i32)
    hashes = np.sum(indices * hash_multipliers, axis=1)

    # Copy the particle data into the grid. Here we use a trick to allow us to
    # copy into all cells simultaneously using a single lax.scatter call. To do
    # this we first sort particles by their cell hash. We then assign each
    # particle to have a cell id = hash * cell_capacity + grid_id where grid_id
    # is a flat list that repeats 0, .., cell_capacity. So long as there are
    # fewer than cell_capacity particles per cell, each particle is guarenteed
    # to get a cell id that is unique.
    sort_map = np.argsort(hashes)
    sorted_R = R[sort_map]
    sorted_species = _species[sort_map]
    sorted_hash = hashes[sort_map]
    sorted_id = particle_id[sort_map]

    sorted_cell_id = np.mod(lax.iota(np.int64, N), cell_capacity)
    sorted_cell_id = sorted_hash * cell_capacity + sorted_cell_id

    cell_R = ops.index_update(cell_R, sorted_cell_id, sorted_R)
    sorted_species = np.reshape(sorted_species, (N, 1))
    cell_species = ops.index_update(
        cell_species, sorted_cell_id, sorted_species)
    sorted_id = np.reshape(sorted_id, (N, 1))
    cell_id = ops.index_update(
        cell_id, sorted_cell_id, sorted_id)
    
    cell_R = _unflatten_cell_buffer(cell_R, cells_per_side, dim)
    cell_species = _unflatten_cell_buffer(cell_species, cells_per_side, dim)
    cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)

    return CellList(N, dim, cell_count, cell_R, cell_species, cell_id)