Beispiel #1
0
 def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums):
   rng = jtu.rand_default(self.rng())
   rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape))
   idxs = rng_idx(idxs.shape, idxs.dtype)
   scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
   x = rng(arg_shape, dtype)
   y = rng(update_shape, dtype)
   check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
Beispiel #2
0
 def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
                    rng_factory, rng_idx_factory):
   rng = rng_factory(self.rng())
   rng_idx = rng_idx_factory(self.rng())
   idxs = rng_idx(idxs.shape, idxs.dtype)
   scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
   x = rng(arg_shape, dtype)
   y = rng(update_shape, dtype)
   check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)