Example #1
0
 def body_fn(i, loop_carry):
   swaps, permutation = loop_carry
   j = swaps[i]
   x, y = np.ravel(permutation[i]), np.ravel(permutation[j])
   permutation = lax.dynamic_update_index_in_dim(permutation, y, i, axis=0)
   permutation = lax.dynamic_update_index_in_dim(permutation, x, j, axis=0)
   return swaps, permutation
Example #2
0
  def testDynamicUpdateSlice(self):
    x = np.random.randn(10, 3)
    y = np.random.randn(10)
    ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
               in_axes=(0, 0, None))(x, y, 1)
    expected = x.copy()
    expected[:, 1] = y
    self.assertAllClose(ans, expected, check_dtypes=False)

    x = np.random.randn(3)
    idx = np.array([0, 1, 2, 1, 0] * 2)
    ans = vmap(lambda x, y, i: lax.dynamic_update_index_in_dim(x, y, i, axis=0),
               in_axes=(None, 0, 0))(x, y, idx)
    expected = np.broadcast_to(x, (10, 3)).copy()
    expected[np.arange(10), idx] = y
    self.assertAllClose(ans, expected, check_dtypes=False)
Example #3
0
 def push(self, elem: Any) -> Stack:
     """Pushes `elem` onto the stack, returning the updated stack."""
     return Stack(
         self._size + 1,
         jax.tree_util.tree_map(
             lambda x, y: lax.dynamic_update_index_in_dim(
                 x, y, self._size, 0), self._data, elem))
Example #4
0
 def body_fun(i, vals):
     a, state = vals
     # select i-th element from each b
     b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs]
     a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b))
     # select fields from a_out and update state
     state_out = [
         lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0)
         for a, s in zip([tuple(a_out)[j] for j in fields], state)
     ]
     return a_out, state_out
Example #5
0
 def pop_body(i, carry):
     # We build up the decompressed data in the xs DeviceArray
     message, xs = carry
     message, x = codec_pop(message)
     return message, lax.dynamic_update_index_in_dim(xs, x, i, 0)
Example #6
0
 def pop_body(i, carry):
     m, xs = carry
     m, x = codec_pop(m)
     return m, lax.dynamic_update_index_in_dim(xs, x, i, 0)