Example #1
0
 def testLaxFftAcceptsStringTypes(self):
     rng = jtu.rand_default(self.rng())
     x = rng((10, ), np.complex64)
     self.assertAllClose(
         np.fft.fft(x).astype(np.complex64),
         lax.fft(x, "FFT", fft_lengths=(10, )))
Example #2
0
 def test_dot(self):
   self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)',
              dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'],
              jtu.rand_default(self.rng()))
   self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)],
              ['float_', 'float_'], jtu.rand_default(self.rng()))
Example #3
0
 def test_split(self):
     self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
                [(8, )], ['float_'], jtu.rand_default(self.rng()))
     self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'],
                dict(n=12), [(12, )], ['float_'],
                jtu.rand_default(self.rng()))
Example #4
0
 def testSparseAttrAccess(self, attr):
   rng = jtu.rand_default(self.rng())
   args_maker = lambda: [make_sparse_array(rng, (10,), jnp.float32)]
   f = lambda x: getattr(x, attr)
   self._CompileAndCheck(f, args_maker)
Example #5
0
 def test_transpose(self):
     self.check(lambda x: lax.transpose(x, (1, 0, 2)), ['(a, b, c)'],
                'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)], ['float_'],
                jtu.rand_default(self.rng()))
Example #6
0
 def test_expit(self):
     self.check(expit, ['n'], 'n', dict(n=3), [(4, )], ['float_'],
                jtu.rand_default(self.rng()))
Example #7
0
 def testReducePrecision(self, shape, dtype, nmant, nexp, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.reduce_precision(
         x, exponent_bits=nexp, mantissa_bits=nmant)
     self._CheckBatching(op, 10, bdims, (shape, ), (dtype, ), rng)
Example #8
0
    def test_numpy_pad(self):
        def numpy_pad(x):
            return jnp.pad(x, (0, 1), constant_values=5.)

        self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3, )], ['float_'],
                   jtu.rand_default(self.rng()))
Example #9
0
 def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims):
   fun = partial(lax.scatter_add, dimension_numbers=dnums)
   self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape],
                       [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()),
                       rtol={np.float16: 5e-3, dtypes.bfloat16: 3e-2})
Example #10
0
 def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.convert_element_type(x, to_dtype)
     self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)
Example #11
0
 def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
   fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes)
   self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                       jtu.rand_default(self.rng()))
   self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype],
                       jtu.rand_default(self.rng()))
Example #12
0
 def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
   rng = jtu.rand_default(self.rng())
   op = lambda c, x, y: lax.select(c < 0, x, y)
   self._CheckBatching(op, 5, bdims, (pred_shape, arg_shape, arg_shape,),
                       (np.bool_, arg_dtype, arg_dtype), rng)
Example #13
0
 def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
   rng = jtu.rand_default(self.rng())
   op = partial(lax.dot, precision=lax.Precision.HIGHEST)
   self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                       rng, rtol={np.float16: 5e-2, np.float64: 5e-14})
Example #14
0
 def testArgminmax(self, op, shape, dtype, dim, bdims):
     rng = jtu.rand_default(self.rng())
     fun = lambda operand: op(operand, dim, np.int32)
     self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
Example #15
0
 def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
     rng = jtu.rand_default(self.rng())
     shapes = [min_shape, operand_shape, max_shape]
     self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)
Example #16
0
 def test_mean(self):
     self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'],
                '', {'n': 3}, [(4, )], ['float_'],
                jtu.rand_default(self.rng()))
Example #17
0
 def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.broadcast(x, broadcast_sizes)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #18
0
 def test_indexing(self):
     self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3, )], ['float_'],
                jtu.rand_default(self.rng()))
     self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3, )], ['float_'],
                jtu.rand_default(self.rng()))
Example #19
0
 def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
     self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng)
Example #20
0
 def test_sum_2d(self):
     self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)],
                ['float_'], jtu.rand_default(self.rng()))
Example #21
0
 def testSqueeze(self, arg_shape, dimensions, bdims):
     dtype = np.float32
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.squeeze(x, dimensions)
     self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
Example #22
0
 def test_where(self):
     self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n', {'n': 2},
                [(3, )], ['float_'], jtu.rand_default(self.rng()))
Example #23
0
 def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
     self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
Example #24
0
 def test_reduce(self, operator):
     self.check(operator, ['(m+1, n+1)'], '', {
         'm': 3,
         'n': 4
     }, [(4, 5)], ['float_'], jtu.rand_default(self.rng()))
Example #25
0
 def testSlice(self, shape, dtype, starts, limits, strides, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.slice(x, starts, limits, strides)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #26
0
 def testSparseMatvec(self, shape, dtype):
   rng = jtu.rand_default(self.rng())
   args_maker = lambda: [make_sparse_array(rng, shape, dtype), rng(shape[-1:], dtype)]
   self._CompileAndCheck(matvec, args_maker)
Example #27
0
 def testTranspose(self, shape, dtype, perm, bdims):
     rng = jtu.rand_default(self.rng())
     op = lambda x: lax.transpose(x, perm)
     self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
Example #28
0
 def testFftshift(self, shape, dtype, axes):
     rng = jtu.rand_default(self.rng())
     args_maker = lambda: (rng(shape, dtype), )
     jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
     np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
     self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
Example #29
0
 def test_concatenate(self):
   self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
              ['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3},
              [(4,), (3,), (4,)], ['float_', 'float_', 'float_'],
              jtu.rand_default(self.rng()))