def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)
def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)