Ejemplo n.º 1
0
 def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory,
                                      indexer):
     rng = rng_factory()
     tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
     arg = rng(shape, dtype)
     fun = lambda x: jnp.asarray(x)[indexer]
     check_grads(fun, (arg, ), 2, tol, tol, eps=1.)
Ejemplo n.º 2
0
    def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory,
                                             indexer):
        rng = rng_factory(self.rng())
        tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
        unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer)

        @npe.jit
        def fun(unpacked_indexer, x):
            indexer = pack_indexer(unpacked_indexer)
            return x[indexer]

        arr = rng(shape, dtype)
        check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol)
Ejemplo n.º 3
0
 def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer):
     rng = rng_factory()
     tol = 1e-2 if jnp.finfo(dtype).bits == 32 else None
     arg = rng(shape, dtype)
     fun = lambda x: x[indexer]**2
     check_grads(fun, (arg, ), 2, tol, tol, tol)