def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: jnp.asarray(x)[indexer] check_grads(fun, (arg, ), 2, tol, tol, eps=1.)
def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory(self.rng()) tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @npe.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol)
def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol)