def body_fn(i, loop_carry): swaps, permutation = loop_carry j = swaps[i] x, y = np.ravel(permutation[i]), np.ravel(permutation[j]) permutation = lax.dynamic_update_index_in_dim(permutation, y, i, axis=0) permutation = lax.dynamic_update_index_in_dim(permutation, x, j, axis=0) return swaps, permutation
def testDynamicUpdateSlice(self): x = np.random.randn(10, 3) y = np.random.randn(10) ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(0, 0, None))(x, y, 1) expected = x.copy() expected[:, 1] = y self.assertAllClose(ans, expected, check_dtypes=False) x = np.random.randn(3) idx = np.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0), in_axes=(None, 0, 0))(x, y, idx) expected = np.broadcast_to(x, (10, 3)).copy() expected[np.arange(10), idx] = y self.assertAllClose(ans, expected, check_dtypes=False)
def push(self, elem: Any) -> Stack: """Pushes `elem` onto the stack, returning the updated stack.""" return Stack( self._size + 1, jax.tree_util.tree_map( lambda x, y: lax.dynamic_update_index_in_dim( x, y, self._size, 0), self._data, elem))
def body_fun(i, vals): a, state = vals # select i-th element from each b b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs] a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b)) # select fields from a_out and update state state_out = [ lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0) for a, s in zip([tuple(a_out)[j] for j in fields], state) ] return a_out, state_out
def pop_body(i, carry): # We build up the decompressed data in the xs DeviceArray message, xs = carry message, x = codec_pop(message) return message, lax.dynamic_update_index_in_dim(xs, x, i, 0)
def pop_body(i, carry): m, xs = carry m, x = codec_pop(m) return m, lax.dynamic_update_index_in_dim(xs, x, i, 0)