Example #1
0
File: nn_test.py Project: gtr8/jax
class NNInitializersTest(jtu.JaxTestCase):
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_{}_{}".format(
           rec.name,
           jtu.format_shape_dtype_string(shape, dtype)),
       "initializer": rec.initializer(),
       "shape": shape, "dtype": dtype}
      for rec in INITIALIZER_RECS
      for shape in rec.shapes
      for dtype in rec.dtypes))
  def testInitializer(self, initializer, shape, dtype):
    rng = random.PRNGKey(0)
    val = initializer(rng, shape, dtype)

    self.assertEqual(shape, jnp.shape(val))
    self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_{}_{}".format(
           rec.name,
           jtu.format_shape_dtype_string(shape, dtype)),
       "initializer_provider": rec.initializer,
       "shape": shape, "dtype": dtype}
      for rec in INITIALIZER_RECS
      for shape in rec.shapes
      for dtype in rec.dtypes))
  def testInitializerProvider(self, initializer_provider, shape, dtype):
    rng = random.PRNGKey(0)
    initializer = initializer_provider(dtype=dtype)
    val = initializer(rng, shape)

    self.assertEqual(shape, jnp.shape(val))
    self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))

  def testVarianceScalingMultiAxis(self):
    rng = random.PRNGKey(0)
    shape = (2, 3, 4, 5)
    initializer = nn.initializers.variance_scaling(
      scale=1.0, mode='fan_avg', distribution='truncated_normal',
      in_axis=(0, 1), out_axis=(-2, -1))
    val = initializer(rng, shape)

    self.assertEqual(shape, jnp.shape(val))

  def testVarianceScalingBatchAxis(self):
    rng = random.PRNGKey(0)
    shape = (2, 3, 4, 5)
    initializer = nn.initializers.variance_scaling(
      scale=1.0, mode='fan_avg', distribution='truncated_normal',
      in_axis=0, out_axis=(2, 3), batch_axis=1)
    val = initializer(rng, shape)

    self.assertEqual(shape, jnp.shape(val))
Example #2
0
class LaxBackedScipyFftTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.fft implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(
                testcase_name=
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}_n={n}_axis={axis}_norm={norm}",
                shape=shape,
                dtype=dtype,
                n=n,
                axis=axis,
                norm=norm) for dtype in real_dtypes
            for shape in [(10, ), (2, 5)] for n in [None, 1, 7, 13, 20]
            for axis in [-1, 0] for norm in [None, 'ortho']))
    @jtu.skip_on_devices("rocm")
    def testDct(self, shape, dtype, n, axis, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_fn = lambda a: jsp_fft.dct(a, n=n, axis=axis, norm=norm)
        np_fn = lambda a: osp_fft.dct(a, n=n, axis=axis, norm=norm)
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(
                testcase_name=
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}_axes={axes}_s={s}_norm={norm}",
                shape=shape,
                dtype=dtype,
                s=s,
                axes=axes,
                norm=norm) for dtype in real_dtypes
            for shape in [(10, ), (10, 10), (9, ), (2, 3, 4), (2, 3, 4, 5)]
            for axes in _get_dctn_test_axes(shape)
            for s in _get_dctn_test_s(shape, axes)
            for norm in [None, 'ortho']))
    @jtu.skip_on_devices("rocm")
    def testDctn(self, shape, dtype, s, axes, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_fn = lambda a: jsp_fft.dctn(a, s=s, axes=axes, norm=norm)
        np_fn = lambda a: osp_fft.dctn(a, shape=s, axes=axes, norm=norm)
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4)
Example #3
0
class NNInitializersTest(jtu.JaxTestCase):
    def setUp(self):
        super().setUp()
        config.update("jax_numpy_rank_promotion", "raise")

    def tearDown(self):
        super().tearDown()
        config.update("jax_numpy_rank_promotion", "allow")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(rec.name,
                            jtu.format_shape_dtype_string(shape, dtype)),
            "initializer":
            rec.initializer(),
            "shape":
            shape,
            "dtype":
            dtype
        } for rec in INITIALIZER_RECS for shape in rec.shapes
                            for dtype in rec.dtypes))
    def testInitializer(self, initializer, shape, dtype):
        rng = random.PRNGKey(0)
        val = initializer(rng, shape, dtype)

        self.assertEqual(shape, jnp.shape(val))
        self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_{}".format(rec.name,
                            jtu.format_shape_dtype_string(shape, dtype)),
            "initializer_provider":
            rec.initializer,
            "shape":
            shape,
            "dtype":
            dtype
        } for rec in INITIALIZER_RECS for shape in rec.shapes
                            for dtype in rec.dtypes))
    def testInitializerProvider(self, initializer_provider, shape, dtype):
        rng = random.PRNGKey(0)
        initializer = initializer_provider(dtype=dtype)
        val = initializer(rng, shape)

        self.assertEqual(shape, jnp.shape(val))
        self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))
Example #4
0
def genNamedParametersNArgs(n):
  return parameterized.named_parameters(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
          "shapes": shapes, "dtypes": dtypes}
        for shapes in itertools.combinations_with_replacement(all_shapes, n)
        for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
Example #5
0
class CustomErrorsTest(jtu.JaxTestCase):
  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": f"_{errorclass}", "errorclass": errorclass}
     for errorclass in dir(jax.errors)
     if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError']))
  def testErrorsURL(self, errorclass):
    class FakeTracer(core.Tracer):
      aval = None
    ErrorClass = getattr(jax.errors, errorclass)
    err = ErrorClass(FakeTracer(None))

    self.assertIn(f'https://jax.readthedocs.io/en/latest/errors.html#jax.errors.{errorclass}', str(err))
Example #6
0
class LaxBackedScipyInterpolateTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.interpolate implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_spaces={spaces}_method={method}",
            "spaces": spaces,
            "method": method
        } for spaces in (((0., 10., 10), ), ((-15., 20., 12), (3., 4., 24)))
                            for method in ("linear", "nearest")))
    def testRegularGridInterpolator(self, spaces, method):
        rng = jtu.rand_default(self.rng())
        scipy_fun = lambda init_args, call_args: sp_interp.RegularGridInterpolator(
            *init_args[:2], method, False, *init_args[2:])(*call_args)
        lax_fun = lambda init_args, call_args: jsp_interp.RegularGridInterpolator(
            *init_args[:2], method, False, *init_args[2:])(*call_args)

        def args_maker():
            points = tuple(map(lambda x: np.linspace(*x), spaces))
            values = rng(reduce(operator.add, tuple(map(np.shape, points))),
                         float)
            fill_value = np.nan

            init_args = (points, values, fill_value)
            n_validation_points = 50
            valid_points = tuple(
                map(
                    lambda x: np.linspace(x[0] - 0.2 *
                                          (x[1] - x[0]), x[1] + 0.2 *
                                          (x[1] - x[0]), n_validation_points),
                    spaces))
            valid_points = np.squeeze(np.stack(valid_points, axis=1))
            call_args = (valid_points, )
            return init_args, call_args

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
class CudaArrayInterfaceTest(jtu.JaxTestCase):

  def setUp(self):
    super().setUp()
    if jtu.device_under_test() != "gpu":
      self.skipTest("__cuda_array_interface__ is only supported on GPU")

  @parameterized.named_parameters(jtu.cases_from_list(
     {"testcase_name": "_{}".format(
        jtu.format_shape_dtype_string(shape, dtype)),
     "shape": shape, "dtype": dtype}
     for shape in all_shapes
     for dtype in dlpack_dtypes))
  @unittest.skipIf(not cupy, "Test requires CuPy")
  def testJaxToCuPy(self, shape, dtype):
    rng = jtu.rand_default(self.rng())
    x = rng(shape, dtype)
    y = jnp.array(x)
    z = cupy.asarray(y)
    self.assertEqual(y.__cuda_array_interface__["data"][0],
                     z.__cuda_array_interface__["data"][0])
    self.assertAllClose(x, cupy.asnumpy(z))
Example #8
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
                jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims,
                return_sign, use_b),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "shapes":
            shapes,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "return_sign":
            return_sign,
            "use_b":
            use_b
        } for shape_group in compatible_shapes for dtype in float_dtypes +
                            complex_dtypes + int_dtypes
                            for use_b in [False, True]
                            for shapes in itertools.product(
                                *((shape_group,
                                   shape_group) if use_b else (shape_group, )))
                            for axis in range(
                                -max(len(shape) for shape in shapes),
                                max(len(shape) for shape in shapes))
                            for keepdims in [False, True]
                            for return_sign in [False, True]))
    @jtu.ignore_warning(category=RuntimeWarning,
                        message="invalid value encountered in .*")
    @jax.numpy_rank_promotion(
        'allow')  # This test explicitly exercises implicit rank promotion.
    def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b):
        if jtu.device_under_test() != "cpu":
            rng = jtu.rand_some_inf_and_nan(self.rng())
        else:
            rng = jtu.rand_default(self.rng())
        # TODO(mattjj): test autodiff
        if use_b:

            def scipy_fun(array_to_reduce, scale_array):
                res = osp_special.logsumexp(array_to_reduce,
                                            axis,
                                            keepdims=keepdims,
                                            return_sign=return_sign,
                                            b=scale_array)
                if dtype == np.int32:
                    res = tree_map(lambda x: x.astype('float32'), res)
                return res

            def lax_fun(array_to_reduce, scale_array):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
        else:

            def scipy_fun(array_to_reduce):
                res = osp_special.logsumexp(array_to_reduce,
                                            axis,
                                            keepdims=keepdims,
                                            return_sign=return_sign)
                if dtype == np.int32:
                    res = tree_map(lambda x: x.astype('float32'), res)
                return res

            def lax_fun(array_to_reduce):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            args_maker = lambda: [rng(shapes[0], dtype)]
        tol = {np.float32: 1E-6, np.float64: 1E-14}
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    def testLogSumExpZeros(self):
        # Regression test for https://github.com/google/jax/issues/5370
        scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
        lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
        args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker)

    def testLogSumExpOnes(self):
        # Regression test for https://github.com/google/jax/issues/7390
        args_maker = lambda: [np.ones(4, dtype='float32')]
        with jax.debug_infs(True):
            self._CheckAgainstNumpy(osp_special.logsumexp,
                                    lsp_special.logsumexp, args_maker)
            self._CompileAndCheck(lsp_special.logsumexp, args_maker)

    def testLogSumExpNans(self):
        # Regression test for https://github.com/google/jax/issues/7634
        with jax.debug_nans(True):
            with jax.disable_jit():
                result = lsp_special.logsumexp(1.0)
                self.assertEqual(result, 1.0)

                result = lsp_special.logsumexp(1.0, b=1.0)
                self.assertEqual(result, 1.0)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "nondiff_argnums":
                    rec.nondiff_argnums,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in itertools.combinations_with_replacement(
                    all_shapes, rec.nargs)
                for dtypes in (itertools.combinations_with_replacement(
                    rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else
                               itertools.product(*rec.dtypes)))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    @jax.numpy_rank_promotion(
        'allow')  # This test explicitly exercises implicit rank promotion.
    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    @jax.numpy_rank_promotion('raise')
    def testMultigammaln(self, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = jtu.rand_positive(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={
                                  np.float32: 3e-07,
                                  np.float64: 4e-15
                              })

    def testIssue980(self):
        x = np.full((4, ), -1e20, dtype=np.float32)
        self.assertAllClose(np.zeros((4, ), dtype=np.float32),
                            lsp_special.expit(x))

    @jax.numpy_rank_promotion('raise')
    def testIssue3758(self):
        x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
        q = np.array([1., 40., 30.], dtype=np.float32)
        self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32),
                            lsp_special.zeta(x, q))

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(jax.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(jax.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_lmax={}".format(jtu.format_shape_dtype_string(shape, dtype),
                                 l_max),
            "l_max":
            l_max,
            "shape":
            shape,
            "dtype":
            dtype
        } for l_max in [1, 2, 3, 6] for shape in [(5, ), (10, )]
                            for dtype in float_dtypes))
    def testLpmn(self, l_max, shape, dtype):
        rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
        args_maker = lambda: [rng(shape, dtype)]

        lax_fun = partial(lsp_special.lpmn, l_max, l_max)

        def scipy_fun(z, m=l_max, n=l_max):
            # scipy only supports scalar inputs for z, so we must loop here.
            vals, derivs = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
            return np.dstack(vals), np.dstack(derivs)

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                rtol=1e-5,
                                atol=3e-3,
                                check_dtypes=False)
        self._CompileAndCheck(lax_fun, args_maker, rtol=1E-5, atol=3e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}_lmax={}".format(jtu.format_shape_dtype_string(shape, dtype),
                                 l_max),
            "l_max":
            l_max,
            "shape":
            shape,
            "dtype":
            dtype
        } for l_max in [3, 4, 6, 32] for shape in [(2, ), (3, ), (4, ), (64, )]
                            for dtype in float_dtypes))
    def testNormalizedLpmnValues(self, l_max, shape, dtype):
        rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
        args_maker = lambda: [rng(shape, dtype)]

        # Note: we test only the normalized values, not the derivatives.
        lax_fun = partial(lsp_special.lpmn_values,
                          l_max,
                          l_max,
                          is_normalized=True)

        def scipy_fun(z, m=l_max, n=l_max):
            # scipy only supports scalar inputs for z, so we must loop here.
            vals, _ = zip(*(osp_special.lpmn(m, n, zi) for zi in z))
            a = np.dstack(vals)

            # apply the normalization
            num_m, num_l, _ = a.shape
            a_normalized = np.zeros_like(a)
            for m in range(num_m):
                for l in range(num_l):
                    c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
                    c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
                    c2 = np.sqrt(c0 / c1)
                    a_normalized[m, l] = c2 * a[m, l]
            return a_normalized

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                rtol=1e-5,
                                atol=1e-5,
                                check_dtypes=False)
        self._CompileAndCheck(lax_fun, args_maker, rtol=1E-6, atol=1E-6)

    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testSphHarmAccuracy(self):
        m = jnp.arange(-3, 3)[:, None]
        n = jnp.arange(3, 6)
        n_max = 5
        theta = 0.0
        phi = jnp.pi

        expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

        actual = osp_special.sph_harm(m, n, theta, phi)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testSphHarmOrderZeroDegreeZero(self):
        """Tests the spherical harmonics of order zero and degree zero."""
        theta = jnp.array([0.3])
        phi = jnp.array([2.3])
        n_max = 0

        expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
        actual = jnp.real(
            lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi,
                                 n_max))

        self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)

    @jtu.skip_on_devices("rocm")  # rtol and atol needs to be adjusted for ROCm
    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testSphHarmOrderZeroDegreeOne(self):
        """Tests the spherical harmonics of order one and degree zero."""
        theta = jnp.array([2.0])
        phi = jnp.array([3.1])
        n_max = 1

        expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
        actual = jnp.real(
            lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi,
                                 n_max))

        self.assertAllClose(actual, expected, rtol=2e-7, atol=6e-8)

    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testSphHarmOrderOneDegreeOne(self):
        """Tests the spherical harmonics of order one and degree one."""
        theta = jnp.array([2.0])
        phi = jnp.array([2.5])
        n_max = 1

        expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) *
                    jnp.exp(1j * theta))
        actual = lsp_special.sph_harm(jnp.array([1]), jnp.array([1]), theta,
                                      phi, n_max)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            f'_maxdegree={l_max}_inputsize={num_z}_dtype={dtype.__name__}',
            'l_max': l_max,
            'num_z': num_z,
            'dtype': dtype
        } for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
                            for dtype in jtu.dtypes.all_integer))
    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
        """Tests against JIT compatibility and Numpy."""
        n_max = l_max
        shape = (num_z, )
        rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

        lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)

        def args_maker():
            m = rng(shape, dtype)
            n = abs(m)
            theta = jnp.linspace(-4.0, 5.0, num_z)
            phi = jnp.linspace(-2.0, 1.0, num_z)
            return m, n, theta, phi

        with self.subTest('Test JIT compatibility'):
            self._CompileAndCheck(lsp_special_fn, args_maker)

        with self.subTest('Test against numpy.'):
            self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn,
                                    args_maker)

    @jax.numpy_dtype_promotion(
        'standard')  # This test explicitly exercises dtype promotion
    def testSphHarmCornerCaseWithWrongNmax(self):
        """Tests the corner case where `n_max` is not the maximum value of `n`."""
        m = jnp.array([2])
        n = jnp.array([10])
        n_clipped = jnp.array([6])
        n_max = 6
        theta = jnp.array([0.9])
        phi = jnp.array([0.2])

        expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

        actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name':
                '_shape={}'
                '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
                '_max_sv={}_method={}_side={}'
                '_nonzero_condition_number={}_seed={}'.format(
                    jtu.format_shape_dtype_string(
                        shape,
                        jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv,
                    degeneracy, geometric_spectrum, max_sv, method, side,
                    nonzero_condition_number, seed),
                'n_zero_sv':
                n_zero_sv,
                'degeneracy':
                degeneracy,
                'geometric_spectrum':
                geometric_spectrum,
                'max_sv':
                max_sv,
                'shape':
                shape,
                'method':
                method,
                'side':
                side,
                'nonzero_condition_number':
                nonzero_condition_number,
                'dtype':
                dtype,
                'seed':
                seed
            } for n_zero_sv in n_zero_svs for degeneracy in degeneracies
            for geometric_spectrum in geometric_spectra for max_sv in max_svs
            for shape in polar_shapes for method in methods for side in sides
            for nonzero_condition_number in nonzero_condition_numbers
            for dtype in jtu.dtypes.inexact for seed in seeds))
    @jtu.skip_on_devices("gpu")  # Fails on A100.
    def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
                  shape, method, side, nonzero_condition_number, dtype, seed):
        """ Tests jax.scipy.linalg.polar."""
        if jtu.device_under_test() != "cpu":
            if jnp.dtype(dtype).name in ("bfloat16", "float16"):
                raise unittest.SkipTest("Skip half precision off CPU.")

        m, n = shape
        if (method == "qdwh" and ((side == "left" and m >= n) or
                                  (side == "right" and m < n))):
            raise unittest.SkipTest("method=qdwh does not support these sizes")

        matrix, _ = _initialize_polar_test(self.rng(), shape, n_zero_sv,
                                           degeneracy, geometric_spectrum,
                                           max_sv, nonzero_condition_number,
                                           dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError,
                              jsp.linalg.polar,
                              matrix,
                              method=method,
                              side=side)
            return

        unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
        if shape[0] >= shape[1]:
            should_be_eye = np.matmul(unitary.conj().T, unitary)
        else:
            should_be_eye = np.matmul(unitary, unitary.conj().T)
        tol = 500 * float(jnp.finfo(matrix.dtype).eps)
        eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
        with self.subTest('Test unitarity.'):
            self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape))

        with self.subTest('Test Hermiticity.'):
            self.assertAllClose(posdef,
                                posdef.conj().T,
                                atol=tol * jnp.linalg.norm(posdef))

        ev, _ = np.linalg.eigh(posdef)
        ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
        negative_ev = jnp.sum(ev < 0.)
        with self.subTest('Test positive definiteness.'):
            self.assertEqual(negative_ev, 0)

        if side == "right":
            recon = jnp.matmul(unitary,
                               posdef,
                               precision=lax.Precision.HIGHEST)
        elif side == "left":
            recon = jnp.matmul(posdef,
                               unitary,
                               precision=lax.Precision.HIGHEST)
        with self.subTest('Test reconstruction.'):
            self.assertAllClose(matrix,
                                recon,
                                atol=tol * jnp.linalg.norm(matrix))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_linear_size={}_dtype={}_termination_size={}'.format(
                linear_size,
                jnp.dtype(dtype).name, termination_size),
            'linear_size':
            linear_size,
            'dtype':
            dtype,
            'termination_size':
            termination_size
        } for linear_size in linear_sizes for dtype in jtu.dtypes.floating +
                            jtu.dtypes.complex
                            for termination_size in [1, 19]))
    def test_spectral_dac_eigh(self, linear_size, dtype, termination_size):
        if jtu.device_under_test() != "tpu" and termination_size != 1:
            raise unittest.SkipTest(
                "Termination sizes greater than 1 only work on TPU")

        rng = self.rng()
        H = rng.randn(linear_size, linear_size)
        H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.lax.eigh.eigh, H)
            return
        evs, V = jax._src.lax.eigh.eigh(H, termination_size=termination_size)
        ev_exp, eV_exp = jnp.linalg.eigh(H)
        HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
        vV = evs.astype(V.dtype)[None, :] * V
        eps = jnp.finfo(H.dtype).eps
        atol = jnp.linalg.norm(H) * eps
        self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
        self.assertAllClose(
            HV,
            vV,
            atol=atol *
            (80 if jnp.issubdtype(dtype, jnp.complexfloating) else 30))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_{jtu.format_shape_dtype_string((n_obs, n_codes, *n_feats), dtype)}",
            "n_obs": n_obs,
            "n_codes": n_codes,
            "n_feats": n_feats,
            "dtype": dtype
        } for n_obs in [1, 3, 5] for n_codes in [1, 2, 4]
                            for n_feats in [()] + [(i, ) for i in range(1, 3)]
                            for dtype in float_dtypes + int_dtypes)
    )  # scipy doesn't support complex
    def test_vq(self, n_obs, n_codes, n_feats, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [
            rng((n_obs, *n_feats), dtype),
            rng((n_codes, *n_feats), dtype)
        ]
        self._CheckAgainstNumpy(osp_cluster.vq.vq,
                                lsp_cluster.vq.vq,
                                args_maker,
                                check_dtypes=False)
        self._CompileAndCheck(lsp_cluster.vq.vq, args_maker)
Example #9
0
class SvdTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint:disable=g-complex-comprehension
                'testcase_name': '_m={}_by_n={}_log_cond={}'.format(
                    m, n, log_cond),
                'm': m,
                'n': n,
                'log_cond': log_cond
            } for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
            for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)))
    @jtu.skip_on_devices("rocm")  # will be fixed on rocm-5.1
    def testSvdWithRectangularInput(self, m, n, log_cond):
        """Tests SVD with rectangular input."""
        with jax.default_matmul_precision('float32'):
            a = np.random.uniform(low=0.3, high=0.9,
                                  size=(m, n)).astype(_SVD_TEST_DTYPE)
            u, s, v = jnp.linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, min(m, n))
            a = (u * s) @ v
            a = a + 1j * a

            osp_linalg_fn = functools.partial(osp_linalg.svd,
                                              full_matrices=False)
            actual_u, actual_s, actual_v = svd.svd(a)

            k = min(m, n)
            if m > n:
                unitary_u = jnp.abs(actual_u.T.conj() @ actual_u)
                unitary_v = jnp.abs(actual_v.T.conj() @ actual_v)
            else:
                unitary_u = jnp.abs(actual_u @ actual_u.T.conj())
                unitary_v = jnp.abs(actual_v @ actual_v.T.conj())

            _, expected_s, _ = osp_linalg_fn(a)

            args_maker = lambda: [a]

            with self.subTest('Test JIT compatibility'):
                self._CompileAndCheck(svd.svd, args_maker)

            with self.subTest('Test unitary u.'):
                self.assertAllClose(np.eye(k),
                                    unitary_u,
                                    rtol=_SVD_RTOL,
                                    atol=2E-3)

            with self.subTest('Test unitary v.'):
                self.assertAllClose(np.eye(k),
                                    unitary_v,
                                    rtol=_SVD_RTOL,
                                    atol=2E-3)

            with self.subTest('Test s.'):
                self.assertAllClose(expected_s,
                                    jnp.real(actual_s),
                                    rtol=_SVD_RTOL,
                                    atol=1E-6)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': '_m={}_by_n={}'.format(m, n),
            'm': m,
            'n': n
        } for m, n in zip([50, 6], [3, 60])))
    def testSvdWithSkinnyTallInput(self, m, n):
        """Tests SVD with skinny and tall input."""
        # Generates a skinny and tall input
        with jax.default_matmul_precision('float32'):
            np.random.seed(1235)
            a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE)
            u, s, v = svd.svd(a, is_hermitian=False)

            relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a)

            np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint:disable=g-complex-comprehension
                'testcase_name': '_m={}_r={}_log_cond={}'.format(
                    m, r, log_cond),
                'm': m,
                'r': r,
                'log_cond': log_cond
            } for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])
            for log_cond in np.linspace(1, 3, 3)))
    @jtu.skip_on_devices("rocm")  # will be fixed on rocm-5.1
    def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
        """Tests SVD with rank-deficient input."""
        with jax.default_matmul_precision('float32'):
            a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE)

            # Generates a rank-deficient input.
            u, s, v = jnp.linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, m)
            s = s.at[r:m].set(jnp.zeros((m - r, )))
            a = (u * s) @ v

            with jax.default_matmul_precision('float32'):
                u, s, v = svd.svd(a, is_hermitian=False)
            diff = np.linalg.norm(a - (u * s) @ v)

            np.testing.assert_almost_equal(diff, 1E-4, decimal=2)
Example #10
0
class TestPromotionTables(jtu.JaxTestCase):
    @parameterized.named_parameters({
        "testcase_name": f"_jaxtype={jaxtype}",
        "jaxtype": jaxtype
    } for jaxtype in dtypes._jax_types + dtypes._weak_types)
    def testJaxTypeFromType(self, jaxtype):
        self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(jaxtype)),
                      jaxtype)

    @parameterized.named_parameters({
        "testcase_name": f"_jaxtype={jaxtype}",
        "jaxtype": jaxtype
    } for jaxtype in dtypes._jax_types + dtypes._weak_types)
    def testJaxTypeFromVal(self, jaxtype):
        try:
            val = jaxtype(0)
        except TypeError:
            val = jaxtype.type(0)
        self.assertIs(dtypes._jax_type(*dtypes._dtype_and_weaktype(val)),
                      jaxtype)

    @parameterized.named_parameters({
        "testcase_name": f"_dtype={dtype}",
        "dtype": dtype
    } for dtype in dtypes._jax_types)
    def testJaxTypeWeak(self, dtype):
        jax_type = dtypes._jax_type(dtype, weak_type=True)
        if dtypes.issubdtype(jax_type, np.complexfloating):
            self.assertIs(jax_type, complex)
        elif dtypes.issubdtype(jax_type, np.floating):
            self.assertIs(jax_type, float)
        elif dtypes.issubdtype(jax_type, np.integer):
            self.assertIs(jax_type, int)
        else:
            self.assertIs(jax_type, np.dtype(bool))

    def testResultTypeNone(self):
        # This matches the behavior of np.result_type(None) => np.float_
        self.assertEqual(dtypes.result_type(None),
                         dtypes.canonicalize_dtype(dtypes.float_))

    def testResultTypeWeakFlag(self):
        float_ = dtypes.canonicalize_dtype(dtypes.float_)
        x_weak = jnp.array(1.)
        x_strong = x_weak.astype(float_)
        self.assertEqual(dtypes.result_type(x_weak), float_)
        self.assertEqual(
            dtypes.result_type(x_weak, return_weak_type_flag=True),
            (float_, True))
        self.assertEqual(dtypes.result_type(x_strong), float_)
        self.assertEqual(
            dtypes.result_type(x_strong, return_weak_type_flag=True),
            (float_, False))

    @jtu.ignore_warning(category=UserWarning,
                        message="Explicitly requested dtype.*")
    @jax.numpy_dtype_promotion('standard')
    def testObservedPromotionTable(self):
        """Test that the weak & strong dtype promotion table does not change over time."""
        # Note: * here refers to weakly-typed values
        typecodes = \
            ['b1','u1','u2','u4','u8','i1','i2','i4','i8','bf','f2','f4','f8','c4','c8','i*','f*','c*']
        if config.x64_enabled:
            expected = [
                [
                    'b1', 'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'i*', 'f*', 'c*'
                ],
                [
                    'u1', 'u1', 'u2', 'u4', 'u8', 'i2', 'i2', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'u1', 'f*', 'c*'
                ],
                [
                    'u2', 'u2', 'u2', 'u4', 'u8', 'i4', 'i4', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'u2', 'f*', 'c*'
                ],
                [
                    'u4', 'u4', 'u4', 'u4', 'u8', 'i8', 'i8', 'i8', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'u4', 'f*', 'c*'
                ],
                [
                    'u8', 'u8', 'u8', 'u8', 'u8', 'f*', 'f*', 'f*', 'f*', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'u8', 'f*', 'c*'
                ],
                [
                    'i1', 'i2', 'i4', 'i8', 'f*', 'i1', 'i2', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'i1', 'f*', 'c*'
                ],
                [
                    'i2', 'i2', 'i4', 'i8', 'f*', 'i2', 'i2', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'i2', 'f*', 'c*'
                ],
                [
                    'i4', 'i4', 'i4', 'i8', 'f*', 'i4', 'i4', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'i4', 'f*', 'c*'
                ],
                [
                    'i8', 'i8', 'i8', 'i8', 'f*', 'i8', 'i8', 'i8', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'i8', 'f*', 'c*'
                ],
                [
                    'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf',
                    'f4', 'f4', 'f8', 'c4', 'c8', 'bf', 'bf', 'c4'
                ],
                [
                    'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f4',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'f2', 'f2', 'c4'
                ],
                [
                    'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4',
                    'f4', 'f4', 'f8', 'c4', 'c8', 'f4', 'f4', 'c4'
                ],
                [
                    'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8', 'f8',
                    'f8', 'f8', 'f8', 'c8', 'c8', 'f8', 'f8', 'c8'
                ],
                [
                    'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4',
                    'c4', 'c4', 'c8', 'c4', 'c8', 'c4', 'c4', 'c4'
                ],
                [
                    'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8',
                    'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8', 'c8'
                ],
                [
                    'i*', 'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'i*', 'f*', 'c*'
                ],
                [
                    'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'bf',
                    'f2', 'f4', 'f8', 'c4', 'c8', 'f*', 'f*', 'c*'
                ],
                [
                    'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c4',
                    'c4', 'c4', 'c8', 'c4', 'c8', 'c*', 'c*', 'c*'
                ],
            ]
        else:
            expected = [
                [
                    'b1', 'u1', 'u2', 'u4', 'u4', 'i1', 'i2', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'i*', 'f*', 'c*'
                ],
                [
                    'u1', 'u1', 'u2', 'u4', 'u4', 'i2', 'i2', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'u1', 'f*', 'c*'
                ],
                [
                    'u2', 'u2', 'u2', 'u4', 'u4', 'i4', 'i4', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'u2', 'f*', 'c*'
                ],
                [
                    'u4', 'u4', 'u4', 'u4', 'u4', 'i4', 'i4', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'u4', 'f*', 'c*'
                ],
                [
                    'u4', 'u4', 'u4', 'u4', 'u4', 'i4', 'i4', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'u4', 'f*', 'c*'
                ],
                [
                    'i1', 'i2', 'i4', 'i4', 'i4', 'i1', 'i2', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'i1', 'f*', 'c*'
                ],
                [
                    'i2', 'i2', 'i4', 'i4', 'i4', 'i2', 'i2', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'i2', 'f*', 'c*'
                ],
                [
                    'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'i4', 'f*', 'c*'
                ],
                [
                    'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'i4', 'f*', 'c*'
                ],
                [
                    'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf', 'bf',
                    'f4', 'f4', 'f4', 'c4', 'c4', 'bf', 'bf', 'c4'
                ],
                [
                    'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f2', 'f4',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'f2', 'f2', 'c4'
                ],
                [
                    'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4',
                    'f4', 'f4', 'f4', 'c4', 'c4', 'f4', 'f4', 'c4'
                ],
                [
                    'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4', 'f4',
                    'f4', 'f4', 'f4', 'c4', 'c4', 'f4', 'f4', 'c4'
                ],
                [
                    'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4',
                    'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4'
                ],
                [
                    'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4',
                    'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4', 'c4'
                ],
                [
                    'i*', 'u1', 'u2', 'u4', 'u4', 'i1', 'i2', 'i4', 'i4', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'i*', 'f*', 'c*'
                ],
                [
                    'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'f*', 'bf',
                    'f2', 'f4', 'f4', 'c4', 'c4', 'f*', 'f*', 'c*'
                ],
                [
                    'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c*', 'c4',
                    'c4', 'c4', 'c4', 'c4', 'c4', 'c*', 'c*', 'c*'
                ],
            ]
        typecode_to_dtype = {
            'b1': jnp.bool_,
            'u1': jnp.uint8,
            'u2': jnp.uint16,
            'u4': jnp.uint32,
            'u8': jnp.uint64,
            'i1': jnp.int8,
            'i2': jnp.int16,
            'i4': jnp.int32,
            'i8': jnp.int64,
            'bf': jnp.bfloat16,
            'f2': jnp.float16,
            'f4': jnp.float32,
            'f8': jnp.float64,
            'c4': jnp.complex64,
            'c8': jnp.complex128,
            'i*': jnp.int64,
            'f*': jnp.float64,
            'c*': jnp.complex128,
        }
        dtype_to_typecode = {
            jnp.dtype(v): k
            for k, v in typecode_to_dtype.items() if not k.endswith('*')
        }

        def typecode_to_val(typecode):
            weak_type = typecode.endswith('*')
            dtype = typecode_to_dtype[typecode]
            val = dtype(0)
            if weak_type:
                val = val.item()
            return val

        def val_to_typecode(val):
            dtype = dtypes.result_type(val)
            weak_type = dtypes.is_weakly_typed(val)
            typecode = dtype_to_typecode[dtype]
            if weak_type:
                typecode = typecode[:-1] + '*'
            return typecode

        vals = [typecode_to_val(t) for t in typecodes]
        table = [[val_to_typecode(v1 + v2) for v1 in vals] for v2 in vals]

        def show_differences(epected, actual):
            diffs = ""
            for i, t1 in enumerate(typecodes):
                for j, t2 in enumerate(typecodes):
                    if expected[i][j] != actual[i][j]:
                        diffs += f"\n{t1}, {t2} -> want {expected[i][j]}, got {actual[i][j]}"
            return diffs

        self.assertEqual(table, expected, show_differences(expected, table))

    @parameterized.named_parameters({
        "testcase_name":
        "_xtype={}_ytype={}_xfun={}_yfun={}".format(xtype.__name__,
                                                    ytype.__name__,
                                                    xfun.__name__,
                                                    yfun.__name__),
        "xtype":
        xtype,
        "ytype":
        ytype,
        "xfun":
        xfun,
        "yfun":
        yfun
    } for xtype, ytype in itertools.product(
        [int, float, jnp.int16, jnp.int32, jnp.float16, jnp.float32], repeat=2)
                                    for xfun, yfun in itertools.product(
                                        [identity, abs, jnp.array], repeat=2))
    @jax.numpy_dtype_promotion('standard')
    def testBinaryPromotionJitInvariance(self, xtype, ytype, xfun, yfun):
        """Test jit invariance of simple binary promotion rules with and without weak types."""
        f = lambda x, y: xfun(x) + yfun(y)
        args_maker = lambda: [xtype(1), ytype(1)]
        self._CompileAndCheck(f, args_maker, check_dtypes=True)

    @parameterized.named_parameters({
        "testcase_name": f"_dtype={dtype}_weak_type={weak_type}",
        "dtype": dtype,
        "weak_type": weak_type
    } for dtype in all_dtypes for weak_type in [True, False])
    def testUnaryPromotion(self, dtype, weak_type):
        # Regression test for https://github.com/google/jax/issues/6051
        x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
        if weak_type:
            expected = dtypes.canonicalize_dtype(
                dtypes._default_types['f' if x.dtype ==
                                      'bfloat16' else x.dtype.kind])
        else:
            expected = x.dtype
        self.assertEqual(dtypes.result_type(x), expected)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_dtype={dtype}_weak_type={weak_type}_promotion={promotion}",
            "dtype": dtype,
            "weak_type": weak_type,
            "promotion": promotion
        } for dtype in all_dtypes for weak_type in [True, False]
                            for promotion in ['standard', 'strict']))
    def testBinaryNonPromotion(self, dtype, weak_type, promotion):
        # Regression test for https://github.com/google/jax/issues/6051
        x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
        with jax.numpy_dtype_promotion(promotion):
            y = (x + x)

        if promotion == 'standard' or not weak_type or dtype == dtypes.bool_:
            expected_dtype = dtype
        elif dtypes.issubdtype(dtype, np.complexfloating):
            expected_dtype = dtypes.complex_
        elif dtypes.issubdtype(dtype, np.floating):
            expected_dtype = dtypes.float_
        else:
            expected_dtype = dtypes.int_

        # No boolean weak types.
        expected_weak_type = weak_type and dtype != bool
        expected_dtype = dtypes.canonicalize_dtype(expected_dtype)

        self.assertEqual(y.dtype, expected_dtype)
        self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)

    @parameterized.named_parameters({
        "testcase_name": f"_dtype={dtype}_weak_type={weak_type}",
        "dtype": dtype,
        "weak_type": weak_type
    } for dtype in all_dtypes for weak_type in [True, False])
    def testDeviceArrayRepr(self, dtype, weak_type):
        val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
        rep = repr(val)
        self.assertStartsWith(rep, 'DeviceArray(')
        if weak_type:
            self.assertEndsWith(rep,
                                f"dtype={val.dtype.name}, weak_type=True)")
        else:
            self.assertEndsWith(rep, f"dtype={val.dtype.name})")
Example #11
0
class VectorizeTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_leftshape={}_rightshape={}".format(left_shape, right_shape),
            "left_shape":
            left_shape,
            "right_shape":
            right_shape,
            "result_shape":
            result_shape
        } for left_shape, right_shape, result_shape in [
            ((2, 3), (3, 4), (2, 4)),
            ((2, 3), (1, 3, 4), (1, 2, 4)),
            ((5, 2, 3), (1, 3, 4), (5, 2, 4)),
            ((6, 5, 2, 3), (3, 4), (6, 5, 2, 4)),
        ]))
    def test_matmat(self, left_shape, right_shape, result_shape):
        matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
        self.assertEqual(
            matmat(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape,
            result_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_leftshape={}_rightshape={}".format(left_shape, right_shape),
            "left_shape":
            left_shape,
            "right_shape":
            right_shape,
            "result_shape":
            result_shape
        } for left_shape, right_shape, result_shape in [
            ((2, 3), (3, ), (2, )),
            ((2, 3), (1, 3), (1, 2)),
            ((4, 2, 3), (1, 3), (4, 2)),
            ((5, 4, 2, 3), (1, 3), (5, 4, 2)),
        ]))
    def test_matvec(self, left_shape, right_shape, result_shape):
        matvec = jnp.vectorize(jnp.dot, signature='(n,m),(m)->(n)')
        self.assertEqual(
            matvec(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape,
            result_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_leftshape={}_rightshape={}".format(left_shape, right_shape),
            "left_shape":
            left_shape,
            "right_shape":
            right_shape,
            "result_shape":
            result_shape
        } for left_shape, right_shape, result_shape in [
            ((3, ), (3, ), ()),
            ((2, 3), (3, ), (2, )),
            ((4, 2, 3), (3, ), (4, 2)),
        ]))
    def test_vecmat(self, left_shape, right_shape, result_shape):
        vecvec = jnp.vectorize(jnp.dot, signature='(m),(m)->()')
        self.assertEqual(
            vecvec(jnp.zeros(left_shape), jnp.zeros(right_shape)).shape,
            result_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape,
            "result_shape": result_shape
        } for shape, result_shape in [
            ((3, ), ()),
            ((
                2,
                3,
            ), (2, )),
            ((
                1,
                2,
                3,
            ), (1, 2)),
        ]))
    def test_magnitude(self, shape, result_shape):
        size = 1
        for x in shape:
            size *= x
        inputs = jnp.arange(size).reshape(shape)

        @partial(jnp.vectorize, signature='(n)->()')
        def magnitude(x):
            return jnp.dot(x, x)

        self.assertEqual(magnitude(inputs).shape, result_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape,
            "result_shape": result_shape
        } for shape, result_shape in [
            ((3, ), ()),
            ((2, 3), (2, )),
            ((1, 2, 3, 4), (1, 2, 3)),
        ]))
    def test_mean(self, shape, result_shape):
        mean = jnp.vectorize(jnp.mean, signature='(n)->()')
        self.assertEqual(mean(jnp.zeros(shape)).shape, result_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape,
            "result_shape": result_shape
        } for shape, result_shape in [
            ((), (2, )),
            ((3, ), (
                3,
                2,
            )),
        ]))
    def test_stack_plus_minus(self, shape, result_shape):
        @partial(jnp.vectorize, signature='()->(n)')
        def stack_plus_minus(x):
            return jnp.stack([x, -x])

        self.assertEqual(
            stack_plus_minus(jnp.zeros(shape)).shape, result_shape)

    def test_center(self):
        @partial(jnp.vectorize, signature='(n)->(),(n)')
        def center(array):
            bias = jnp.mean(array)
            debiased = array - bias
            return bias, debiased

        b, a = center(jnp.arange(3))
        self.assertEqual(a.shape, (3, ))
        self.assertEqual(b.shape, ())
        self.assertAllClose(1.0, b, check_dtypes=False)

        b, a = center(jnp.arange(6).reshape(2, 3))
        self.assertEqual(a.shape, (2, 3))
        self.assertEqual(b.shape, (2, ))
        self.assertAllClose(jnp.array([1.0, 4.0]), b, check_dtypes=False)

    def test_exclude_first(self):
        @partial(jnp.vectorize, excluded={0})
        def f(x, y):
            assert x == 'foo'
            assert y.ndim == 0
            return y

        x = jnp.arange(3)
        self.assertAllClose(x, f('foo', x))
        self.assertAllClose(x, jax.jit(f, static_argnums=0)('foo', x))

    def test_exclude_second(self):
        @partial(jnp.vectorize, excluded={1})
        def f(x, y):
            assert x.ndim == 0
            assert y == 'foo'
            return x

        x = jnp.arange(3)
        self.assertAllClose(x, f(x, 'foo'))
        self.assertAllClose(x, jax.jit(f, static_argnums=1)(x, 'foo'))

    def test_exclude_errors(self):
        with self.assertRaisesRegex(TypeError,
                                    "jax.numpy.vectorize can only exclude"):
            jnp.vectorize(lambda x: x, excluded={'foo'})

        with self.assertRaisesRegex(
                ValueError, r"excluded=\{-1\} contains negative numbers"):
            jnp.vectorize(lambda x: x, excluded={-1})

        f = jnp.vectorize(lambda x: x, excluded={1})
        with self.assertRaisesRegex(
                ValueError, r"excluded=\{1\} is invalid for 1 argument\(s\)"):
            f(1.0)

    def test_bad_inputs(self):
        matmat = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k)')
        with self.assertRaisesRegex(TypeError,
                                    "wrong number of positional arguments"):
            matmat(jnp.zeros((3, 2)))
        with self.assertRaisesRegex(
                ValueError,
                r"input with shape \(2,\) does not have enough dimensions"):
            matmat(jnp.zeros((2, )), jnp.zeros((2, 2)))
        with self.assertRaisesRegex(
                ValueError, r"inconsistent size for core dimension 'm'"):
            matmat(jnp.zeros((2, 3)), jnp.zeros((4, 5)))

    def test_wrong_output_type(self):
        f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,k),()')
        with self.assertRaisesRegex(TypeError, "output must be a tuple"):
            f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))

    def test_wrong_num_outputs(self):
        f = jnp.vectorize(lambda *args: args, signature='(),()->(),(),()')
        with self.assertRaisesRegex(TypeError,
                                    "wrong number of output arguments"):
            f(1, 2)

    def test_wrong_output_shape(self):
        f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n)')
        with self.assertRaisesRegex(ValueError,
                                    r"output shape \(2, 2\) does not match"):
            f(jnp.zeros((2, 2)), jnp.zeros((2, 2)))

    def test_inconsistent_output_size(self):
        f = jnp.vectorize(jnp.dot, signature='(n,m),(m,k)->(n,n)')
        with self.assertRaisesRegex(
                ValueError, r"inconsistent size for core dimension 'n'"):
            f(jnp.zeros((2, 3)), jnp.zeros((3, 4)))
Example #12
0
class MaskingTest(jtu.JaxTestCase):
  def test_sum(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def padded_sum(x):
      return jnp.sum(x)

    ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=3))
    expected = 8
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = padded_sum([jnp.array([3, 1, 4, 1, 5])], dict(n=4))
    expected = 9
    self.assertAllClose(ans, expected, check_dtypes=False)

  def test_sum_vmap(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def padded_sum(x):
      return jnp.sum(x)

    ans = vmap(padded_sum)([jnp.ones((5, 10))], dict(n=jnp.arange(5)))
    expected = np.array([0, 1, 2, 3, 4])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def check(self, fun, in_shapes, out_shape, logical_env, padded_in_shapes,
            dtypes, rng, rtol=None, atol=None):
    shapecheck(in_shapes, out_shape)(fun)
    masked_fun = mask(fun, in_shapes, out_shape)
    padded_args = [rng(shape, dtype)
                   for shape, dtype in zip(padded_in_shapes, dtypes)]
    padded_outs, outs_tree = tree_flatten(masked_fun(padded_args, logical_env))

    out_specs, _ = tree_flatten(out_shape)
    out_specs = map(parse_spec, out_specs)
    out_specs = map(finalize_spec, out_specs, map(np.shape, padded_outs))
    logical_out_shapes = [eval_poly_shape(s, logical_env)
                          for s in out_specs]
    logical_out_slices = [tuple(map(slice, s)) for s in logical_out_shapes]
    logical_outs = [o[s] for o, s in zip(padded_outs, logical_out_slices)]

    in_specs = map(parse_spec, in_shapes)
    in_specs = map(finalize_spec, in_specs, padded_in_shapes)
    logical_in_shapes = [eval_poly_shape(s, logical_env)
                         for s in in_specs]
    logical_in_slices = [tuple(map(slice, s)) for s in logical_in_shapes]
    logical_args = [a[s] for a, s in zip(padded_args, logical_in_slices)]
    logical_outs_expected, logical_outs_tree = tree_flatten(fun(*logical_args))
    assert outs_tree == logical_outs_tree
    self.assertAllClose(logical_outs, logical_outs_expected, check_dtypes=True,
                        atol=atol, rtol=rtol)

    # Check that abstract evaluation works
    padded_outs_jit, _ = tree_flatten(jit(masked_fun)(padded_args, logical_env))
    self.assertAllClose(padded_outs_jit, padded_outs, check_dtypes=True,
                        atol=atol, rtol=rtol)

  def test_add(self):
    self.check(lax.add, ['n', ''], 'n', {'n': 3}, [(4,), ()], ['float_', 'float_'],
               jtu.rand_default(self.rng()))
    addvecs = mask(lax.add, in_shapes=['n', 'n'], out_shape='n')

    x = jnp.array([3, 1, 4, 1, 5, 9])
    y = jnp.array([2, 6, 5, 3, 5, 8])
    ans = addvecs([x, y], dict(n=3))
    expected = np.array([5, 7, 9])
    self.assertAllClose(ans[:3], expected, check_dtypes=False)

    thunk = lambda: addvecs([jnp.arange(5), jnp.arange(6)], dict(n=3))
    self.assertRaisesRegex(ShapeError, "", thunk)

  def test_scan(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def cumsum(arr):
      out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
      return out

    n = np.uint8(3)  # Test non-default integer type for dynamic length.
    ans = cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=n))
    expected = 16
    self.assertAllClose(ans, expected, check_dtypes=False)

  def test_scan_vmap(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def cumsum(arr):
      out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
      return out

    ans = vmap(cumsum)([jnp.arange(6).reshape(2, 3)], dict(n=jnp.array([1, 2])))
    expected = np.array([0, 7])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def test_scan_jit(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def cumsum(arr):
      out, _ = lax.scan(lambda c, x: (c + x, ()), 0, arr)
      return out

    @jit
    def jit_cumsum(args, shape_env):
      assert python_should_be_executing
      return cumsum(args, shape_env)

    python_should_be_executing = True
    ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=3))
    expected = 16
    self.assertAllClose(ans, expected, check_dtypes=False)

    python_should_be_executing = False
    ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=4))
    expected = 17
    self.assertAllClose(ans, expected, check_dtypes=False)

    python_should_be_executing = False
    ans = jit_cumsum([jnp.array([5, 2, 9, 1, 4])], dict(n=1))
    expected = 5
    self.assertAllClose(ans, expected, check_dtypes=False)

  # TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet
  @unittest.skip("Shapecheck fails")
  def test_mean(self):
    self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '',
               {'n': 3}, [(4,)], ['float_'],
               jtu.rand_default(self.rng()))

  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_arithmetic(self):
    @partial(mask, in_shapes=['(n, m)', 'm'], out_shape='(n, m)')
    def times(x, y):
      return x * y

    # TODO(shoyer): enable this check when broadcast_in_dim supports masking
    with self.assertRaisesRegex(
        NotImplementedError,
        'Masking rule for broadcast_in_dim not implemented yet.'):
      times([jnp.array([[1, 2], [3, 4], [5, 6]]), jnp.array([1, 2])],
            dict(n=4, m=5))
      # expected = np.array([[1, 2, 3], [8, 10, 12]])
      # self.assertAllClose(ans, expected, check_dtypes=False)

  def test_stack(self):
    @partial(mask, in_shapes=['n','n'], out_shape='(2, n)')
    def stack(x, y):
      return jnp.stack([x, y], 0)

    # TODO(shoyer): enable this check when broadcast_in_dim supports masking
    with self.assertRaisesRegex(
        NotImplementedError,
        'Masking rule for broadcast_in_dim not implemented yet.'):
      stack([jnp.array([1, 2, 3]), jnp.array([4, 5, 6])], dict(n=10))
      # expected = np.array([[1, 2, 3], [4, 5, 6]])
      # self.assertAllClose(ans, expected, check_dtypes=False)

  def test_monomorphic(self):
    @partial(mask, in_shapes=['(_, n)'], out_shape='')
    def padded_sum(x):
      return jnp.sum(x)

    ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
    expected = 8
    self.assertAllClose(ans, expected, check_dtypes=False)

  def test_monomorphic2(self):
    @partial(mask, in_shapes=['(_, n)'], out_shape='n')
    def padded_sum(x):
      return jnp.sum(x, axis=0)

    ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=2))
    expected = jnp.array([8, 10])
    self.assertAllClose(ans, expected, check_dtypes=False)

  def test_monomorphic3(self):
    @partial(mask, in_shapes=['(_, n)'], out_shape='_')
    def padded_sum(x):
      return jnp.sum(x, axis=1)

    ans = padded_sum([jnp.array([[3, 4], [5, 6]])], dict(n=1))
    expected = jnp.array([3, 5])
    self.assertAllClose(ans, expected, check_dtypes=False)

    @shapecheck(['(2*n, n)'], '_, n')
    def identity(x):
      return x

  def test_rnn(self):
    n = 3

    @partial(mask, in_shapes=['(_, _)', '(t, _)'], out_shape='_')
    def rnn(W, xs):
      def step(h, x):
        new_h = jnp.dot(W, h) + jnp.dot(W, x)
        return new_h, ()
      predicted, _ = lax.scan(step, jnp.zeros(n), xs)
      return predicted

    rng = self.rng()
    W = jnp.eye(n)
    xs = rng.randn(10, n).astype(jnp.float_)
    ans = rnn([W, xs], dict(t=4))
    expected = xs[:4].sum(0)
    self.assertAllClose(ans, expected, check_dtypes=False)

  def test_rnn_grad(self):
    n = 3

    @partial(mask, in_shapes=['(_, _)', '(t, _)', '_'], out_shape='')
    def rnn(W, xs, target):
      def step(h, x):
        new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
        return new_h, ()
      predicted, _ = lax.scan(step, jnp.zeros(n), xs)
      return jnp.sum((predicted - target)**2)

    rng = self.rng()
    W = rng.randn(n, n).astype(jnp.float_)
    xs = rng.randn(10, n).astype(jnp.float_)
    y = rng.randn(n).astype(jnp.float_)

    ans = grad(lambda W: rnn([W, xs, y], dict(t=4)))(W)

    def rnn_reference(W, xs, target):
      h = jnp.zeros(n)
      for x in xs:
        h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
      predicted = h
      return jnp.sum((predicted - target)**2)

    expected = grad(lambda W: rnn_reference(W, xs[:4], y))(W)

    self.assertAllClose(ans, expected, check_dtypes=False,
                        rtol={np.float64: 1e-14, np.float32: 1e-5})

  def test_ragged_batched_rnn(self):
    n = 3

    @partial(mask, in_shapes=('(_, _)', '(t, _)', '_'), out_shape='')
    def rnn(W, xs, target):
      def step(h, x):
        new_h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
        return new_h, ()
      predicted, _ = lax.scan(step, jnp.zeros(n), xs)
      return jnp.sum((predicted - target)**2)

    rng = self.rng()
    W = rng.randn(n, n).astype(jnp.float_)
    seqs = rng.randn(3, 10, n).astype(jnp.float_)
    ts = jnp.array([2, 5, 4])
    ys = rng.randn(3, n)

    ans = grad(lambda W: vmap(rnn, ((None, 0, 0), 0))((W, seqs, ys), dict(t=ts)).sum())(W)

    def rnn_reference(W, seqs, targets):
      total_loss = jnp.array(0.0)
      for xs, target in zip(seqs, targets):
        h = jnp.zeros(n)
        for x in xs:
          h = jnp.tanh(jnp.dot(W, h) + jnp.dot(W, x))
        predicted = h
        total_loss = total_loss + jnp.sum((predicted - target)**2)
      return total_loss

    seqs_ = [xs[:t] for xs, t in zip(seqs, ts)]
    expected = grad(lambda W: rnn_reference(W, seqs_, ys).sum())(W)

    self.assertAllClose(
      ans, expected, check_dtypes=False,
      rtol=0.1 if jtu.device_under_test() == "tpu" else 1e-5)

  def test_concatenate(self):
    self.check(lambda x, y, z: lax.concatenate([x, y, z], 0),
               ['n', 'm', 'n'], 'm + 2 * n', {'n': 2, 'm': 3},
               [(4,), (3,), (4,)], ['float_', 'float_', 'float_'],
               jtu.rand_default(self.rng()))

  def test_dot(self):
    self.check(lax.dot, ['(m, k)', '(k, n)'], '(m, n)',
               dict(m=2, k=3, n=4), [(4, 5), (5, 7)], ['float_', 'float_'],
               jtu.rand_default(self.rng()))
    self.check(lax.dot, ['(m, n)', 'n'], 'm', dict(m=2, n=3), [(4, 5), (5,)],
               ['float_', 'float_'], jtu.rand_default(self.rng()))

  # TODO(mattjj,j-towns): fix test failure and reenable.
  @jtu.skip_on_devices("tpu")
  def test_jit(self):
    @partial(mask, in_shapes=['n'], out_shape='2*n')
    @jit
    def duplicate(x):
      assert python_should_be_executing
      return lax.concatenate([x, x], 0)

    python_should_be_executing = True
    out = duplicate([jnp.arange(3)], dict(n=2))
    assert np.all(np.array([0, 1, 0, 1]) == out[:4])

    python_should_be_executing = False
    out = duplicate([jnp.arange(3)], dict(n=2))
    assert np.all(np.array([0, 1, 0, 1]) == out[:4])

  @unittest.skip("broken by omnistaging")  # TODO(mattjj): update
  def test_jit2(self):
    # Trigger MaskTrace.post_process_call
    def fun(x):
      @jit
      def concat(y):
        return lax.concatenate([x, y], 0)
      return concat(jnp.array([1., 2., 3.], dtype='float32'))

    self.check(fun, ['n'], '(n+3,)', {'n': 2}, [(3,)], ['float32'],
               jtu.rand_default(self.rng()))

  @parameterized.named_parameters({
      'testcase_name': "padding_config={}_shapes={}".format(padding_config,
                                                            shape),
      'padding_config': padding_config,
      'shape': shape} for padding_config, shape in (
          (((1, 2, 0),), (2,)),
          (((1, 2, 0), (3, 4, 0)), (1, 2)),
          (((0, 0, 0), (0, 0, 0)), (1, 2)),
          (((1, 2, 3),), (2,)),
          (((1, 2, 1), (3, 4, 2)), (3, 2)),
          (((-1, 2, 0),), (2,)),
          (((-1, -2, 0), (1, 2, 0)), (4, 2)),
          (((-1, 2, 0), (1, 2, 2)), (4, 2)),
          (((-1, -2, 2),), (5,)),
          (((-1, -2, 1), (1, 2, 2)), (4, 2))))
  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_pad(self, padding_config, shape):
    def pad(x):
      return lax.pad(x, jnp.array(1., x.dtype), padding_config)

    if len(shape) == 1:
      padding_config_, = padding_config
      linear_coeff = padding_config_[2] + 1
      const_coeff = sum(padding_config_[:2]) - padding_config_[2]
      out_shape = str(linear_coeff) + ' * h + ' + str(const_coeff)
      self.check(pad, ['h'], out_shape, dict(h=shape[0]),
                 [tuple(np.add(shape, 1))], ['float_'],
                 jtu.rand_default(self.rng()))


  # TODO(mattjj,j-towns): fix test failure and reenable.
  @jtu.skip_on_devices("tpu")
  @unittest.skip("broken by omnistaging")  # TODO(mattjj): update
  def test_numpy_pad(self):
    def numpy_pad(x):
      return jnp.pad(x, (0, 1), constant_values=5.)

    self.check(numpy_pad, ['n'], 'n + 1', dict(n=2), [(3,)], ['float_'],
               jtu.rand_default(self.rng()))

  @parameterized.named_parameters(jtu.cases_from_list(
      {'testcase_name': "padding={}_lhs_dilation={}_"
       "dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
           padding, lhs_dilation, dimension_numbers, lhs_perm,
           rhs_perm, out_perm),
      'padding': padding, 'lhs_dilation': lhs_dilation,
      'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
      'rhs_perm': rhs_perm, 'out_perm': out_perm}
    for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
    for lhs_dilation in (None, (1, 2))
    for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
            (("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
            (("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
            (("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
    )
    # String padding is not implemented for transposed convolution, see
    # conv_general_dilated implementation:
    if (lhs_dilation is None or not isinstance(padding, str))))
  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_conv(
          self, padding, lhs_dilation, dimension_numbers, lhs_perm,
          rhs_perm, out_perm):
    def conv(lhs, rhs):
      return lax.conv_general_dilated(
        lhs, rhs, (1, 1), padding, lhs_dilation=lhs_dilation,
        dimension_numbers=dimension_numbers)

    template =  '({}, {}, {}, {})'
    lhs_shape = template.format(*np.take(['n', 'c', 'h', 'w'], lhs_perm))
    rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
    if padding == 'VALID':
      out_shape = template.format(
        *np.take(['n', 'o', 'h+-1', 'w+-2'], out_perm))
    elif lhs_dilation:
      out_shape = template.format(
        *np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
    else:
      out_shape = template.format(
        *np.take(['n', 'o', 'h', 'w'], out_perm))

    logical_env = dict(n=3, c=2, h=4, w=5, o=6)

    self.check(conv, [lhs_shape, rhs_shape], out_shape,
               logical_env, [tuple(np.take([4, 3, 6, 7], lhs_perm)),
                             tuple(np.take([7, 3, 2, 3], rhs_perm))],
               ['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4,
               atol=1e-4)

  @parameterized.named_parameters(jtu.cases_from_list(
      {'testcase_name': "padding={}_lhs_dilation={}_"
       "dimension_numbers={}_lhs_perm={}_rhs_perm={}_out_perm={}".format(
           padding, lhs_dilation, dimension_numbers, lhs_perm,
           rhs_perm, out_perm),
      'padding': padding, 'lhs_dilation': lhs_dilation,
      'dimension_numbers': dimension_numbers, 'lhs_perm': lhs_perm,
      'rhs_perm': rhs_perm, 'out_perm': out_perm}
    for padding in ['SAME', 'VALID', ((0, 1), (2, 0))]
    for lhs_dilation in (None, (1, 2))
    for dimension_numbers, (lhs_perm, rhs_perm, out_perm) in (
            (("NCHW", "OIHW", "NCHW"), ((0, 1, 2, 3), (0, 1, 2, 3), (0, 1, 2, 3))),
            (("NHWC", "HWIO", "NHWC"), ((0, 2, 3, 1), (2, 3, 1, 0), (0, 2, 3, 1))),
            (("NCHW", "HWIO", "NHWC"), ((0, 1, 2, 3), (2, 3, 1, 0), (0, 2, 3, 1)))
    )
    # String padding is not implemented for transposed convolution, see
    # conv_general_dilated implementation:
    if (lhs_dilation is None or not isinstance(padding, str))))
  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_conv_strided(
          self, padding, lhs_dilation, dimension_numbers, lhs_perm,
          rhs_perm, out_perm):
    def conv(lhs, rhs):
      return lax.conv_general_dilated(
        lhs, rhs, (2, 1), padding, lhs_dilation=lhs_dilation,
        dimension_numbers=dimension_numbers)

    template =  '({}, {}, {}, {})'
    rhs_shape = template.format(*np.take(['o', 'c', '2', '3'], rhs_perm))
    if padding == 'VALID':
      lhs_shape = template.format(*np.take(['n', 'c', '2*h+1', 'w'], lhs_perm))
      lhs_shape_padded = tuple(np.take([4, 3, 5, 7], lhs_perm))
      out_shape = template.format(*np.take(['n', 'o', 'h', 'w+-2'], out_perm))
    elif lhs_dilation:
      lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
      lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
      out_shape = template.format(*np.take(['n', 'o', 'h', '2*w+-1'], out_perm))
    else:
      lhs_shape = template.format(*np.take(['n', 'c', '2*h', 'w'], lhs_perm))
      lhs_shape_padded = tuple(np.take([4, 3, 6, 7], lhs_perm))
      out_shape = template.format(*np.take(['n', 'o', 'h', 'w'], out_perm))

    logical_env = dict(n=3, c=2, h=4, w=5, o=6)

    self.check(conv, [lhs_shape, rhs_shape], out_shape,
               logical_env, [lhs_shape_padded,
                             tuple(np.take([7, 3, 2, 3], rhs_perm))],
               ['float_', 'float_'], jtu.rand_default(self.rng()), rtol=1e-4,
               atol=1e-4)

  @unittest.skip("requires gather support")
  def test_indexing(self):
    self.check(lambda x: x[0], ['n'], '', {'n': 2}, [(3,)], ['float_'],
               jtu.rand_default(self.rng()))
    self.check(lambda x: x[-1], ['n'], '', {'n': 2}, [(3,)], ['float_'],
               jtu.rand_default(self.rng()))

  @unittest.skip("requires gather support")
  def test_slicing(self):
    self.check(lambda x: x[1:], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
    self.check(lambda x: x[:-1], ['n'], 'n+-1', {'n': 2}, [(3,)], ['float_'])
    self.check(lambda x: x[..., -1], ['(n,3)'], 'n', {'n': 2}, [(3, 4)], ['float_'])

  def test_rev(self):
    @shapecheck(['n'], 'n')
    def rev1(x):
      return lax.rev(x, (0,))

    @shapecheck(['(m, n)'], '(m, n)')
    def rev2(x):
      return lax.rev(x, (1,))

  @unittest.skip("TODO")
  def test_rev_by_indexing(self):

    @shapecheck(['n'], 'n+-1')
    def rev1(x):
      return x[:0:-1]

    @shapecheck(['n'], 'n+-1')
    def rev2(x):
      return x[-2::-1]

    # TODO implement masking for rev_p:
    # self.check(lambda x: x[:0:-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')
    # self.check(lambda x: x[-2::-1], ['n'], dict(n=jnp.array([2, 3])), 'n+-1')

  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_lax_slice(self):
    self.check(lambda x: lax.slice(x, (1,), (x.shape[0],)), ['n'], 'n+-1',
               {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))
    # TODO self.check(lambda x: lax.slice(x, (x.shape[0] // 2,), (x.shape[0],)),
    #  ['2*n'], 'n', {'n': 2}, [(6,)], ['float_'], jtu.rand_default(self.rng()))
    self.check(lambda x: lax.slice(x, (0,), (x.shape[0],), (x.shape[0],)),
               ['n'], '1', {'n': 2}, [(5,)], ['float_'],
               jtu.rand_default(self.rng()))

  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_reshape(self):
    self.check(lambda x: jnp.reshape(x, (x.shape[1], 2, 4, 1)),
               ['1, n, 4, 2'], 'n, 2, 4, 1', dict(n=2), [(1, 3, 4, 2)],
               ['float_'], jtu.rand_default(self.rng()))

    self.check(lambda x: jnp.reshape(x, (x.shape[0] * 2,)),
               ['n, 2'], '2 * n', dict(n=2), [(3, 2)],
               ['float_'], jtu.rand_default(self.rng()))

    self.check(lambda x: jnp.reshape(x, (x.shape[0] // 2, 2)),
               ['2 * n'], 'n, 2', dict(n=2), [(6,)],
               ['float_'], jtu.rand_default(self.rng()))

    self.check(lambda x: jnp.reshape(x, (x.shape[0] * 4, 2)),
               ['n, 2, 4'], '4 * n, 2', dict(n=2), [(3, 2, 4)],
               ['float_'], jtu.rand_default(self.rng()))

    self.check(lambda x: jnp.reshape(x, ((x.shape[0] - 1) // 4 + 1, 2, 4)),
               ['4 * n + 4, 2'], 'n + 1, 2, 4', dict(n=2), [(12, 2)],
               ['float_'], jtu.rand_default(self.rng()))

    msg = "Reshape on padded dimensions causing fragmentation is not supported."
    with self.assertRaisesRegex(NotImplementedError, msg):
      self.check(lambda x: jnp.reshape(x, np.prod(x.shape)),
                 ['a, b'], 'a*b', dict(a=2, b=3), [(3, 4)],
                 ['float_'], jtu.rand_default(self.rng()))

    with self.assertRaisesRegex(NotImplementedError, msg):
      self.check(lambda x: jnp.reshape(x, (x.shape[1], x.shape[0])),
                 ['a, b'], 'b, a', dict(a=2, b=3), [(3, 4)],
                 ['float_'], jtu.rand_default(self.rng()))

    with self.assertRaisesRegex(NotImplementedError, msg):
      self.check(lambda x: jnp.reshape(x, (x.shape[1] * 2,)),
                 ['2, n'], '2 * n', dict(n=2), [(2, 3)],
                 ['float_'], jtu.rand_default(self.rng()))

    self.check(lambda x: jnp.reshape(x, (x.shape[0], -1)),
               ['n, 3, 1, 2'], 'n, 6', dict(n=1), [(2, 3, 1, 2)],
               ['float_'], jtu.rand_default(self.rng()))

  def test_transpose(self):
    self.check(lambda x: lax.transpose(x, (1, 0, 2)),
               ['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)],
               ['float_'], jtu.rand_default(self.rng()))

  def test_sum_2d(self):
    self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'],
               jtu.rand_default(self.rng()))

  @unittest.skip("custom_jvp doesn't work with masking yet")
  def test_expit(self):
    self.check(expit, ['n'], 'n', dict(n=3), [(4,)], ['float_'],
               jtu.rand_default(self.rng()))

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": "_{}".format(dtype), "dtype": np.dtype(dtype).name}
    for dtype in [np.float32, np.float64]))
  @unittest.skip("not yet implemented")
  def test_uniform(self, dtype):
    # TODO needs fix for https://github.com/google/jax/issues/2155
    pass

  @unittest.skip("not yet implemented")
  def test_broadcast_in_dim(self):
    pass

  def test_destructure(self):
    def d(key):
      key1, key2 = key
      return key1

    self.check(d, ['2'], '', {}, [(2,)], ['int_'], jtu.rand_int(self.rng(), 0, 10))

  # TODO(mattjj,j-towns): fix test failure and reenable.
  @jtu.skip_on_devices("tpu")
  def test_where(self):
    self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n',
               {'n': 2}, [(3,)], ['float_'], jtu.rand_default(self.rng()))

  @unittest.skip("Failing after fixing Poly unsoundness #4878")
  def test_split(self):
    self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4),
               [(8,)], ['float_'], jtu.rand_default(self.rng()))
    self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12),
               [(12,)], ['float_'], jtu.rand_default(self.rng()))

  @parameterized.named_parameters(jtu.cases_from_list([{
    'testcase_name': "operator={}".format(operator.__name__), 'operator': operator}
    for operator in [jnp.sum, jnp.prod, jnp.max, jnp.min]]))
  def test_reduce(self, operator):
    self.check(operator, ['(m+1, n+1)'], '', {'m': 3, 'n': 4}, [(4, 5)], ['float_'],
               jtu.rand_default(self.rng()))

  def test_output_shape_error(self):
    def thunk():
      shapecheck(['n'], 'n+-1')(lambda x: x)

    message = "Output shapes should be (n + -1,) but are (n,)."
    self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)

    def thunk():
      shapecheck(['n'], ['7*n', 'n'])(lambda x: (x, x))

    message = "Output shapes should be [(7 n,), (n,)] but are ((n,), (n,))."
    self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)

  def test_output_tree_error(self):
    def thunk():
      shapecheck(['n'], ('n', 'n'))(lambda x: [x, x])

    message = "Output shapes should be ((n,), (n,)) but are [(n,), (n,)]."
    self.assertRaisesWithLiteralMatch(ShapeError, message, thunk)

  def test_unsupported_op(self):
    p = core.Primitive('unsupported_op')
    p.def_abstract_eval(lambda x: x)
    p.def_impl(lambda x: x)

    def thunk():
      mask(p.bind, ['n'], 'n')([np.arange(3)], {'n': 2})

    message = "Masking rule for unsupported_op not implemented yet."
    self.assertRaisesWithLiteralMatch(NotImplementedError, message, thunk)

  @unittest.skip("not yet implemented")
  def test_nesting(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def padded_sum(x):
      return jnp.sum(x)

    batched_sum = vmap(padded_sum)

    @partial(mask, in_shapes=['(m, _)', 'm'], out_shape='')
    def fun(x, ns):
      return batched_sum([x], dict(n=ns)).sum()

    x = jnp.array([[3, 1, 4, 1],
                  [5, 9, 2, 6],
                  [5, 3, 5, 8]])
    ns = jnp.array([2, 3, 2])
    ans = fun([x, ns], dict(m=2))
    expected = 3+1 + 5+9+2
    self.assertAllClose(ans, expected, check_dtypes=False)


  def test_slice_oob_indexing(self):
    # https://github.com/google/jax/issues/2245
    self.assertAllClose(jnp.ones(5), jnp.ones(5)[:10])
    self.assertAllClose(jnp.ones(5), jnp.ones(5)[-10:])

  def test_jaxpr_doesnt_include_trivial_operations(self):
    @partial(mask, in_shapes=['n'], out_shape='')
    def foo(x):
      return np.sum(x)

    padded_x = np.array([0, 1, 2, 3, 999, 999])

    jaxpr = make_jaxpr(foo)([padded_x], dict(n=3))
    self.assertNotIn('mul', str(jaxpr))
    self.assertNotIn('add', str(jaxpr))

  def test_return_shape_to_user(self):
    @partial(mask, in_shapes=['n'])
    def foo(x):
      return [x, np.sum(x)]

    out, out_shape = foo([np.arange(5)], dict(n=2))
    self.assertIsInstance(out_shape, list)
    self.assertLen(out_shape, 2)
    a, b = out_shape
    self.assertEqual(a.shape, (2,))
    self.assertEqual(b.shape, ())
Example #13
0
class TestBFGS(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": f"_func={func_and_init[0].__name__}_maxiter={maxiter}",
     "maxiter": maxiter, "func_and_init": func_and_init}
    for maxiter in [None]
    for func_and_init in [(rosenbrock, np.zeros(2)),
                          (himmelblau, np.zeros(2)),
                          (matyas, np.ones(2) * 6.),
                          (eggholder, np.ones(2) * 100.)]))
  def test_minimize(self, maxiter, func_and_init):
    # Note, cannot compare step for step with scipy BFGS because our line search is _slightly_ different.

    func, x0 = func_and_init

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          func(jnp),
          x0,
          method='BFGS',
          options=dict(maxiter=maxiter, gtol=1e-6),
      )
      return result.x

    jax_res = min_op(x0)
    scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x
    self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)

  def test_fixes4594(self):
    n = 2
    A = jnp.eye(n) * 1e4
    def f(x):
      return jnp.mean((A @ x) ** 2)
    results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS')
    self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6)

  @jtu.skip_on_flag('jax_enable_x64', False)
  def test_zakharov(self):
    def zakharov_fn(x):
      ii = jnp.arange(1, len(x) + 1, step=1)
      answer = zakharovFromIndices(x=x, ii=ii)
      return answer

    x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4])
    eval_func = jax.jit(zakharov_fn)
    jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS')
    self.assertLess(jax_res.fun, 1e-6)

  def test_minimize_bad_initial_values(self):
    # This test runs deliberately "bad" initial values to test that handling
    # of failed line search, etc. is the same across implementations
    initial_value = jnp.array([92, 0.001])
    opt_fn = himmelblau(jnp)
    jax_res = jax.scipy.optimize.minimize(
        fun=opt_fn,
        x0=initial_value,
        method='BFGS',
    ).x
    scipy_res = scipy.optimize.minimize(
        fun=opt_fn,
        jac=jax.grad(opt_fn),
        method='BFGS',
        x0=initial_value
    ).x
    self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False)


  def test_args_must_be_tuple(self):
    A = jnp.eye(2) * 1e4
    def f(x):
      return jnp.mean((A @ x) ** 2)
    with self.assertRaisesRegex(TypeError, "args .* must be a tuple"):
      jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
Example #14
0
class SvdTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint:disable=g-complex-comprehension
                'testcase_name':
                '_m={}_by_n={}_log_cond={}_full_matrices={}'.format(
                    m, n, log_cond, full_matrices),
                'm':
                m,
                'n':
                n,
                'log_cond':
                log_cond,
                'full_matrices':
                full_matrices
            } for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
            for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)
            for full_matrices in [True, False]))
    def testSvdWithRectangularInput(self, m, n, log_cond, full_matrices):
        """Tests SVD with rectangular input."""
        with jax.default_matmul_precision('float32'):
            a = np.random.uniform(low=0.3, high=0.9,
                                  size=(m, n)).astype(_SVD_TEST_DTYPE)
            u, s, v = osp_linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, min(m, n))
            a = (u * s) @ v
            a = a.astype(complex) * (1 + 1j)

            osp_linalg_fn = functools.partial(osp_linalg.svd,
                                              full_matrices=full_matrices)
            actual_u, actual_s, actual_v = svd.svd(a,
                                                   full_matrices=full_matrices)

            k = min(m, n)
            if m > n:
                unitary_u = jnp.real(actual_u.T.conj() @ actual_u)
                unitary_v = jnp.real(actual_v.T.conj() @ actual_v)
                unitary_u_size = m if full_matrices else k
                unitary_v_size = k
            else:
                unitary_u = jnp.real(actual_u @ actual_u.T.conj())
                unitary_v = jnp.real(actual_v @ actual_v.T.conj())
                unitary_u_size = k
                unitary_v_size = n if full_matrices else k

            _, expected_s, _ = osp_linalg_fn(a)

            svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices)
            args_maker = lambda: [a]

            with self.subTest('Test JIT compatibility'):
                self._CompileAndCheck(svd_fn, args_maker)

            with self.subTest('Test unitary u.'):
                self.assertAllClose(np.eye(unitary_u_size),
                                    unitary_u,
                                    rtol=_SVD_RTOL,
                                    atol=2E-3)

            with self.subTest('Test unitary v.'):
                self.assertAllClose(np.eye(unitary_v_size),
                                    unitary_v,
                                    rtol=_SVD_RTOL,
                                    atol=2E-3)

            with self.subTest('Test s.'):
                self.assertAllClose(expected_s,
                                    jnp.real(actual_s),
                                    rtol=_SVD_RTOL,
                                    atol=1E-6)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name': f'_m={m}_by_n={n}',
            'm': m,
            'n': n
        } for m, n in zip([50, 6], [3, 60])))
    def testSvdWithSkinnyTallInput(self, m, n):
        """Tests SVD with skinny and tall input."""
        # Generates a skinny and tall input
        with jax.default_matmul_precision('float32'):
            np.random.seed(1235)
            a = np.random.randn(m, n).astype(_SVD_TEST_DTYPE)
            u, s, v = svd.svd(a, full_matrices=False, hermitian=False)

            relative_diff = np.linalg.norm(a - (u * s) @ v) / np.linalg.norm(a)

            np.testing.assert_almost_equal(relative_diff, 1E-6, decimal=6)

    @parameterized.named_parameters(
        jtu.cases_from_list({  # pylint:disable=g-complex-comprehension
            'testcase_name': f'_m={m}_r={r}_log_cond={log_cond}',
            'm': m,
            'r': r,
            'log_cond': log_cond
        } for m, r in zip([8, 8, 8, 10], [3, 5, 7, 9])
                            for log_cond in np.linspace(1, 3, 3)))
    def testSvdWithOnRankDeficientInput(self, m, r, log_cond):
        """Tests SVD with rank-deficient input."""
        with jax.default_matmul_precision('float32'):
            a = jnp.triu(jnp.ones((m, m))).astype(_SVD_TEST_DTYPE)

            # Generates a rank-deficient input.
            u, s, v = jnp.linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = jnp.linspace(cond, 1, m)
            s = s.at[r:m].set(jnp.zeros((m - r, )))
            a = (u * s) @ v

            with jax.default_matmul_precision('float32'):
                u, s, v = svd.svd(a, full_matrices=False, hermitian=False)
            diff = np.linalg.norm(a - (u * s) @ v)

            np.testing.assert_almost_equal(diff, 1E-4, decimal=2)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {  # pylint:disable=g-complex-comprehension
                'testcase_name':
                '_m={}_by_n={}_log_cond={}_full_matrices={}'.format(
                    m, n, log_cond, full_matrices),
                'm':
                m,
                'n':
                n,
                'log_cond':
                log_cond,
                'full_matrices':
                full_matrices
            } for m, n in zip([2, 8, 10, 20], [4, 6, 10, 18])
            for log_cond in np.linspace(1, _MAX_LOG_CONDITION_NUM, 4)
            for full_matrices in [True, False]))
    def testSingularValues(self, m, n, log_cond, full_matrices):
        """Tests singular values."""
        with jax.default_matmul_precision('float32'):
            a = np.random.uniform(low=0.3, high=0.9,
                                  size=(m, n)).astype(_SVD_TEST_DTYPE)
            u, s, v = osp_linalg.svd(a, full_matrices=False)
            cond = 10**log_cond
            s = np.linspace(cond, 1, min(m, n))
            a = (u * s) @ v
            a = a + 1j * a

            # Only computes singular values.
            compute_uv = False

            osp_linalg_fn = functools.partial(osp_linalg.svd,
                                              full_matrices=full_matrices,
                                              compute_uv=compute_uv)
            actual_s = svd.svd(a,
                               full_matrices=full_matrices,
                               compute_uv=compute_uv)

            expected_s = osp_linalg_fn(a)

            svd_fn = lambda a: svd.svd(a, full_matrices=full_matrices)
            args_maker = lambda: [a]

            with self.subTest('Test JIT compatibility'):
                self._CompileAndCheck(svd_fn, args_maker)

            with self.subTest('Test s.'):
                self.assertAllClose(expected_s,
                                    actual_s,
                                    rtol=_SVD_RTOL,
                                    atol=1E-6)

            with self.subTest('Test non-increasing order.'):
                # Computes `actual_diff[i] = s[i+1] - s[i]`.
                actual_diff = jnp.diff(actual_s, append=0)
                np.testing.assert_array_less(actual_diff,
                                             np.zeros_like(actual_diff))

    @parameterized.named_parameters([
        {
            'testcase_name': f'_m={m}_by_n={n}_full_matrices={full_matrices}_'  # pylint:disable=g-complex-comprehension
            f'compute_uv={compute_uv}_dtype={dtype}',
            'm': m,
            'n': n,
            'full_matrices': full_matrices,  # pylint:disable=undefined-variable
            'compute_uv': compute_uv,
            'dtype': dtype
        }  # pylint:disable=undefined-variable
        for m, n in zip([2, 4, 8], [4, 4, 6])
        for full_matrices in [True, False] for compute_uv in [True, False]
        for dtype in jtu.dtypes.floating + jtu.dtypes.complex
    ])
    def testSvdOnZero(self, m, n, full_matrices, compute_uv, dtype):
        """Tests SVD on matrix of all zeros."""
        osp_fun = functools.partial(osp_linalg.svd,
                                    full_matrices=full_matrices,
                                    compute_uv=compute_uv)
        lax_fun = functools.partial(svd.svd,
                                    full_matrices=full_matrices,
                                    compute_uv=compute_uv)
        args_maker_svd = lambda: [jnp.zeros((m, n), dtype=dtype)]
        self._CheckAgainstNumpy(osp_fun, lax_fun, args_maker_svd)
        self._CompileAndCheck(lax_fun, args_maker_svd)

    @parameterized.named_parameters([{
        'testcase_name': f'_m={m}_by_n={n}_r={r}_c={c}_dtype={dtype}',
        'm': m,
        'n': n,
        'r': r,
        'c': c,
        'dtype': dtype
    } for m, n, r, c in zip([2, 4, 8], [4, 4, 6], [1, 0, 1], [1, 0, 1])
                                     for dtype in jtu.dtypes.floating])
    def testSvdOnTinyElement(self, m, n, r, c, dtype):
        """Tests SVD on matrix of zeros and close-to-zero entries."""
        a = jnp.zeros((m, n), dtype=dtype)
        tiny_element = jnp.finfo(a).tiny
        a = a.at[r, c].set(tiny_element)

        @jax.jit
        def lax_fun(a):
            return svd.svd(a,
                           full_matrices=False,
                           compute_uv=False,
                           hermitian=False)

        actual_s = lax_fun(a)

        k = min(m, n)
        expected_s = np.zeros((k, ), dtype=dtype)
        expected_s[0] = tiny_element

        self.assertAllClose(expected_s,
                            jnp.real(actual_s),
                            rtol=_SVD_RTOL,
                            atol=1E-6)
Example #15
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @genNamedParametersNArgs(3)
    def testPoissonLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.logpmf
        lax_fun = lsp_stats.poisson.logpmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})

    @genNamedParametersNArgs(3)
    def testPoissonPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.pmf
        lax_fun = lsp_stats.poisson.pmf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testPoissonCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.poisson.cdf
        lax_fun = lsp_stats.poisson.cdf

        def args_maker():
            k, mu, loc = map(rng, shapes, dtypes)
            # clipping to ensure that rate parameter is strictly positive
            mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
            return [k, mu, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testBernoulliLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.bernoulli.logpmf
        lax_fun = lsp_stats.bernoulli.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = np.floor(x)
            p = expit(logit)
            loc = np.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testGeomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.geom.logpmf
        lax_fun = lsp_stats.geom.logpmf

        def args_maker():
            x, logit, loc = map(rng, shapes, dtypes)
            x = np.floor(x)
            p = expit(logit)
            loc = np.floor(loc)
            return [x, p, loc]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(5)
    def testBetaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.beta.logpdf
        lax_fun = lsp_stats.beta.logpdf

        def args_maker():
            x, a, b, loc, scale = map(rng, shapes, dtypes)
            return [x, a, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={
                                  np.float32: 2e-3,
                                  np.float64: 1e-4
                              })

    def testBetaLogPdfZero(self):
        # Regression test for https://github.com/google/jax/issues/7645
        a = b = 1.
        x = np.array([0., 1.])
        self.assertAllClose(osp_stats.beta.pdf(x, a, b),
                            lsp_stats.beta.pdf(x, a, b),
                            atol=1E-6)

    @genNamedParametersNArgs(3)
    def testCauchyLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.cauchy.logpdf
        lax_fun = lsp_stats.cauchy.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes),
            "shapes": [x_shape, alpha_shape],
            "dtypes":
            dtypes
        } for x_shape in one_and_two_dim_shapes for alpha_shape in [(
            x_shape[0], ), (
                x_shape[0] +
                1, )] for dtypes in itertools.combinations_with_replacement(
                    jtu.dtypes.floating, 2)))
    def testDirichletLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())

        def _normalize(x, alpha):
            x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
            return (x / x_norm).astype(x.dtype), alpha

        def lax_fun(x, alpha):
            return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))

        def scipy_fun(x, alpha):
            # scipy validates the x normalization using float64 arithmetic, so we must
            # cast x to float64 before normalization to ensure this passes.
            x, alpha = _normalize(x.astype('float64'), alpha)

            result = osp_stats.dirichlet.logpdf(x, alpha)
            # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
            # of a consistent rank. This check ensures the results have the same shape.
            return result if x.ndim == 1 else np.atleast_1d(result)

        def args_maker():
            # Don't normalize here, because we want normalization to happen at 64-bit
            # precision in the scipy version.
            x, alpha = map(rng, shapes, dtypes)
            return x, alpha

        tol = {np.float32: 1E-3, np.float64: 1e-5}
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)

    @genNamedParametersNArgs(3)
    def testExponLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.expon.logpdf
        lax_fun = lsp_stats.expon.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testGammaLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.gamma.logpdf
        lax_fun = lsp_stats.gamma.logpdf

        def args_maker():
            x, a, loc, scale = map(rng, shapes, dtypes)
            return [x, a, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    def testGammaLogPdfZero(self):
        # Regression test for https://github.com/google/jax/issues/7256
        self.assertAllClose(osp_stats.gamma.pdf(0.0, 1.0),
                            lsp_stats.gamma.pdf(0.0, 1.0),
                            atol=1E-6)

    @genNamedParametersNArgs(4)
    def testNBinomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.nbinom.logpmf
        lax_fun = lsp_stats.nbinom.logpmf

        def args_maker():
            k, n, logit, loc = map(rng, shapes, dtypes)
            k = np.floor(np.abs(k))
            n = np.ceil(np.abs(n))
            p = expit(logit)
            loc = np.floor(loc)
            return [k, n, p, loc]

        tol = {np.float32: 1e-6, np.float64: 1e-8}
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    @genNamedParametersNArgs(3)
    def testLaplaceLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.laplace.logpdf
        lax_fun = lsp_stats.laplace.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testLaplaceCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.laplace.cdf
        lax_fun = lsp_stats.laplace.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # ensure that scale is not too low
            scale = np.clip(scale, a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol={
                                    np.float32: 1e-5,
                                    np.float64: 1e-6
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.cdf
        lax_fun = lsp_stats.logistic.cdf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticLogpdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.logpdf
        lax_fun = lsp_stats.logistic.logpdf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticPpf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.ppf
        lax_fun = lsp_stats.logistic.ppf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(1)
    def testLogisticSf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.logistic.sf
        lax_fun = lsp_stats.logistic.sf

        def args_maker():
            return list(map(rng, shapes, dtypes))

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.logpdf
        lax_fun = lsp_stats.norm.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormLogCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.logcdf
        lax_fun = lsp_stats.norm.logcdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormCdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.cdf
        lax_fun = lsp_stats.norm.cdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-6)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(3)
    def testNormPpf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.norm.ppf
        lax_fun = lsp_stats.norm.ppf

        def args_maker():
            q, loc, scale = map(rng, shapes, dtypes)
            # ensure probability is between 0 and 1:
            q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [q, loc, scale]

        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)

    @genNamedParametersNArgs(4)
    def testParetoLogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.pareto.logpdf
        lax_fun = lsp_stats.pareto.logpdf

        def args_maker():
            x, b, loc, scale = map(rng, shapes, dtypes)
            return [x, b, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testTLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.t.logpdf
        lax_fun = lsp_stats.t.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            # clipping to ensure that scale is not too low
            scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-3)
        self._CompileAndCheck(lax_fun,
                              args_maker,
                              rtol={np.float64: 1e-14},
                              atol={np.float64: 1e-14})

    @genNamedParametersNArgs(3)
    def testUniformLogPdf(self, shapes, dtypes):
        rng = jtu.rand_default(self.rng())
        scipy_fun = osp_stats.uniform.logpdf
        lax_fun = lsp_stats.uniform.logpdf

        def args_maker():
            x, loc, scale = map(rng, shapes, dtypes)
            return [x, loc, np.abs(scale)]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(4)
    def testChi2LogPdf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        scipy_fun = osp_stats.chi2.logpdf
        lax_fun = lsp_stats.chi2.logpdf

        def args_maker():
            x, df, loc, scale = map(rng, shapes, dtypes)
            return [x, df, loc, scale]

        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker)

    @genNamedParametersNArgs(5)
    def testBetaBinomLogPmf(self, shapes, dtypes):
        rng = jtu.rand_positive(self.rng())
        lax_fun = lsp_stats.betabinom.logpmf

        def args_maker():
            k, n, a, b, loc = map(rng, shapes, dtypes)
            k = np.floor(k)
            n = np.ceil(n)
            a = np.clip(a, a_min=0.1, a_max=None)
            b = np.clip(a, a_min=0.1, a_max=None)
            loc = np.floor(loc)
            return [k, n, a, b, loc]

        if scipy_version >= (1, 4):
            scipy_fun = osp_stats.betabinom.logpmf
            self._CheckAgainstNumpy(scipy_fun,
                                    lax_fun,
                                    args_maker,
                                    check_dtypes=False,
                                    tol=5e-4)
        self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)

    def testIssue972(self):
        self.assertAllClose(np.ones((4, ), np.float32),
                            lsp_stats.norm.cdf(
                                np.full((4, ), np.inf, np.float32)),
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_x={}_mean={}_cov={}".format(
                jtu.format_shape_dtype_string(x_shape, x_dtype),
                jtu.format_shape_dtype_string(mean_shape, mean_dtype)
                if mean_shape is not None else None,
                jtu.format_shape_dtype_string(cov_shape, cov_dtype)
                if cov_shape is not None else None),
            "x_shape":
            x_shape,
            "x_dtype":
            x_dtype,
            "mean_shape":
            mean_shape,
            "mean_dtype":
            mean_dtype,
            "cov_shape":
            cov_shape,
            "cov_dtype":
            cov_dtype
        } for x_shape, mean_shape, cov_shape in [
            # # These test cases cover default values for mean/cov, but we don't
            # # support those yet (and they seem not very valuable).
            # [(), None, None],
            # [(), (), None],
            # [(2,), None, None],
            # [(2,), (), None],
            # [(2,), (2,), None],
            # [(3, 2), (3, 2,), None],
            # [(5, 3, 2), (5, 3, 2,), None],
            [(), (), ()],
            [(3, ), (), ()],
            [(3, ), (3, ), ()],
            [(3, ), (3, ), (3, 3)],
            [(3, 4), (4, ), (4, 4)],
            [(2, 3, 4), (4, ), (4, 4)],
        ] for x_dtype, mean_dtype, cov_dtype in
                            itertools.combinations_with_replacement(
                                jtu.dtypes.floating, 3)
                            if (mean_shape is not None
                                or mean_dtype == np.float32) and
                            (cov_shape is not None or cov_dtype == np.float32))
    )
    def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
                                     mean_dtype, cov_shape, cov_dtype):
        rng = jtu.rand_default(self.rng())

        def args_maker():
            args = [rng(x_shape, x_dtype)]
            if mean_shape is not None:
                args.append(5 * rng(mean_shape, mean_dtype))
            if cov_shape is not None:
                if cov_shape == ():
                    args.append(0.1 + rng(cov_shape, cov_dtype)**2)
                else:
                    factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
                    factor = rng(factor_shape, cov_dtype)
                    args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
            return args

        self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
                                lsp_stats.multivariate_normal.logpdf,
                                args_maker,
                                tol=1e-3)
        self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf,
                              args_maker,
                              rtol=1e-4,
                              atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_x={}_mean={}_cov={}".format(
                jtu.format_shape_dtype_string(x_shape, x_dtype),
                jtu.format_shape_dtype_string(mean_shape, mean_dtype)
                if mean_shape is not None else None,
                jtu.format_shape_dtype_string(cov_shape, cov_dtype)
                if cov_shape is not None else None),
            "x_shape":
            x_shape,
            "x_dtype":
            x_dtype,
            "mean_shape":
            mean_shape,
            "mean_dtype":
            mean_dtype,
            "cov_shape":
            cov_shape,
            "cov_dtype":
            cov_dtype
        } for x_shape, mean_shape, cov_shape in [
            # These test cases are where scipy flattens things, which has
            # different batch semantics than some might expect, so we manually
            # vectorize scipy's outputs for the sake of testing.
            [(5, 3, 2), (5, 3, 2), (5, 3, 2, 2)],
            [(2, ), (5, 3, 2), (5, 3, 2, 2)],
            [(5, 3, 2), (2, ), (5, 3, 2, 2)],
            [(5, 3, 2), (
                5,
                3,
                2,
            ), (2, 2)],
            [(1, 3, 2), (
                3,
                2,
            ), (5, 1, 2, 2)],
            [(5, 3, 2), (
                1,
                2,
            ), (2, 2)],
        ] for x_dtype, mean_dtype, cov_dtype in
                            itertools.combinations_with_replacement(
                                jtu.dtypes.floating, 3)
                            if (mean_shape is not None
                                or mean_dtype == np.float32) and
                            (cov_shape is not None or cov_dtype == np.float32))
    )
    def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype,
                                                mean_shape, mean_dtype,
                                                cov_shape, cov_dtype):
        rng = jtu.rand_default(self.rng())

        def args_maker():
            args = [rng(x_shape, x_dtype)]
            if mean_shape is not None:
                args.append(5 * rng(mean_shape, mean_dtype))
            if cov_shape is not None:
                if cov_shape == ():
                    args.append(0.1 + rng(cov_shape, cov_dtype)**2)
                else:
                    factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
                    factor = rng(factor_shape, cov_dtype)
                    args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
            return args

        osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf,
                               signature="(n),(n),(n,n)->()")

        self._CheckAgainstNumpy(osp_fun,
                                lsp_stats.multivariate_normal.logpdf,
                                args_maker,
                                tol=1e-3)
        self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf,
                              args_maker,
                              rtol=1e-4,
                              atol=1e-4)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__),
            "ndim":
            ndim,
            "nbatch":
            nbatch,
            "dtype":
            dtype
        } for ndim in [2, 3] for nbatch in [1, 3, 5]
                            for dtype in jtu.dtypes.floating))
    def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
        # Regression test for #5570
        rng = jtu.rand_default(self.rng())
        x = rng((nbatch, ndim), dtype)
        mean = 5 * rng((nbatch, ndim), dtype)
        factor = rng((nbatch, ndim, 2 * ndim), dtype)
        cov = factor @ factor.transpose(0, 2, 1)

        result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
        result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
        self.assertArraysEqual(result1, result2)
Example #16
0
class CheckifyTransformTests(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_jit={jit}",
            "jit": jit
        } for jit in [False, True]))
    @jtu.skip_on_devices("tpu")
    def test_jit_nan(self, jit):
        def f(x1, x2):
            y1 = jnp.sin(x1)
            y2 = jnp.sin(x2)
            return y1 + y2

        f = jax.jit(f) if jit else f
        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        err, _ = checked_f(3., 4.)
        self.assertIs(err.get(), None)

        err, _ = checked_f(3., jnp.inf)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_jit={jit}",
            "jit": jit
        } for jit in [False, True]))
    def test_jit_oob(self, jit):
        def f(x, i):
            y = jnp.sin(x)
            z = y[i]
            w = jnp.cos(z)
            return w

        f = jax.jit(f) if jit else f
        checked_f = checkify.checkify(f, errors=checkify.index_checks)

        err, _ = checked_f(jnp.arange(3), 2)
        self.assertIs(err.get(), None)

        err, _ = checked_f(jnp.arange(3), 5)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")

    @parameterized.named_parameters(
        {
            "testcase_name": f"_updatefn={update_fn}",
            "update_fn": update_fn
        } for update_fn in
        ["set", "add", "multiply", "divide", "power", "min", "max", "get"])
    def test_jit_oob_update(self, update_fn):
        def f(x, i):
            return getattr(x.at[i], update_fn)(1)

        f = jax.jit(f)
        checked_f = checkify.checkify(f, errors=checkify.index_checks)

        err, _ = checked_f(jnp.arange(3), 2)
        self.assertIs(err.get(), None)

        err, _ = checked_f(jnp.arange(3), 3)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_jit={jit}",
            "jit": jit
        } for jit in [False, True]))
    def test_jit_div_errors(self, jit):
        def f(x, y):
            return x / y

        f = jax.jit(f) if jit else f
        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        err, _ = checked_f(jnp.ones((3, )), jnp.ones((3, )))
        self.assertIs(err.get(), None)

        err, _ = checked_f(jnp.ones((3, )), jnp.array([1., 0., 1.]))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")

        err, _ = checked_f(jnp.array([1, jnp.inf, 1]),
                           jnp.array([1, jnp.inf, 1]))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive div")

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_jit={jit}",
            "jit": jit
        } for jit in [False, True]))
    @jtu.skip_on_devices("tpu")
    def test_jit_multi(self, jit):
        def f(x, i):
            y = x[i]
            z = jnp.cos(y)
            return z

        f = jax.jit(f) if jit else f
        checked_f = checkify.checkify(f, errors=checkify.automatic_checks)

        # no error
        err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
        self.assertIs(err.get(), None)

        # oob error
        err, _ = checked_f(jnp.array([0., 1., 2.]), 5)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")

        # nan error
        err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive cos")

    def test_numpy_indexing_oobs(self):
        def raises_oob(fn, idx, *expected_strs):
            err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x,
                                                                         idx)
            error_txt = err.get()
            self.assertIsNotNone(error_txt)
            self.assertStartsWith(error_txt, "out-of-bounds indexing")
            for s in expected_strs:
                self.assertIn(s, error_txt)

        x = jnp.ones((2, 3, 7))
        axis0_msg = "axis 0 with size 2"
        axis1_msg = "axis 1 with size 3"
        axis2_msg = "axis 2 with size 7"

        single_idx = lambda x, i: x[i]
        raises_oob(single_idx, 5, "index 5", axis0_msg)
        raises_oob(single_idx, -5, "index -3", axis0_msg)
        raises_oob(single_idx, (0, 100), "index 100", axis1_msg)
        raises_oob(single_idx, (0, 5, 100), "index 5", axis1_msg)
        raises_oob(single_idx, (0, 0, 100), "index 100", axis2_msg)
        raises_oob(single_idx, ((1, 20), (1, 4)), "index 20", axis0_msg)
        raises_oob(single_idx, ((1, 20), (3, 4)), "index 3", axis1_msg)
        raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg)
        raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg)

        multi_idx = lambda x, i: x[i[0], :, i[1]]
        raises_oob(multi_idx, (0, 9), "index 9", axis2_msg)
        # TODO(lenamartens): numpy reports index -5 here, need to normalize?
        raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg)
        raises_oob(multi_idx, (5, -9), "index 5", axis0_msg)
        raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_jit={jit}",
            "jit": jit
        } for jit in [False, True]))
    def test_jit_ordering(self, jit):
        def f(x, i):
            y = x[i]
            z = jnp.sin(x)
            return y * z

        f = jax.jit(f) if jit else f
        checked_f = checkify.checkify(f, errors=checkify.automatic_checks)

        # both oob and nan error, but oob happens first
        err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")

    @jtu.skip_on_devices("tpu")
    def test_pmap_basic(self):
        if len(jax.devices()) < 2:
            raise unittest.SkipTest("requires at least 2 devices")

        @jax.pmap
        def f(x1, x2):
            y1 = jnp.sin(x1)
            y2 = jnp.sin(x2)
            return y1 + y2

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        xs = jnp.array([0., 2.])
        err, _ = checked_f(xs, xs)
        self.assertIs(err.get(), None)

        ys = jnp.array([3., jnp.inf])
        err, _ = checked_f(xs, ys)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

    @jtu.skip_on_devices("tpu")
    def test_cond_basic(self):
        @jax.jit
        def f(x):
            return lax.cond(x > 0, lambda: jnp.sin(x), lambda: x)

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        err, _ = checked_f(3.)
        self.assertIs(err.get(), None)

        err, _ = checked_f(jnp.inf)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

        err, _ = checked_f(-jnp.inf)
        self.assertIs(err.get(), None)

    @jtu.skip_on_devices("tpu")
    def test_scan_map(self):
        def scan_body(_, x):
            return None, jnp.sin(x)

        @jax.jit
        def f(xs):
            return lax.scan(scan_body, None, xs)

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        xs = jnp.array([0., 2.])
        err, (_, ch_outs) = checked_f(xs)
        _, outs = f(xs)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_outs, outs)

        xs = jnp.array([3., jnp.inf])
        err, (_, ch_outs) = checked_f(xs)
        _, outs = f(xs)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")
        self.assertArraysEqual(ch_outs, outs)

    @jtu.skip_on_devices("tpu")
    def test_scan_carry(self):
        def scan_body(carry, x):
            carry = carry - 1.
            possible_nan = jnp.sin(1. / carry)
            return carry, x + possible_nan

        @jax.jit
        def f(carry, xs):
            return lax.scan(scan_body, carry, xs)

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        carry, xs = 3., jnp.ones((2, ))
        err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
        out_carry, outs = f(carry, xs)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_outs, outs)
        self.assertArraysEqual(ch_out_carry, out_carry)

        # error happens on first iteration
        carry, xs = 1., jnp.ones((2, ))
        err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
        out_carry, outs = f(carry, xs)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")
        self.assertArraysEqual(ch_outs, outs)
        self.assertArraysEqual(ch_out_carry, out_carry)

        # error happens on second iteration
        carry, xs = 2., jnp.ones((4, ))
        err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
        out_carry, outs = f(carry, xs)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")
        self.assertArraysEqual(ch_outs, outs)
        self.assertArraysEqual(ch_out_carry, out_carry)

    @jtu.skip_on_devices("tpu")
    def test_while_loop_body_error(self):
        def while_cond(val):
            i, _ = val
            return i < 2

        def while_body(val):
            i, x = val
            possible_nan = jnp.sin(1. / i)
            return i + 1., x + possible_nan

        @jax.jit
        def f(init_val):
            return lax.while_loop(while_cond, while_body, (init_val, 0.))

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        init_val = 1.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_out, out)

        init_val = 0.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")
        self.assertArraysEqual(ch_out, out)

    @jtu.skip_on_devices("tpu")
    def test_while_loop_cond_error(self):
        def while_cond(val):
            _ = jnp.sin(1. / val)
            return val < 2.

        def while_body(val):
            return val + 1.

        @jax.jit
        def f(init_val):
            return lax.while_loop(while_cond, while_body, init_val)

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        init_val = 1.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIs(err.get(), None)
        self.assertArraysEqual(ch_out, out)

        init_val = 0.
        err, ch_out = checked_f(init_val)
        out = f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")
        self.assertArraysEqual(ch_out, out)

    @jtu.skip_on_devices("tpu")
    def test_while_loop_cond_error_and_false(self):
        # Tests if an error is generated when cond returns False.
        def while_cond(val):
            possible_nan = jnp.sin(1. / val)
            return jnp.logical_not(jnp.isnan(possible_nan))

        @jax.jit
        def f(init_val):
            return lax.while_loop(while_cond, lambda val: val - 1, init_val)

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        # error on first cond
        init_val = 0.
        err, _ = checked_f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")

        # error on second cond
        init_val = 1.
        err, _ = checked_f(init_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")

    @jtu.skip_on_devices("tpu")
    def test_while_loop_body_and_cond_error(self):
        def while_cond(val):
            i, cond_val, _ = val
            _ = jnp.sin(cond_val)
            return i < 2

        def while_body(val):
            i, cond_val, body_val = val
            possible_nan = jnp.cos(body_val)
            return i + 1., cond_val, possible_nan

        @jax.jit
        def f(cond_val, body_val):
            return lax.while_loop(while_cond, while_body,
                                  (0., cond_val, body_val))

        checked_f = checkify.checkify(f, errors=checkify.float_checks)

        cond_val = jnp.inf
        body_val = 1.
        err, _ = checked_f(cond_val, body_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

        cond_val = 1.
        body_val = jnp.inf
        err, _ = checked_f(cond_val, body_val)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive cos")

        cond_val = jnp.inf
        body_val = jnp.inf
        err, _ = checked_f(cond_val, body_val)
        self.assertIsNotNone(err.get())
        # first error which occurs is in cond
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

    def test_empty_enabled_errors(self):
        def multi_errors(x):
            x = x / 0  # DIV
            x = jnp.sin(x)  # NAN
            x = x[500]  # OOB
            checkify.check(x < 0, "must be negative!")  # ASSERT
            return x

        x = jnp.ones((2, ))
        err, _ = checkify.checkify(multi_errors, errors=set())(x)
        self.assertIsNone(err.get())

    @parameterized.named_parameters(
        ("assert", checkify.user_checks, "must be negative!"),
        ("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
        ("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
        ("oob", checkify.index_checks, "out-of-bounds indexing"),
        ("automatic_checks", checkify.automatic_checks, "divided by zero"),
    )
    @jtu.skip_on_devices("tpu")
    def test_enabled_errors(self, error_set, expected_error):
        def multi_errors(x):
            checkify.check(jnp.all(x < 0), "must be negative!")  # ASSERT
            x = x / 0  # DIV
            x = jnp.sin(x)  # NAN
            x = x[500]  # OOB
            return x

        x = jnp.ones((2, ))
        err, _ = checkify.checkify(multi_errors, errors=error_set)(x)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), expected_error)

    @jtu.skip_on_devices("tpu")
    def test_post_process_call(self):
        @partial(checkify.checkify, errors=checkify.float_checks)
        def g(x):
            @jax.jit
            def f(y):
                return jnp.sin(x * y)

            return f(jnp.inf)

        err, _ = g(2.)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

    @jtu.skip_on_devices("tpu")
    def test_post_process_map(self):
        @partial(checkify.checkify, errors=checkify.float_checks)
        def g(x):
            @jax.pmap
            def f(y):
                return jnp.sin(x * y)

            return f(jnp.array([jnp.inf]))[0]

        err, _ = g(2.)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')

    @jtu.skip_on_devices("tpu")
    def test_custom_jvp(self):
        @jax.custom_jvp
        def sin(x):
            return jnp.sin(x)

        @sin.defjvp
        def sin_jvp(primals, tangents):
            (x, ), (xdot, ) = primals, tangents
            return sin(x), jnp.cos(x) * xdot

        f = checkify.checkify(sin, errors=checkify.float_checks)

        err, y = f(3.)
        self.assertIsNone(err.get())
        err, y = f(jnp.inf)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')

        # When we hit the custom jvp rule with jvp-of-checkify, no checks are added.
        (err, y), (errdot, ydot) = jax.jvp(f, (3., ), (1., ))  # doesn't crash
        self.assertIsNone(err.get())  # no error
        self.assertEmpty(err.msgs)  # and no checks were added!
        self.assertEmpty(errdot.msgs)
        y_expected, ydot_expected = jax.jvp(jnp.sin, (3., ), (1., ))
        self.assertAllClose(y, y_expected)
        self.assertAllClose(ydot, ydot_expected)

        # Grad-of-checkify doesn't crash either.
        x_bar = jax.grad(lambda x: f(x)[1])(3.)
        self.assertAllClose(x_bar, jnp.cos(3.))

        # Checkify-of-jvp adds checks (unlike jvp-of-checkify above).
        g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x, ), (xdot, )),
                              errors=checkify.float_checks)
        err, (y, ydot) = g(3., 1.)  # doesn't crash
        self.assertIsNone(err.get())  # no error
        self.assertNotEmpty(err.msgs)  # but checks were added!
        self.assertAllClose(y, jnp.sin(3.))
        self.assertAllClose(ydot, jnp.cos(3.))
        err, _ = g(jnp.inf, 1.)
        self.assertIsNotNone(err.get())  # yes error
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')

    @jtu.skip_on_devices("tpu")
    def test_custom_vjp(self):
        @jax.custom_vjp
        def sin(x):
            return jnp.sin(x)

        def sin_fwd(x):
            return jnp.sin(x), 2. * x

        def sin_bwd(x2, g):
            return jnp.cos(x2 / 2.) * g,

        sin.defvjp(sin_fwd, sin_bwd)

        f = checkify.checkify(sin, errors=checkify.float_checks)

        # no differentiation, no error
        err, y = f(3.)
        self.assertIsNone(err.get())

        # no differentiation, yes error
        err, y = f(jnp.inf)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), 'nan generated by primitive sin')

        # When we hit the custom vjp rule with vjp-of-checkify, no checks are added.
        (err, y), f_vjp = jax.vjp(f, 3.)
        self.assertIsNone(err.get())  # no error
        self.assertEmpty(err.msgs)  # and no checks were added!

        # Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
        err, y = checkify.checkify(jax.grad(sin),
                                   errors=checkify.float_checks)(3.)
        self.assertIsNone(err.get())  # no error
        self.assertNotEmpty(err.msgs)  # but checks were added!
        err, y = checkify.checkify(jax.grad(sin),
                                   errors=checkify.float_checks)(jnp.inf)
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "nan generated by primitive sin")

    def test_scan_consts(self):
        def f(xs):
            def scan_body(carry, _):
                # closes oves xs
                return carry + 1, xs[carry]

            return lax.scan(scan_body, 1, xs)[1]

        checked_f = checkify.checkify(f, errors=checkify.index_checks)
        err, _ = checked_f(jnp.ones((7, 3)))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")

    def test_scan_consts2(self):
        def f(xs):
            def scan_body(carry, _):
                # add more consts!
                _ = xs[carry], xs[carry], jnp.sin(np.arange(11.))
                return carry + 1, xs[carry]

            return lax.scan(scan_body, 1, xs)[1]

        checked_f = checkify.checkify(f, errors=checkify.index_checks)
        err, _ = checked_f(jnp.ones((7, 3)))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "out-of-bounds indexing")

    def test_while_consts(self):
        def f(xs):
            def while_cond(carry):
                i, _ = carry
                _ = xs[i], jnp.sin(np.arange(11.))
                return i > -1

            def while_body(carry):
                i, _ = carry
                x = xs[i]
                return i - 1, x / i

            return lax.while_loop(while_cond, while_body,
                                  (0, jnp.zeros_like(xs[0])))

        checked_f = checkify.checkify(f, errors=checkify.float_checks)
        err, _ = checked_f(jnp.ones((7, 3)))
        self.assertIsNotNone(err.get())
        self.assertStartsWith(err.get(), "divided by zero")

    def test_multiple_payloads(self):
        def f(x):
            _ = x[5]
            _ = x[6]

        err, _ = checkify.checkify(f, errors=checkify.index_checks)(jnp.ones(
            (2, )))
        self.assertIsNotNone(err.get())
        self.assertIn("index 5", err.get())

    def test_nd_payloads(self):
        cf = checkify.checkify(lambda x, i: x[i], errors=checkify.index_checks)
        errs, _ = jax.vmap(cf)(jnp.ones((3, 2)), jnp.array([5, 0, 100]))
        self.assertIsNotNone(errs.get())
        self.assertIn("index 5", errs.get())
        self.assertIn("index 100", errs.get())

    def test_mapped_error_one_payload(self):
        def f(x, i):
            x = x[i]
            return x / 0

        cf = checkify.checkify(f, errors=checkify.automatic_checks)
        errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100]))
        self.assertIsNotNone(errs.get())
        self.assertIn("divided by zero", errs.get())
        self.assertIn("index 100", errs.get())
Example #17
0
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):

    # This test runs for all primitive harnesses. For each primitive "xxx" the
    # test will be called "test_prim_xxx_..." and the custom parameters for
    # the test are defined in the class method "jax2tf_limitations.Jax2TfLimitation.xxx".
    # See more details in the comment at top of file and in Jax2TfLimitation class.
    # If you want to run this test for only one harness, add parameter
    # `one_containing="foo"` to parameterized below.
    @primitive_harness.parameterized(
        primitive_harness.all_harnesses,
        include_jax_unimpl=False,
        #one_containing="cumprod_dtype_by_fun_shape=float16[8,9]_axis=0_reverse=False"
    )
    @jtu.ignore_warning(category=UserWarning,
                        message="Using reduced precision for gradient.*")
    def test_prim(self, harness: primitive_harness.Harness):
        limitations = Jax2TfLimitation.limitations_for_harness(harness)
        device = jtu.device_under_test()
        limitations = tuple(
            filter(lambda l: l.filter(device=device, dtype=harness.dtype),
                   limitations))
        func_jax = harness.dyn_fun
        args = harness.dyn_args_maker(self.rng())
        enable_xla = harness.params.get("enable_xla", True)
        associative_scan_reductions = harness.params.get(
            "associative_scan_reductions", False)
        with jax.jax2tf_associative_scan_reductions(
                associative_scan_reductions):
            self.ConvertAndCompare(func_jax,
                                   *args,
                                   limitations=limitations,
                                   enable_xla=enable_xla)

    def test_primitive_coverage(self):
        """Fail if there are JAX primitives that are not implemented."""
        # Harvest primitives from XLA translation tables
        all_primitives = (set(xla._translations)
                          | set(xla._backend_specific_translations["cpu"])
                          | set(xla._backend_specific_translations["gpu"])
                          | set(xla._backend_specific_translations["tpu"])
                          | set(mlir._lowerings)
                          | set(mlir._platform_specific_lowerings["cpu"])
                          | set(mlir._platform_specific_lowerings["gpu"])
                          | set(mlir._platform_specific_lowerings["tpu"]))

        tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set(
            jax.experimental.jax2tf.jax2tf.tf_impl_with_avals)
        tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl)

        all_primitives = tuple(sorted(all_primitives, key=str))
        for p in all_primitives:
            if p.name == "axis_index":
                continue
            # TODO: Remove once tensorflow is 2.10.0 everywhere.
            if p.name == "optimization_barrier":
                continue
            if p.name in tf_not_yet_impl:
                self.assertNotIn(
                    p, tf_impl
                )  # Should not be in both tf_impl and tf_not_yet_impl
            else:
                self.assertIn(p, tf_impl)

    def test_generate_limitations_doc(self):
        """Generates primitives_with_limited_support.md.

    See the doc for instructions.
    """

        harnesses = [
            h for h in primitive_harness.all_harnesses
            if h.filter(h, include_jax_unimpl=True)
        ]
        print(f"Found {len(harnesses)} test harnesses that work in JAX")

        def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation):
            return (h.group_name, l.description, l.devices,
                    tuple(np.dtype(d).name for d in l.dtypes), l.modes)

        unique_limitations: Dict[Any, Tuple[primitive_harness.Harness,
                                            Jax2TfLimitation]] = {}
        for h in harnesses:
            for l in h.jax_unimplemented:
                if l.enabled:
                    # Fake a Jax2TFLimitation from the Limitation
                    tfl = Jax2TfLimitation(
                        description="Not implemented in JAX: " + l.description,
                        devices=l.devices,
                        dtypes=l.dtypes,
                        expect_tf_error=False,
                        skip_tf_run=True)
                    unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl)
        for h in harnesses:
            for l in Jax2TfLimitation.limitations_for_harness(h):
                unique_limitations[hash(unique_hash(h, l))] = (h, l)

        print(f"Found {len(unique_limitations)} unique limitations")
        tf_error_table = [
            """
| Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes |
| --- | --- | --- | --- | --- |"""
        ]
        tf_numerical_discrepancies_table = list(tf_error_table)  # a copy
        for h, l in sorted(unique_limitations.values(),
                           key=lambda pair: unique_hash(*pair)):
            devices = ", ".join(sorted(l.devices))
            modes = ", ".join(sorted(l.modes))
            description = l.description
            if l.skip_comparison:
                description = "Numeric comparison disabled: " + description
            if l.expect_tf_error:
                description = "TF error: " + description
            if l.skip_tf_run:
                description = "TF test skipped: " + description

            if l.skip_tf_run or l.expect_tf_error:
                to_table = tf_error_table
            elif l.skip_comparison or l.custom_assert:
                to_table = tf_numerical_discrepancies_table
            else:
                continue

            to_table.append(
                f"| {h.group_name} | {description} | "
                f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |"
            )

        if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"):
            raise unittest.SkipTest(
                "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation"
            )
        # The CPU has more supported types, and harnesses
        self.assertEqual("cpu", jtu.device_under_test())
        self.assertTrue(
            config.x64_enabled,
            "Documentation generation must be run with JAX_ENABLE_X64=1")

        with open(
                os.path.join(
                    os.path.dirname(__file__),
                    "../g3doc/primitives_with_limited_support.md.template")
        ) as f:
            template = f.read()
        output_file = os.path.join(
            os.path.dirname(__file__),
            "../g3doc/primitives_with_limited_support.md")

        with open(output_file, "w") as f:
            f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \
                    .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \
                    .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \
                    )

    # The rest of the test are checking special cases

    @parameterized.named_parameters(
        dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [
            jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less,
            jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal,
            jnp.not_equal, jnp.maximum, jnp.minimum
        ])
    def test_type_promotion(self, f_jax=jnp.add):
        # We only test a few types here, as tensorflow does not support many
        # types like uint* or bool in binary ops.
        types = [dtypes.bfloat16, np.int32, np.int64, np.float32]
        for x_dtype in types:
            for y_dtype in types:
                x = np.array([1, 2], dtype=x_dtype)
                y = np.array([3, 4], dtype=y_dtype)
                self.ConvertAndCompare(f_jax, x, y)

    def test_integer_div(self):
        x = jnp.array([-4, -3, -1, 0, 1, 3, 6])
        y = np.int32(3)
        self.ConvertAndCompare(jnp.floor_divide, x, y)
        expected = jnp.floor_divide(x, y)
        # Try it with TF 1 as well (#5831)
        with tf.compat.v1.Session() as sess:
            tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y))
            self.assertAllClose(expected, tf1_res)

    def test_boolean_gather(self):
        values = np.array([[True, True], [False, True], [False, False]],
                          dtype=np.bool_)
        indices = np.array([0, 1], dtype=np.int32)
        for axis in [0, 1]:
            f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis))  # pylint: disable=cell-var-from-loop
            self.ConvertAndCompare(f_jax, values, indices)

    def test_gather_rank_change(self):
        params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]])
        indices = jnp.array([[1, 1, 2], [0, 1, 0]])
        f_jax = jax.jit(lambda i: params[i])
        self.ConvertAndCompare(f_jax, indices)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in REDUCE))
    def test_reduce_ops_with_numerical_input(self, f_jax):
        values = np.array([1, 2, 3], dtype=np.float32)
        self.ConvertAndCompare(f_jax, values)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{op}", op=op)
            for op in ("add", "max", "min", "multiply", "set")))
    def test_scatter_static(self, op):
        values = np.ones((5, 6), dtype=np.float32)
        update = np.float32(6.)
        f_jax = jax.jit(lambda v, u: getattr(v.at[::2, 3:], op)(u))
        self.ConvertAndCompare(f_jax, values, update)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax)
            for f_jax in REDUCE))
    def test_reduce_ops_with_boolean_input(self, f_jax):
        values = np.array([True, False, True], dtype=np.bool_)
        self.ConvertAndCompare(f_jax, values)
Example #18
0
class StaxTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape
        } for shape in [(2, 3), (5, )]))
    def testRandnInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.randn()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(shape),
            "shape": shape
        } for shape in [(2, 3), (2, 3, 4)]))
    def testGlorotInitShape(self, shape):
        key = random.PRNGKey(0)
        out = stax.glorot()(key, shape)
        self.assertEqual(out.shape, shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvShape(self, channels, filter_shape, padding, strides,
                      input_shape):
        init_fun, apply_fun = stax.Conv(channels,
                                        filter_shape,
                                        strides=strides,
                                        padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (2, 1), (2, 2)]
                            for input_shape in [(2, 10, 11, 1)]))
    def testConvTransposeShape(self, channels, filter_shape, padding, strides,
                               input_shape):
        init_fun, apply_fun = stax.ConvTranspose(
            channels,
            filter_shape,  # 2D
            strides=strides,
            padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}"
            .format(channels, filter_shape, padding, strides, input_shape),
            "channels":
            channels,
            "filter_shape":
            filter_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape
        } for channels in [2, 3] for filter_shape in [(1, ), (2, ), (3, )]
                            for padding in ["SAME", "VALID"]
                            for strides in [None, (1, ), (2, )]
                            for input_shape in [(2, 10, 1)]))
    def testConv1DTransposeShape(self, channels, filter_shape, padding,
                                 strides, input_shape):
        init_fun, apply_fun = stax.Conv1DTranspose(channels,
                                                   filter_shape,
                                                   strides=strides,
                                                   padding=padding)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_out_dim={}_input_shape={}".format(out_dim, input_shape),
            "out_dim":
            out_dim,
            "input_shape":
            input_shape
        } for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)]))
    def testDenseShape(self, out_dim, input_shape):
        init_fun, apply_fun = stax.Dense(out_dim)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_input_shape={}_nonlinear={}".format(input_shape, nonlinear),
                "input_shape":
                input_shape,
                "nonlinear":
                nonlinear
            } for input_shape in [(2, 3), (2, 3, 4)]
            for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"]))
    def testNonlinearShape(self, input_shape, nonlinear):
        init_fun, apply_fun = getattr(stax, nonlinear)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_window_shape={}_padding={}_strides={}_input_shape={}"
            "_maxpool={}_spec={}".format(window_shape, padding, strides,
                                         input_shape, max_pool, spec),
            "window_shape":
            window_shape,
            "padding":
            padding,
            "strides":
            strides,
            "input_shape":
            input_shape,
            "max_pool":
            max_pool,
            "spec":
            spec
        } for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"]
                            for strides in [None, (2, 1)]
                            for input_shape in [(2, 5, 6, 4)]
                            for max_pool in [False, True]
                            for spec in ["NHWC", "NCHW", "WHNC", "WHCN"]))
    def testPoolingShape(self, window_shape, padding, strides, input_shape,
                         max_pool, spec):
        layer = stax.MaxPool if max_pool else stax.AvgPool
        init_fun, apply_fun = layer(window_shape,
                                    padding=padding,
                                    strides=strides,
                                    spec=spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_shape={}".format(input_shape),
            "input_shape": input_shape
        } for input_shape in [(2, 3), (2, 3, 4)]))
    def testFlattenShape(self, input_shape):
        init_fun, apply_fun = stax.Flatten
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}_spec={}".format(
                    input_shape, i),
                "input_shape": input_shape,
                "spec": spec
            } for input_shape in [(2, 5, 6, 1)]
            for i, spec in enumerate([[stax.Conv(3, (
                2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten,
                          stax.Dense(4)]])))
    def testSerialComposeLayersShape(self, input_shape, spec):
        init_fun, apply_fun = stax.serial(*spec)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testDropoutShape(self, input_shape):
        init_fun, apply_fun = stax.Dropout(0.9)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shape)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_input_shape={}".format(input_shape),
                "input_shape": input_shape
            } for input_shape in [(3, 4), (2, 5, 6, 1)]))
    def testFanInSum(self, input_shape):
        init_fun, apply_fun = stax.FanInSum
        _CheckShapeAgreement(self, init_fun, apply_fun,
                             [input_shape, input_shape])

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_inshapes={}_axis={}".format(
                    input_shapes, axis),
                "input_shapes": input_shapes,
                "axis": axis
            } for input_shapes, axis in [
                ([(2, 3), (2, 1)], 1),
                ([(2, 3), (2, 1)], -1),
                ([(1, 2, 4), (1, 1, 4)], 1),
            ]))
    def testFanInConcat(self, input_shapes, axis):
        init_fun, apply_fun = stax.FanInConcat(axis)
        _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes)

    def testIssue182(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.Softmax
        input_shape = (10, 3)
        inputs = np.arange(30.).astype("float32").reshape(input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        assert out_shape == out.shape
        assert np.allclose(np.sum(np.asarray(out), -1), 1.)

    def testBatchNormNoScaleOrCenter(self):
        key = random.PRNGKey(0)
        axes = (0, 1, 2)
        init_fun, apply_fun = stax.BatchNorm(axis=axes,
                                             center=False,
                                             scale=False)
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(np.random.RandomState(0), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)
        means = np.mean(out, axis=(0, 1, 2))
        std_devs = np.std(out, axis=(0, 1, 2))
        assert np.allclose(means, np.zeros_like(means), atol=1e-4)
        assert np.allclose(std_devs, np.ones_like(std_devs), atol=1e-4)

    def testBatchNormShapeNHWC(self):
        key = random.PRNGKey(0)
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(np.random.RandomState(0), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (7, ))
        self.assertEqual(gamma.shape, (7, ))
        self.assertEqual(out_shape, out.shape)

    def testBatchNormShapeNCHW(self):
        key = random.PRNGKey(0)
        # Regression test for https://github.com/google/jax/issues/461
        init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3))
        input_shape = (4, 5, 6, 7)
        inputs = random_inputs(np.random.RandomState(0), input_shape)

        out_shape, params = init_fun(key, input_shape)
        out = apply_fun(params, inputs)

        self.assertEqual(out_shape, input_shape)
        beta, gamma = params
        self.assertEqual(beta.shape, (5, ))
        self.assertEqual(gamma.shape, (5, ))
        self.assertEqual(out_shape, out.shape)
Example #19
0
class ImageTest(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method,
          antialias),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method, "antialias": antialias}
       for dtype in float_dtypes
       for target_shape, image_shape in itertools.combinations_with_replacement(
        [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
       for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"]
       for antialias in [False, True]))
  @unittest.skipIf(not tf, "Test requires TensorFlow")
  def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method,
                                  antialias):
    # TODO(phawkins): debug this. There is a small mismatch between TF and JAX
    # for some cases of non-antialiased bicubic downscaling; we would expect
    # exact equality.
    if method == "bicubic" and any(x < y for x, y in
                                   zip(target_shape, image_shape)):
      raise unittest.SkipTest("non-antialiased bicubic downscaling mismatch")
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: (rng(image_shape, dtype),)
    def tf_fn(x):
      out = tf.image.resize(
        x.astype(np.float64), tf.constant(target_shape[1:-1]),
        method=method, antialias=antialias).numpy().astype(dtype)
      return out
    jax_fn = partial(image.resize, shape=target_shape, method=method,
                     antialias=antialias)
    self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True,
                            tol={np.float16: 2e-2, jnp.bfloat16: 1e-1,
                                 np.float32: 1e-4, np.float64: 1e-4})


  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method}
       for dtype in [np.float32]
       for target_shape, image_shape in itertools.combinations_with_replacement(
        [[3, 2], [6, 4], [33, 17], [50, 39]], 2)
       for method in ["nearest", "bilinear", "lanczos3", "bicubic"]))
  @unittest.skipIf(not PIL_Image, "Test requires PIL")
  def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
    rng = jtu.rand_uniform(self.rng())
    args_maker = lambda: (rng(image_shape, dtype),)
    def pil_fn(x):
      pil_methods = {
        "nearest": PIL_Image.NEAREST,
        "bilinear": PIL_Image.BILINEAR,
        "bicubic": PIL_Image.BICUBIC,
        "lanczos3": PIL_Image.LANCZOS,
      }
      img = PIL_Image.fromarray(x.astype(np.float32))
      out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]),
                       dtype=dtype)
      return out
    jax_fn = partial(image.resize, shape=target_shape, method=method,
                     antialias=True)
    self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method),
        "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape,
        "method": method}
       for dtype in inexact_dtypes
       for image_shape, target_shape in [
         ([3, 1, 2], [6, 1, 4]),
         ([1, 3, 2, 1], [1, 6, 4, 1]),
       ]
       for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]))
  def testResizeUp(self, dtype, image_shape, target_shape, method):
    data = [64, 32, 32, 64, 50, 100]
    expected_data = {}
    expected_data["nearest"] = [
        64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0,
        32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0,
        100.0
    ]
    expected_data["linear"] = [
        64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0,
        36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5,
        87.5, 100.0
    ]
    expected_data["lanczos3"] = [
        75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964,
        35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045,
        35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413
    ]
    expected_data["lanczos5"] = [
        77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709,
        35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658,
        32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109
    ]
    expected_data["cubic"] = [
        70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981,
        36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731,
        41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284
    ]
    x = np.array(data, dtype=dtype).reshape(image_shape)
    output = image.resize(x, target_shape, method)
    expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape)
    self.assertAllClose(output, expected, atol=1e-04)

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method,
          antialias),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method, "antialias": antialias}
       for dtype in [np.float32]
       for target_shape, image_shape in itertools.combinations_with_replacement(
        [[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2)
       for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"]
       for antialias in [False, True]))
  def testResizeGradients(self, dtype, image_shape, target_shape, method,
                           antialias):
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: (rng(image_shape, dtype),)
    jax_fn = partial(image.resize, shape=target_shape, method=method,
                     antialias=antialias)
    jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
       {"testcase_name": "_shape={}_target={}_method={}_antialias={}".format(
          jtu.format_shape_dtype_string(image_shape, dtype),
          jtu.format_shape_dtype_string(target_shape, dtype), method,
          antialias),
        "dtype": dtype, "image_shape": image_shape,
        "target_shape": target_shape,
        "method": method, "antialias": antialias}
       for dtype in [np.float32]
       for image_shape, target_shape in [
         ([1], [0]),
         ([5, 5], [5, 0]),
         ([5, 5], [0, 1]),
         ([5, 5], [0, 0])
       ]
       for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]
       for antialias in [False, True]))
  def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias):
    # Regression test for https://github.com/google/jax/issues/7586
    image = np.ones(image_shape, dtype)
    out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias)
    self.assertArraysEqual(out, jnp.zeros(target_shape, dtype))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_target={}_method={}".format(
         jtu.format_shape_dtype_string(image_shape, dtype),
         jtu.format_shape_dtype_string(target_shape, dtype), method),
       "dtype": dtype, "image_shape": image_shape,
       "target_shape": target_shape,
       "scale": scale, "translation": translation, "method": method}
      for dtype in inexact_dtypes
      for image_shape, target_shape, scale, translation in [
        ([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]),
        ([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])]
      for method in ["linear", "lanczos3", "lanczos5", "cubic"]))
  def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale,
                              translation, method):
    data = [64, 32, 32, 64, 50, 100]
    # Note zeros occur in the output because the sampling location is outside
    # the boundaries of the input image.
    expected_data = {}
    expected_data["linear"] = [
        0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0,
        52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0
    ]
    expected_data["lanczos3"] = [
        0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454,
        31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045,
        0.0, 56.4713, 83.5243, 104.2017, 0.0
    ]
    expected_data["lanczos5"] = [
        0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593,
        28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658,
        0.0, 56.784, 83.984, 108.6467, 0.0
    ]
    expected_data["cubic"] = [
        0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789,
        35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731,
        0.0, 55.823, 83.9288, 98.1026, 0.0
    ]
    x = np.array(data, dtype=dtype).reshape(image_shape)
    # Should we test different float types here?
    scale_a = jnp.array(scale, dtype=jnp.float32)
    translation_a = jnp.array(translation, dtype=jnp.float32)
    output = image.scale_and_translate(x, target_shape, range(len(image_shape)),
                                       scale_a, translation_a,
                                       method)

    expected = np.array(
        expected_data[method], dtype=dtype).reshape(target_shape)
    self.assertAllClose(output, expected, atol=2e-03)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_method={}_antialias={}".format(
         jtu.dtype_str(dtype), method, antialias),
       "dtype": dtype, "method": method, "antialias": antialias}
      for dtype in inexact_dtypes
      for method in ["linear", "lanczos3", "lanczos5", "cubic"]
      for antialias in [True, False]))
  def testScaleAndTranslateDown(self, dtype, method, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]
    if antialias:
      expected_data = {}
      expected_data["linear"] = [
          43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
      ]
      expected_data["lanczos3"] = [
          43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0
      ]
      expected_data["lanczos5"] = [
          43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0
      ]
      expected_data["cubic"] = [
          42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0
      ]
    else:
      expected_data = {}
      expected_data["linear"] = [
          43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0
      ]
      expected_data["lanczos3"] = [
          44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0
      ]
      expected_data["lanczos5"] = [
          44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0
      ]
      expected_data["cubic"] = [
          43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0
      ]
    x = np.array(data, dtype=dtype).reshape(image_shape)

    expected = np.array(
        expected_data[method], dtype=dtype).reshape(target_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    output = image.scale_and_translate(
        x, target_shape, (0,1,2,3),
        scale_a, translation_a, method, antialias=antialias)
    self.assertAllClose(output, expected, atol=2e-03)

    # Tests that running with just a subset of dimensions that have non-trivial
    # scale and translation.
    output = image.scale_and_translate(
        x, target_shape, (1,2),
        scale_a[1:3], translation_a[1:3], method, antialias=antialias)
    self.assertAllClose(output, expected, atol=2e-03)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "antialias={}".format(antialias),
       "antialias": antialias}
      for antialias in [True, False]))
  def testScaleAndTranslateJITs(self, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]
    if antialias:
      expected_data = [
          43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0
      ]
    else:
      expected_data = [43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0]
    x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)

    expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    def jit_fn(in_array, s, t):
      return jax.image.scale_and_translate(
          in_array, target_shape, (0, 1, 2, 3), s, t,
          "linear", antialias, precision=jax.lax.Precision.HIGHEST)

    output = jax.jit(jit_fn)(x, scale_a, translation_a)
    self.assertAllClose(output, expected, atol=2e-03)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "antialias={}".format(antialias),
       "antialias": antialias}
      for antialias in [True, False]))
  def testScaleAndTranslateGradFinite(self, antialias):
    image_shape = [1, 6, 7, 1]
    target_shape = [1, 3, 3, 1]

    data = [
        51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92,
        41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89,
        71, 32, 23, 23, 35, 93
    ]

    x = jnp.array(data, dtype=jnp.float32).reshape(image_shape)
    scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32)
    translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32)

    def scale_fn(s):
      return jnp.sum(jax.image.scale_and_translate(
        x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias,
        precision=jax.lax.Precision.HIGHEST))

    scale_out = jax.grad(scale_fn)(scale_a)
    self.assertTrue(jnp.all(jnp.isfinite(scale_out)))

    def translate_fn(t):
      return jnp.sum(jax.image.scale_and_translate(
        x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias,
        precision=jax.lax.Precision.HIGHEST))

    translate_out = jax.grad(translate_fn)(translation_a)
    self.assertTrue(jnp.all(jnp.isfinite(translate_out)))


  def testResizeWithUnusualShapes(self):
    x = jnp.ones((3, 4))
    # Array shapes are accepted
    self.assertEqual((10, 17),
                     jax.image.resize(x, jnp.array((10, 17)), "nearest").shape)
    with self.assertRaises(TypeError):
      # Fractional shapes are disallowed
      jax.image.resize(x, [10.5, 17], "bicubic")
Example #20
0
class SparsifyTest(jtu.JaxTestCase):
    @classmethod
    def sparsify(cls, f):
        return sparsify(f, use_tracer=False)

    def testTracerIsInstanceCheck(self):
        @self.sparsify
        def f(x):
            self.assertNotIsInstance(x, SparseTracer)

        f(jnp.arange(5))

    def assertBcooIdentical(self, x, y):
        self.assertIsInstance(x, BCOO)
        self.assertIsInstance(y, BCOO)
        self.assertEqual(x.shape, y.shape)
        self.assertArraysEqual(x.data, y.data)
        self.assertArraysEqual(x.indices, y.indices)

    def testArgSpec(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparseEnv()
        argspecs = arrays_to_argspecs(spenv, args)
        self.assertEqual(len(argspecs), len(args))
        self.assertEqual(spenv.size(), 5)
        self.assertEqual(argspecs, (ArgSpec(
            X.shape, 0, None), ArgSpec(X.shape, 1, 2), ArgSpec(X.shape, 3, 4)))

        args_out = argspecs_to_arrays(spenv, argspecs)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        argspecs = (ArgSpec(X.shape, 0,
                            None), ArgSpec(X.shape, 1,
                                           2), ArgSpec(X.shape, 3, 2))
        spenv = SparseEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = argspecs_to_arrays(spenv, argspecs)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

    def testUnitHandling(self):
        x = BCOO.fromdense(jnp.arange(5))
        f = jit(lambda x, y: x)
        result = self.sparsify(jit(f))(x, core.unit)
        self.assertBcooIdentical(result, x)

    def testDropvar(self):
        def inner(x):
            return x * 2, x * 3

        def f(x):
            _, y = jit(inner)(x)
            return y * 4

        x_dense = jnp.arange(5)
        x_sparse = BCOO.fromdense(x_dense)
        self.assertArraysEqual(
            self.sparsify(f)(x_sparse).todense(), f(x_dense))

    def testPytreeInput(self):
        f = self.sparsify(lambda x: x)
        args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
        out = f(args)
        self.assertLen(out, 2)
        self.assertArraysEqual(args[0], out[0])
        self.assertBcooIdentical(args[1], out[1])

    def testSparsify(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)
        v = jnp.arange(M_dense.shape[0])

        @self.sparsify
        def func(x, v):
            return -jnp.sin(jnp.pi * x).T @ (v + 1)

        result_dense = func(M_dense, v)
        result_sparse = func(M_sparse, v)

        self.assertAllClose(result_sparse, result_dense)

    def testSparsifyWithConsts(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)

        @self.sparsify
        def func(x):
            return jit(lambda x: jnp.sum(x, 1))(x)

        result_dense = func(M_dense)
        result_sparse = func(M_sparse)

        self.assertAllClose(result_sparse.todense(), result_dense)

    def testSparseMatmul(self):
        X = jnp.arange(16).reshape(4, 4)
        Xsp = BCOO.fromdense(X)
        Y = jnp.ones(4)
        Ysp = BCOO.fromdense(Y)

        # dot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Y)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse, result_dense)

        # rdot_general
        result_sparse = self.sparsify(operator.matmul)(Y, Xsp)
        result_dense = operator.matmul(Y, X)
        self.assertAllClose(result_sparse, result_dense)

        # spdot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse.todense(), result_dense)

    def testSparseAdd(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.add)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.add)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() + y.todense())

    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.mul)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())

    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        argspecs = [ArgSpec(x.shape, 1, 0), ArgSpec(y.shape, 2, 0)]
        spenv = SparseEnv([x.indices, x.data, y.data])

        result = sparsify_raw(operator.sub)(spenv, *argspecs)
        args_out, _ = result
        out, = argspecs_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())

    def testSparseSum(self):
        x = jnp.arange(20).reshape(4, 5)
        xsp = BCOO.fromdense(x)

        def f(x):
            return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))

        result_dense = f(x)
        result_sparse = self.sparsify(f)(xsp)

        assert len(result_dense) == len(result_sparse)

        for res_dense, res_sparse in zip(result_dense, result_sparse):
            if isinstance(res_sparse, BCOO):
                res_sparse = res_sparse.todense()
            self.assertArraysAllClose(res_dense, res_sparse)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dimensions={}_nbatch={}, ndense={}".format(
                jtu.format_shape_dtype_string(shape, np.float32), dimensions,
                n_batch, n_dense),
            "shape":
            shape,
            "dimensions":
            dimensions,
            "n_batch":
            n_batch,
            "n_dense":
            n_dense
        } for shape, dimensions in [
            [(1, ), (0, )],
            [(1, ), (-1, )],
            [(2, 1, 4), (1, )],
            [(2, 1, 3, 1), (1, )],
            [(2, 1, 3, 1), (1, 3)],
            [(2, 1, 3, 1), (3, )],
        ] for n_batch in range(len(shape) + 1)
                            for n_dense in range(len(shape) - n_batch + 1)))
    def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
        rng = jtu.rand_default(self.rng())

        M_dense = rng(shape, np.float32)
        M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
        func = self.sparsify(partial(lax.squeeze, dimensions=dimensions))

        result_dense = func(M_dense)
        result_sparse = func(M_sparse).todense()

        self.assertAllClose(result_sparse, result_dense)

    def testSparseWhileLoop(self):
        def cond_fun(params):
            i, A = params
            return i < 5

        def body_fun(params):
            i, A = params
            return i + 1, 2 * A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A))

        A = jnp.arange(4)
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 2)
        self.assertEqual(len(out_sparse), 2)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())

    def testSparseWhileLoopDuplicateIndices(self):
        def cond_fun(params):
            i, A, B = params
            return i < 5

        def body_fun(params):
            i, A, B = params
            # TODO(jakevdp): track shared indices through while loop & use this
            #   version of the test, which requires shared indices in order for
            #   the nse of the result to remain the same.
            # return i + 1, A, A + B

            # This version is fine without shared indices, and tests that we're
            # flattening non-shared indices consistently.
            return i + 1, B, A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A, A))

        A = jnp.arange(4).reshape((2, 2))
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 3)
        self.assertEqual(len(out_sparse), 3)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
        self.assertArraysEqual(out_dense[2], out_sparse[2].todense())

    def testSparsifyDenseXlaCall(self):
        # Test handling of dense xla_call within jaxpr interpreter.
        out = self.sparsify(jit(lambda x: x + 1))(0.0)
        self.assertEqual(out, 1.0)

    def testSparsifySparseXlaCall(self):
        # Test sparse lowering of XLA call
        def func(M):
            return 2 * M

        M = jnp.arange(6).reshape(2, 3)
        Msp = BCOO.fromdense(M)

        out_dense = func(M)
        out_sparse = self.sparsify(jit(func))(Msp)
        self.assertArraysEqual(out_dense, out_sparse.todense())

    def testSparseForiLoop(self):
        def func(M, x):
            body_fun = lambda i, val: (M @ val) / M.shape[1]
            return lax.fori_loop(0, 2, body_fun, x)

        x = jnp.arange(5.0)
        M = jnp.arange(25).reshape(5, 5)
        M_bcoo = BCOO.fromdense(M)

        result_dense = func(M, x)
        result_sparse = self.sparsify(func)(M_bcoo, x)

        self.assertArraysAllClose(result_dense, result_sparse)

    def testSparseCondSimple(self):
        def func(x):
            return lax.cond(False, lambda x: x, lambda x: 2 * x, x)

        x = jnp.arange(5.0)
        result_dense = func(x)

        x_bcoo = BCOO.fromdense(x)
        result_sparse = self.sparsify(func)(x_bcoo)

        self.assertArraysAllClose(result_dense, result_sparse.todense())

    def testSparseCondMismatchError(self):
        @self.sparsify
        def func(x, y):
            return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))

        x = jnp.arange(5.0)
        y = jnp.arange(5.0)

        x_bcoo = BCOO.fromdense(x)
        y_bcoo = BCOO.fromdense(y)

        func(x, y)  # No error
        func(x_bcoo, y_bcoo)  # No error

        with self.assertRaisesRegex(
                TypeError, "sparsified true_fun and false_fun output.*"):
            func(x_bcoo, y)

    def testToDense(self):
        M = jnp.arange(4)
        Msp = BCOO.fromdense(M)

        @self.sparsify
        def func(M):
            return todense(M) + 1

        self.assertArraysEqual(func(M), M + 1)
        self.assertArraysEqual(func(Msp), M + 1)
        self.assertArraysEqual(jit(func)(M), M + 1)
        self.assertArraysEqual(jit(func)(Msp), M + 1)

    def testWeakTypes(self):
        # Regression test for https://github.com/google/jax/issues/8267
        M = jnp.arange(12, dtype='int32').reshape(3, 4)
        Msp = BCOO.fromdense(M)
        self.assertArraysEqual(
            operator.mul(2, M),
            self.sparsify(operator.mul)(2, Msp).todense(),
            check_dtypes=True,
        )
Example #21
0
class Jax2TfTest(tf_test_util.JaxToTfTestCase):

  def test_basics(self):
    f_jax = lambda x: jnp.sin(jnp.cos(x))
    _, res_tf = self.ConvertAndCompare(f_jax, 0.7)

  def test_input_output_naming(self):
    @jax2tf.convert
    def f(xs, y):
      return [jnp.add(x, y) for x in xs]

    @tf.function(autograph=False)
    def u(xs, y):
      xs = tf.nest.map_structure(tf.convert_to_tensor, xs)
      with tf.GradientTape() as tape:
        tf.nest.map_structure(tape.watch, xs)
        y = f(xs, y)
        tape.gradient(y, xs)
        return y

    cf = u.get_concrete_function([1., 2., 3.], 4.)
    g = cf.graph
    g.get_operation_by_name("jax2tf_arg_0")
    g.get_operation_by_name("jax2tf_arg_1")
    g.get_operation_by_name("jax2tf_arg_2")
    g.get_operation_by_name("jax2tf_arg_3")
    g.get_operation_by_name("jax2tf_out")
    g.get_operation_by_name("jax2tf_out_1")
    g.get_operation_by_name("jax2tf_out_2")
    with self.assertRaises(KeyError):
      g.get_operation_by_name("jax2tf_arg_4")
    with self.assertRaises(KeyError):
      g.get_operation_by_name("jax2tf_out_3")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_0")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_2")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_3")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_out")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_out_1")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_out_2")
    g.get_operation_by_name("jax2tf_vjp/jax2tf_out_3")

  def test_pytrees(self):
    # Take and return pytrees
    def f_jax(x: Tuple[float, Dict[str, float]]) -> Tuple[float, Dict[str, float]]:
      x_a, x_dict = x
      return x_a * 2., {k: v * 3. for k, v in x_dict.items()}

    x = (.7, {"a": .8, "b": .9})
    self.ConvertAndCompare(f_jax, x)

  def test_variable_input(self):
    f_jax = lambda x: jnp.sin(jnp.cos(x))
    f_tf = jax2tf.convert(f_jax)
    v = tf.Variable(0.7, dtype=jax2tf.dtype_of_val(0.7))
    self.assertIsInstance(f_tf(v), tf.Tensor)
    self.assertAllClose(f_jax(0.7), f_tf(v))

  def test_jit(self):
    f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
    self.ConvertAndCompare(f_jax, 0.7)

  def test_nested_jit(self):
    f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
    f_tf = jax2tf.convert(f_jax)
    np.testing.assert_allclose(f_jax(0.7), f_tf(0.7))

  def test_nested_jit_is_compiled(self):
    # Check that nested jax.jit are compiled with tf.function(jit_compile=True)
    # We do this by looking for the _XlaMustCompile attribute in the function graph
    def has_xla_must_compile(f_tf, x):
      f_conc = tf.function(f_tf, autograph=True).get_concrete_function(tf.convert_to_tensor(x))
      for n in f_conc.graph._nodes_by_id.values():
        try:
          n.get_attr("_XlaMustCompile")
          return True
        except ValueError:
          continue
      return False

    x = np.array(0.7)
    f_no_jit = lambda x: x
    self.assertFalse(has_xla_must_compile(jax2tf.convert(f_no_jit), x))
    f_jit = lambda x: jax.jit(jnp.sin)(x)
    # TODO(b/207464757): TF compilation is disabled
    self.assertFalse(has_xla_must_compile(jax2tf.convert(f_jit), x))

  def test_converts_jax_arrays(self):
    f_tf = tf.function(lambda x: x)
    self.assertEqual(f_tf(jnp.zeros([])).numpy(), 0.)
    self.assertEqual(f_tf(jnp.ones([])).numpy(), 1.)
    f_tf = tf.function(lambda x: x + x)
    self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.)

    # Test with ShardedDeviceArray.
    n = jax.local_device_count()
    mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n]))
    f_tf = tf.function(lambda x: x)
    self.assertAllClose(f_tf(mk_sharded(jnp.zeros)).numpy(),
                        jnp.zeros([n]))
    self.assertAllClose(f_tf(mk_sharded(jnp.ones)).numpy(),
                        jnp.ones([n]))

  @jtu.skip_on_devices("gpu")
  def test_bfloat16_passed_by_tf(self):
    f_jax = lambda a, b: a + b
    f_tf = tf.function(jax2tf.convert(f_jax),
                       input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
                                        tf.TensorSpec([512, 512], tf.bfloat16)])
    self.assertIsNotNone(f_tf.get_concrete_function())

  @jtu.skip_on_devices("gpu")
  def test_bfloat16_returned_by_jax(self):
    f_jax = lambda a, b: (a + b).astype(jnp.bfloat16)
    f_tf = jax2tf.convert(f_jax)
    self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16)

  @jtu.skip_on_devices("gpu")
  def test_bfloat16_tf_grad(self):
    f_jax = lambda a, b: a + b

    def _tf_grad(a, b):
      with tf.GradientTape() as tape:
        tape.watch(a)
        result = jax2tf.convert(f_jax)(a, b)
      return result, tape.gradient(result, a)

    f_tf = tf.function(_tf_grad,
                       input_signature=[tf.TensorSpec([512, 512], tf.bfloat16),
                                        tf.TensorSpec([512, 512], tf.bfloat16)])

    self.assertIsNotNone(f_tf.get_concrete_function())

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_dtype={dtype.__name__}_function={with_function}",
         dtype=dtype,
         with_function=with_function)
    for dtype in [np.int64, np.float64]
    for with_function in [True, False]))
  def test_converts_64bit(self, dtype=np.int64, with_function=False):
    if not config.jax_enable_x64:
      self.skipTest("requires x64 mode")
    big_const = np.full((5,), 2 ** 33, dtype=dtype)
    self.ConvertAndCompare(jnp.sin, big_const)
    f_conv = jax2tf.convert(jnp.sin)
    if with_function:
      f_conv = tf.function(f_conv)
    # We check also when we pass tf.Variable or tf.Tensor into the
    # converted function
    self.assertAllClose(jnp.sin(big_const),
                        f_conv(tf.Variable(big_const)))
    self.assertAllClose(jnp.sin(big_const),
                        f_conv(tf.constant(big_const)))

  def test_64bit_behavior_enable_x64(self):
    if not config.jax_enable_x64:
      self.skipTest("requires x64 mode")

    # JAX and TF have different default float types if JAX_ENABLE_X64=1
    self.assertEqual(tf.math.sin(0.7).dtype, tf.float32)
    self.assertEqual(jnp.sin(0.7).dtype, jnp.float64)

    # jax2tf.convert has the same behavior as JAX
    self.assertEqual(jax2tf.convert(jnp.sin)(0.7).dtype, tf.float64)

  def test_64bit_behavior_not_enable_x64(self):
    if config.jax_enable_x64:
      self.skipTest("requires not x64 mode")

    # JAX and TF have same default float types if JAX_ENABLE_X64=1
    self.assertEqual(tf.math.sin(0.7).dtype, tf.float32)
    self.assertEqual(jnp.sin(0.7).dtype, jnp.float32)

    # Except that JAX forces values to 32-bit
    self.assertEqual(jnp.sin(np.float64(0.7)).dtype, jnp.float32)

    # jax2tf.convert has the same behavior as JAX
    self.assertEqual(jax2tf.convert(jnp.sin)(0.7).dtype, tf.float32)
    self.assertEqual(jax2tf.convert(jnp.sin)(np.float64(0.7)).dtype, tf.float32)

  def test_function(self):
    f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
    self.ConvertAndCompare(f_jax, 0.7)

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_gradients_disabled(self, with_function=False):
    f_tf = jax2tf.convert(jnp.tan, with_gradient=False)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    x = tf.ones([])

    # With tf.function the error is raised when we evaluate f_tf(x), in
    # eager mode when we evaluate tape.gradient(y, x)
    with self.assertRaisesRegex(LookupError,
                                "Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
      with tf.GradientTape() as tape:
        tape.watch(x)
        y = f_tf(x)
        _ = tape.gradient(y, x)

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"function={with_function}",
         with_function=with_function)
    for with_function in [False, True]))
  def test_gradients(self, with_function=True):
    def f(x, y):
      return x * x, x * y
    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    default_float_type = jax2tf.dtype_of_val(4.)
    x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
    y = tf.Variable(5., dtype=default_float_type)
    with tf.GradientTape(persistent=True) as tape:
      u, v = f_tf(x, y)

    self.assertAllClose(2. * 4., tape.gradient(u, x))
    self.assertAllClose(0., tape.gradient(u, y))
    self.assertAllClose(5., tape.gradient(v, x))
    self.assertAllClose(4., tape.gradient(v, y))

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"function={with_function}",
         with_function=with_function)
    for with_function in [False, True]))
  def test_gradients_pytree(self, with_function=True):
    def f(xy: Tuple[float, float]) -> Dict[str, float]:
      x, y = xy
      return dict(one=x * x, two=x * y)

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    default_float_dtype = jax2tf.dtype_of_val(4.)
    x = tf.Variable(4., dtype=default_float_dtype)
    y = tf.Variable(5., dtype=default_float_dtype)
    with tf.GradientTape(persistent=True) as tape:
      uv = f_tf((x, y))

    self.assertAllClose(2. * 4., tape.gradient(uv["one"], x))
    self.assertAllClose(0., tape.gradient(uv["one"], y))
    self.assertAllClose(5., tape.gradient(uv["two"], x))
    self.assertAllClose(4., tape.gradient(uv["two"], y))

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"function={with_function}",
         with_function=with_function)
    for with_function in [False, True]))
  def test_gradients_with_ordered_dict_input(self, with_function=True):
    def f(inputs):
      out = 0.0
      for v in inputs.values():
        out += jnp.sum(v)
      return out

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    default_float_type = jax2tf.dtype_of_val(4.)
    inputs = OrderedDict()
    x = tf.Variable([4.], dtype=default_float_type)
    y = tf.Variable([4., 5.], dtype=default_float_type)
    inputs = OrderedDict()
    inputs['r'] = x
    inputs['d'] = y
    with tf.GradientTape(persistent=True) as tape:
      u = f_tf(inputs)

    self.assertAllClose(np.array([1.]), tape.gradient(u, x).numpy())
    self.assertAllClose(np.array([1., 1.]), tape.gradient(u, y).numpy())

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_gradients_with_custom_jvp(self, with_function=True):
    """Check gradients, for a function with custom JVP."""
    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      # 3 * x * x_t
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    self.assertAllClose(4. * 4., f(4.))
    self.assertAllClose(3. * 4., jax.grad(f)(4.))

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    self.assertAllClose(4. * 4., f_tf(4.))
    x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
    with tf.GradientTape() as tape:
      tape.watch(x)
      y = f_tf(x)

    self.assertAllClose(4. * 4., y)
    self.assertAllClose(3. * 4., tape.gradient(y, x))

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_gradients_with_custom_vjp(self, with_function=True):
    """Check gradients, for a function with custom VJP."""
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), 3. * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)

    self.assertAllClose(4. * 4., f(4.))
    self.assertAllClose(3. * 4., jax.grad(f)(4.))

    f_tf = jax2tf.convert(f, with_gradient=True)
    if with_function:
      f_tf = tf.function(f_tf, autograph=False)
    self.assertAllClose(4. * 4., f_tf(4.))
    x = tf.Variable(4., dtype=jax2tf.dtype_of_val(4.))
    with tf.GradientTape() as tape:
      tape.watch(x)
      y = f_tf(x)

    self.assertAllClose(4. * 4., y)
    self.assertAllClose(3. * 4., tape.gradient(y, x))

  def test_gradient_with_float0_intermediate(self):
    # Gradient over integer-argument functions
    def f(x, y):  # x is an int, y is a float
      return 2 * x + y

    def g(x):  # x: f32
      return 2. * f(3 * x.astype("int32"), x * 4.)

    x = 2.
    grad_g = jax.grad(g)
    self.ConvertAndCompare(grad_g, x)

  def test_gradient_with_float0_result(self):
    # Gradient over integer-argument functions, with float0 result
    def f(x, y):  # x is an int, y is a float
      return 2 * x + y

    def g(x):  # x: i32
      return jnp.sum(2. * f(3 * x, 4. * jnp.array(x, jnp.dtype("float32"))))

    grad_g = jax.grad(g, allow_int=True)
    x = 2
    d_dx_jax = grad_g(x)
    d_dx_tf = jax2tf.convert(grad_g)(x)
    self.assertEqual(d_dx_jax.dtype, dtypes.float0)
    self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32),
                        d_dx_tf.numpy())

    shape = (3, 4)
    x = np.ones(shape, dtype=np.int32)
    d_dx_jax = grad_g(x)
    d_dx_tf = jax2tf.convert(grad_g)(x)
    self.assertEqual(d_dx_jax.dtype, dtypes.float0)
    self.assertAllClose(jnp.zeros(np.shape(d_dx_jax), np.int32),
                        d_dx_tf.numpy())

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_gradients_unused_argument_readme(self, with_function=True):
    # x2 and x3 are not used. x3 has integer type.
    def fn(x0, x1, x2, x3):
      return x0 * 0. + x2 * 2.

    xs = [tf.Variable(x) for x in [10., 11., 12., 13]]
    with tf.GradientTape(persistent=True) as tape:
      res = fn(*xs)

    g_tf_native = tape.gradient(res, xs)
    self.assertAllClose(g_tf_native[0].numpy(), np.float32(0.))
    self.assertIsNone(g_tf_native[1])
    self.assertAllClose(g_tf_native[2].numpy(), np.float32(2.))
    self.assertIsNone(g_tf_native[3])

    g_tf_native_0 = tape.gradient(res, xs,
                                  unconnected_gradients=tf.UnconnectedGradients.ZERO)
    self.assertAllClose(g_tf_native_0[0].numpy(), np.float32(0.))
    self.assertAllClose(g_tf_native_0[1].numpy(), np.float32(0.))
    self.assertAllClose(g_tf_native_0[2].numpy(), np.float32(2.))
    self.assertAllClose(g_tf_native_0[3].numpy(), np.int32(0))

    # Now with jax2tf.convert
    with tf.GradientTape(persistent=True) as tape:
      conv_fn = jax2tf.convert(fn, with_gradient=True)
      if with_function:
        conv_fn = tf.function(conv_fn, autograph=False)
      res = conv_fn(*xs)

    g_jax2tf = tape.gradient(res, xs)
    # Returns: 0., 0., 2., None
    # Note that the gradient for x1 is 0.
    self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
    self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
    self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
    self.assertIsNone(g_jax2tf[3])

    g_jax2tf = tape.gradient(res, xs,
                               unconnected_gradients=tf.UnconnectedGradients.ZERO)
    self.assertAllClose(g_jax2tf[0].numpy(), np.float32(0.))
    self.assertAllClose(g_jax2tf[1].numpy(), np.float32(0.))
    self.assertAllClose(g_jax2tf[2].numpy(), np.float32(2.))
    self.assertAllClose(g_jax2tf[3].numpy(), np.int32(0))

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_gradients_int_argument(self, with_function=True):
    # https://github.com/google/jax/issues/6975
    # Also issue #6975.
    # An expanded version of test_gradients_unused_argument
    state = dict(
        float_used=np.array([0.7, 0.9], dtype=np.float32),
        float_passthrough=np.float16(1.),
        float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32),
        int_used=np.int16(5),
        int_passthrough=np.int8(7),
        int_unused=np.array([1, 2, 3], dtype=np.uint32),
        bool_used=np.array([True, False, False, True], dtype=np.bool_),
        bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_),
        bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_),
    )
    def jax_f(state):
      res = dict(state,
                 float_used=2. * state["float_used"],
                 int_used=3 * state["int_used"],
                 bool_used=(state["bool_used"] == state["bool_used"]))
      del res["float_unused"]
      del res["int_unused"]
      del res["bool_unused"]
      return res

    args = (state,)
    res_jax = jax_f(*args)
    # Native JAX AD
    vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax)
    grad_jax, = vjp_jax_fun(*args_vjp)

    def compare_with_overrides(*, what, expected, **expected_overrides):
      what_keys = set(what.keys())
      expected_keys = set(expected.keys())
      self.assertEqual(what_keys, expected_keys)
      for k, w in what.items():
        e = expected[k]
        if k in expected_overrides:
          if expected_overrides[k] == "ZERO":
            e = np.zeros_like(w)
          elif expected_overrides[k] == "ZERO_INT32":
            e = np.zeros(np.shape(w), dtype=np.int32)
          elif expected_overrides[k] == "ONE":
            e = np.ones_like(w)
          else:
            e = expected_overrides[k]

        if e is None:
          self.assertIsNone(w, msg=k)
        else:
          self.assertIsNotNone(w, msg=k)
        w = w.numpy() if isinstance(w, tf.Tensor) else e
        e = e.numpy() if isinstance(e, tf.Tensor) else e
        try:
          self.assertAllClose(e, w, err_msg=k)
        except:
          print(f"Failed at {k}")
          raise


    # compare_with_overrides(g_jax, {},
    #   bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0),
    #   bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0),
    #   bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0),
    #   float_passthrough=np.ones_like(state["float_passthrough"]),
    #   float_unused=np.zeros_like(state["float_unused"]),
    #   float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype),
    #   int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0),
    #   int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0),
    #   int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0))


    # Now native TF gradients, only to test how native TF AD works
    _, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad(
        jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO)
    compare_with_overrides(what=grad_tf_0,
                           expected=grad_jax,
                           float_unused="ZERO",
                           bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO",
                           int_used="ZERO", int_passthrough="ONE", int_unused="ZERO")

    _, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad(
        jax_f, args,
        unconnected_gradients=tf.UnconnectedGradients.NONE)
    compare_with_overrides(what=grad_tf_None,
                           expected=grad_tf_0,
                           float_unused=None, int_used=None, int_unused=None,
                           bool_used=None, bool_unused=None)

    f_tf_jax = jax2tf.convert(jax_f)
    if with_function:
      f_tf_jax = tf.function(f_tf_jax, autograph=False)

    _, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args)
    # Same results as TF native AD with tf.UnconnectedGradients.ZERO
    compare_with_overrides(what=grad_tf_jax_0,
                           expected=grad_tf_0,
                           int_passthrough="ZERO", bool_passthrough="ZERO")

    _, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad(
        f_tf_jax, args,
        unconnected_gradients=tf.UnconnectedGradients.NONE)
    compare_with_overrides(what=grad_tf_jax_None,
                           expected=grad_tf_0,
                           int_used=None, int_passthrough=None, int_unused=None,
                           bool_unused=None, bool_used=None, bool_passthrough=None)

    # Not convert the JAX gradient function
    tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun)
    grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp)
    compare_with_overrides(what=grad_tf_vjp_jax,
                           expected=grad_tf_0,
                           bool_passthrough="ZERO_INT32",
                           bool_unused="ZERO_INT32", bool_used="ZERO_INT32",
                           int_passthrough="ZERO_INT32", int_unused="ZERO_INT32",
                           int_used="ZERO_INT32")

  def test_readme_gradient_int(self):
    x = np.array(2, dtype=np.int16)

    def f_jax(x):  # x: int16
      return x.astype(np.float32) * 2.

    print(jax.grad(f_jax, allow_int=True)(x))
    # returns a special `float0`: array((b'',), dtype=[('float0', 'V')])

    print(jax2tf.convert(jax.grad(f_jax, allow_int=True))(x))
    # returns a 0 with same shape as x, but with dtype int32

    def f_tf(x):  # x: int16
      return tf.cast(x, tf.float32) * 2.

    xv = tf.Variable(x)
    with tf.GradientTape(persistent=True) as tape:
      print(tape.gradient(f_tf(xv), xv))
      # returns None
      print(tape.gradient(f_tf(xv), xv,
                          unconnected_gradients=tf.UnconnectedGradients.ZERO))
      # returns 0 with the same shape and dtype as x


  def test_convert_argument_non_callable_error(self):
    with self.assertRaisesRegex(TypeError, "Expected a callable value"):
      jax2tf.convert(5.)

  def test_convert_argument_non_tensor_error(self):
    with self.assertRaisesRegex(TypeError,
                                "Argument.*should be NumPy array"):
      jax2tf.convert(lambda x: x)(lambda y: y)

  def test_argument_eager_tensor(self):
    x = jax2tf.convert(jnp.sin)(1.)
    jax2tf.convert(jnp.cos)(x)  # No error

  def test_checkpoint_wrapper_types(self):
    m = tf.Module()
    m.a = [tf.Module(), tf.Module()]
    m.b = (tf.Module(), tf.Module())
    m.c = {'a': tf.Module(), 'b': tf.Module()}
    self.assertNotEqual(type(m.a), list)
    self.assertNotEqual(type(m.b), tuple)
    self.assertNotEqual(type(m.c), dict)
    self.assertLen(jax.tree_leaves(m.a), 2)
    self.assertLen(jax.tree_leaves(m.b), 2)
    self.assertLen(jax.tree_leaves(m.c), 2)

  def test_custom_jvp(self):
    """Conversion of function with custom JVP"""

    @jax.custom_jvp
    def f(x):
      return x * x

    @f.defjvp
    def f_jvp(primals, tangents):
      x, = primals
      x_dot, = tangents
      primal_out = f(x)
      tangent_out = 3. * x * x_dot
      return primal_out, tangent_out

    arg = 0.7
    self.TransformConvertAndCompare(f, arg, None)
    self.TransformConvertAndCompare(f, arg, "jvp")
    self.TransformConvertAndCompare(f, arg, "vmap")
    self.TransformConvertAndCompare(f, arg, "jvp_vmap")
    self.TransformConvertAndCompare(f, arg, "grad")
    self.TransformConvertAndCompare(f, arg, "grad_vmap")

  def test_custom_vjp(self):
    """Conversion of function with custom VJP"""

    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), 3. * x

    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)
    arg = 0.7
    self.TransformConvertAndCompare(f, arg, None)
    self.TransformConvertAndCompare(f, arg, "vmap")
    self.TransformConvertAndCompare(f, arg, "grad")
    self.TransformConvertAndCompare(f, arg, "grad_vmap")

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{flavor}", flavor=flavor)
    for flavor in ["old", "new"]))
  def test_remat(self, flavor="old"):
    def f(x1):
      x2 = jnp.sin(x1)
      x3 = jnp.sin(x2)
      x4 = jnp.sin(x3)
      return x4
    remat_f = jax.remat(f) if flavor == "old" else ad_checkpoint.checkpoint(f)

    # The computation of grad_f computes "sin" 5 times, 3 for the forward pass
    # and then to rematerialize "x2" and "x3" in the backward pass.
    arg = np.array(3.)
    # Check that we have a Sin under a conditional
    f_tf = tf.function(jax2tf.convert(jax.grad(remat_f)), autograph=False)
    f_tf_graph = f_tf.get_concrete_function(arg).graph.as_graph_def()
    if flavor == "old":
      raise unittest.SkipTest("TODO: CSE widget not yet implemented for old-style remat")
    if jax.config.jax_remat_opt_barrier:
      self.assertRegex(
          str(f_tf_graph), r"remat_checkpoint_/XlaOptimizationBarrier")
    elif config.jax_experimental_name_stack:
      self.assertRegex(str(f_tf_graph),
                       r'transpose/jax2tf_f_/jvp/checkpoint/remat_checkpoint_/cond/branch_1_fun/Sin')
    else:
      self.assertRegex(str(f_tf_graph),
                       r'remat_checkpoint_/switch_case/indexed_case/Sin')

  def test_remat_free_var(self):
    def f(x):
      y = 2 * x

      @ad_checkpoint.checkpoint
      def g():
        return y

      return g()
    arg = 3.
    self.TransformConvertAndCompare(f, arg, None)
    self.TransformConvertAndCompare(f, arg, "grad")

  def test_convert_nullary_func(self):
    # Even nullary functions are converted to TF (as opposed to constant-folded
    # in JAX prior to conversion).
    def f_jax():
      return jnp.sin(1.)
    f_tf = tf.function(jax2tf.convert(f_jax), autograph=False)
    f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def()
    self.assertIn('op: "Sin"', str(f_tf_graph))

  def test_convert_of_nested_independent_jit(self):
    def func(x):
      def inner1(y):
        return x + y
      # The JIT does not have data dependency
      return jax.jit(inner1)(1.)

    jax2tf.convert(func)(2.)

  def test_convert_of_nested_dependent_jit(self):
    def func(x):
      def inner1(y):
        return x + y
      # The JIT does have data dependency
      return jax.jit(inner1)(x)

    jax2tf.convert(func)(2.)  # No error

  def test_nested_convert_error(self):
    def outer(y):
      return jax2tf.convert(jnp.sin)(y)  # Inner convert takes tracer args
    with self.assertRaisesRegex(
        ValueError, "convert must be used outside all JAX transformations"):
      jax2tf.convert(outer)(np.ones((4, )))

  def test_nested_convert_error_non_tracer(self):
    """The inner convert takes non-tracer arguments"""
    def outer(y):
      sin_1 = jax2tf.convert(jnp.sin)(1.)  # Inner convert takes non-tracer arg
      return y + sin_1

    with self.assertRaisesRegex(
        ValueError, "convert must be used outside all JAX transformations"):
      jax2tf.convert(outer)(2.)


  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{transform}", transform=transform)
    for transform in ["jit", "jvp", "grad", "vmap"]))
  def test_convert_under_transform_error(self, transform="vmap"):
    def outer(y):
      return jax2tf.convert(jnp.sin)(y)  # Inner convert takes tracer args

    with self.assertRaisesRegex(
        ValueError, "convert must be used outside all JAX transformations"):
      self.TransformConvertAndCompare(outer, np.ones((4,)), transform)

  @parameterized.named_parameters(jtu.cases_from_list(
    dict(testcase_name=f"_{transform}", transform=transform)
    for transform in ["jit", "jvp", "grad", "vmap"]))
  def test_convert_under_transform_error_non_tracer(self, transform="vmap"):
    def outer(y):
      sin_1 = jax2tf.convert(jnp.sin)(1.)  # Inner convert takes non-tracer arg
      return y + sin_1

    with self.assertRaisesRegex(
        ValueError, "convert must be used outside all JAX transformations"):
      self.TransformConvertAndCompare(outer, np.ones((4,)), transform)

  def test_name_scope(self):
    @tf.function(autograph=False)
    def run():
      @jax.named_call
      def my_test_function(x):
        return x * x

      def caller(x):
        return my_test_function(jnp.sin(x))

      out = jax2tf.convert(caller, with_gradient=False)(2.)
      # When we use `with_gradient=False` the raw output of `caller` is passed
      # through a `tf.raw_ops.PreventGradient` and a `tf.identity`, clobbering
      # the name scope of the `mul` op. We need to get the grandparent of the
      # `out` tensor to see the name scope of the result of the `mul`.
      grandparent_op = out.op.inputs[0].op.inputs[0]
      self.assertIn("my_test_function", grandparent_op.name)
      return out
    run()

  def test_bfloat16_constant(self):
    # Re: https://github.com/google/jax/issues/3942
    def jax_fn_scalar(x):
      x = x.astype(jnp.bfloat16)
      x *= 2.
      return x

    def jax_fn_array(x):
      x = x.astype(jnp.bfloat16)
      x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16)
      return x

    tf_fn_scalar = jax2tf.convert(jax_fn_scalar)
    self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750))

    tf_fn_array = jax2tf.convert(jax_fn_array)
    self.assertAllClose(
        tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5],
                                                   jnp.bfloat16))

  def test_shared_constants(self):
    # Check that the constants are shared properly in converted functions
    # See https://github.com/google/jax/issues/7992.
    const = np.ones((16, 16))
    def f(x):
      return x + const + const + const + const

    f_tf_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
    self.assertEqual(f_tf_nr_consts, 1)

  def test_shared_constants_under_cond(self):
    # Check that the constants are shared properly in converted functions
    # See https://github.com/google/jax/issues/7992.
    const = np.arange(256, dtype=np.float32)
    x = np.ones((256,), dtype=np.float32)
    def f1(x):
      return lax.cond(x[0] >= 0., lambda x: x + const, lambda x: x * const, x) + const
    def f2(x):
      return f1(x) + const  # The extra const should not cost anything
    f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), x)
    f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), x)
    self.assertEqual(f1_nr_consts, f2_nr_consts)

  def test_shared_constants_under_scan(self):
    # See https://github.com/google/jax/issues/7992.
    const = np.arange(256, dtype=np.float32)
    xs = np.ones((8, 256), dtype=np.float32)
    def f1(xs):
      res, _ = lax.scan(lambda carry, x: (carry + x + const, None),
                        np.zeros((256,), dtype=np.float32), xs)
      return res

    def f2(xs):
      return f1(xs) + const  # The extra const should not be saved

    f1_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f1), xs)
    f2_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f2), xs)
    self.assertEqual(f1_nr_consts, f2_nr_consts)

  def test_shared_constants_under_jit(self):
    # We do not share constants under jit.
    const = np.ones((16, 16))
    @jax.jit
    def g_jit(x):
      return x * const
    def f(x):
      return g_jit(x) + const + const

    f_tf_graph_nr_consts = self.CountLargeTfConstants(jax2tf.convert(f), const)
    # TODO(b/207464757): TF compilation is disabled
    self.assertEqual(f_tf_graph_nr_consts, 1)

  def test_weak_types(self):
    mul = jax.jit(jnp.multiply)
    # The value `2` here should be weakly typed, and should not lead to
    # promotion.
    tf_fn = jax2tf.convert(lambda x: mul(x, 2.))
    self.assertAllClose(tf_fn(tf.constant(1.375, tf.bfloat16)).numpy(),
                        jnp.bfloat16(2.750))

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_kwargs(self, with_function=True):
    # Re: https://github.com/google/jax/issues/6791
    def f_jax(*, x):
      return jnp.sum(x)
    f_tf = jax2tf.convert(f_jax)
    if with_function:
      f_tf = tf.function(f_tf)
    self.assertAllClose(
      f_tf(x=np.zeros(3, dtype=np.float32)),  # Call with kwargs.
      np.zeros((), dtype=np.float32))

  @parameterized.named_parameters(jtu.cases_from_list(
      dict(testcase_name=f"function={with_function}",
           with_function=with_function)
      for with_function in [False, True]))
  def test_grad_kwargs(self, with_function=False):
    # Re: https://github.com/google/jax/issues/6791
    x = (np.zeros(3, dtype=np.float32),
         np.zeros(4, dtype=np.float32))
    def f_jax(*, x=(1., 2.)):
      return jnp.sum(x[0]) + 2. * jnp.sum(x[1])
    f_tf = jax2tf.convert(f_jax)
    if with_function:
      f_tf = tf.function(f_tf)
    xv = tf.nest.map_structure(tf.Variable, x)
    with tf.GradientTape() as tape:
      res = f_tf(x=xv)
    grad_tf = tape.gradient(res, xv)
    self.assertAllClose((np.full_like(x[0], fill_value=1.),
                         np.full_like(x[1], fill_value=2.)),
                        (grad_tf[0].numpy(), grad_tf[1].numpy()))


  def test_enable_xla(self):
    # Tests that enable_xla flag is properly scoped to a conversion.
    def fun(x):
      # lax.reduce is unlikely to ever be convertible with enable_xla=False
      return lax.reduce(x, np.float32(0), lambda v, acc: v + acc, dimensions=(0, 1))

    tf_fun_with_xla = jax2tf.convert(fun, enable_xla=True)
    tf_fun_without_xla = jax2tf.convert(fun, enable_xla=False)
    x = np.ones((2, 3), dtype=np.float32)

    self.assertAllClose(fun(x), tf_fun_with_xla(x))
    with self.assertRaisesRegex(NotImplementedError,
                                "Call to reduce cannot be converted with enable_xla=False"):
      tf_fun_without_xla(x)

    # Now in reverse order (we had bugs with the management of enable_xla global)
    tf_fun2_without_xla = jax2tf.convert(lambda x: fun(x), enable_xla=False)
    tf_fun2_with_xla = jax2tf.convert(lambda x: fun(x), enable_xla=True)

    with self.assertRaisesRegex(NotImplementedError,
                                "Call to reduce cannot be converted with enable_xla=False"):
      tf_fun2_without_xla(x)
    self.assertAllClose(fun(x), tf_fun2_with_xla(x))

  def test_device_array_arg(self):
    self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32))

  def test_randint(self):
    def randint():
      return jax.random.randint(
          jax.random.PRNGKey(42), shape=(), minval=0, maxval=1)

    self.ConvertAndCompare(randint)

  def test_op_metadata_simple(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # A simple example
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_simple(x):
      return jnp.sin(x)

    x = np.ones((2, 3), np.float32)
    self.CheckOpMetadata(
        f_simple, x,
        [tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 2,
                                      op_name="jax2tf(f_simple)/sin",
                                      op_type="sin")
         ]
    )

  def test_op_metadata_sub_jit(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # Calling a jitted-function
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_callee(x):
      return jnp.cos(x)
    def f_caller(x):
      y = jnp.tanh(x)
      z = jax.jit(f_callee)(y)
      return jnp.sin(z)

    x = np.ones((2, 3), np.float32)

    self.CheckOpMetadata(
        f_caller, x,
        [tf_test_util.OpMetadataGraph(tf_type="Tanh",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 4,
                                      op_name="jax2tf(f_caller)/tanh",
                                      op_type="tanh"),
         tf_test_util.OpMetadataGraph(tf_type="Cos",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 2,
                                      op_name="jax2tf(f_caller)/jit(f_callee)/cos",
                                      op_type="cos"),
         tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 6,
                                      op_name="jax2tf(f_caller)/sin",
                                      op_type="sin"),
         ]
    )

  def test_op_metadata_named(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # Calling a jax.named_call
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_callee(x):
      return jnp.cos(x)
    def f_caller(x):
      y = jnp.tanh(x)
      z = jax.named_call(f_callee, name="callee")(y)
      return jnp.sin(z)

    x = np.ones((2, 3), np.float32)

    self.CheckOpMetadata(
        f_caller, x,
        [tf_test_util.OpMetadataGraph(tf_type="Tanh",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 4,
                                      op_name="jax2tf(f_caller)/tanh",
                                      op_type="tanh"),
         tf_test_util.OpMetadataGraph(tf_type="Cos",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 2,
                                      op_name="jax2tf(f_caller)/named(callee)/cos",
                                      op_type="cos"),
         tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 6,
                                      op_name="jax2tf(f_caller)/sin",
                                      op_type="sin"),
         ]
    )

  def test_op_metadata_while_and_cond(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # An example with while and cond
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_while_cond(x):
      def body_fun(i_acc):
        i, acc = i_acc
        return (i + 1,
                (jnp.cos(acc) +
                 lax.cond(jnp.mod(i, 2) == 0,
                          lambda acc: jnp.sin(acc),
                          lambda acc: acc,
                          acc)))

      _, acc = lax.while_loop(
          lambda i_acc: i_acc[0] <= 5,
          body_fun, (0, x))
      return acc

    x = np.ones((2, 3), np.float32)
    self.CheckOpMetadata(
        f_while_cond, x,
        [tf_test_util.OpMetadataGraph(tf_type="Cos",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 5,
                                      op_name="jax2tf(f_while_cond)/while/body/cos",
                                      op_type="cos"),
         tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 7,
                                      op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
                                      op_type="sin"),
         tf_test_util.OpMetadataGraph(tf_type="FloorMod",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 6,
                                      op_name="jax2tf(f_while_cond)/while/body/rem",
                                      op_type="rem"),
         ]
    )

  def test_op_metadata_batched_while(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # An example with while and cond
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    @jax.vmap
    def f_while(x):
      def body_fun(carry):
        new_carry = jnp.sin(carry)  # We look for "sin" in the graph
        return new_carry

      _, carry = lax.while_loop(
          lambda carry: jnp.all(carry <= x),  # We look for "le" in the graph
          body_fun, x)
      return carry

    shape = (3, 2)
    x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)

    jax_comp = jax.xla_computation(f_while)(x)
    backend = jax._src.lib.xla_bridge.get_backend()
    modules = backend.compile(jax_comp).hlo_modules()
    jax_opt_hlo = modules[0].to_string()
    print(f"JAX OPT HLO = {jax_opt_hlo}")

    self.CheckOpMetadata(
        f_while, x,
        [tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 4,
                                      op_name="jax2tf(f_while)/while/body/sin",
                                      op_type="sin"),
         tf_test_util.OpMetadataGraph(tf_type="LessEqual",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 8,
                                      op_name="jax2tf(f_while)/while/body_pred/le",
                                      op_type="le"),
         ]
    )

  def test_op_metadata_disabled(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    def f_simple(x):
      return jnp.sin(x)

    x = np.ones((2, 3), np.float32)
    self.CheckOpMetadata(
        f_simple, x,
        [],
        include_xla_op_metadata=False
    )
Example #22
0
class FftTest(jtu.JaxTestCase):
    def testNotImplemented(self):
        for name in jnp.fft._NOT_IMPLEMENTED:
            func = getattr(jnp.fft, name)
            with self.assertRaises(NotImplementedError):
                func()

    def testLaxFftAcceptsStringTypes(self):
        rng = jtu.rand_default(self.rng())
        x = rng((10, ), np.complex64)
        self.assertAllClose(
            np.fft.fft(x).astype(np.complex64),
            lax.fft(x, "FFT", fft_lengths=(10, )))

    @parameterized.parameters((np.float32, ), (np.float64, ))
    def testLaxIrfftDoesNotMutateInputs(self, dtype):
        if dtype == np.float64 and not config.x64_enabled:
            raise self.skipTest("float64 requires jax_enable_x64=true")
        x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1 + 1j)
        y = np.asarray(jnp.fft.irfft2(x))
        z = np.asarray(jnp.fft.irfft2(x))
        self.assertAllClose(y, z)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inverse={}_real={}_shape={}_axes={}_s={}_norm={}".format(
                    inverse, real, jtu.format_shape_dtype_string(shape, dtype),
                    axes, s, norm),
                "axes":
                axes,
                "shape":
                shape,
                "dtype":
                dtype,
                "inverse":
                inverse,
                "real":
                real,
                "s":
                s,
                "norm":
                norm
            } for inverse in [False, True] for real in [False, True]
            for dtype in (real_dtypes if real and not inverse else all_dtypes)
            for shape in [(10, ), (10, 10), (9, ), (2, 3, 4), (2, 3, 4, 5)]
            for axes in _get_fftn_test_axes(shape)
            for s in _get_fftn_test_s(shape, axes) for norm in FFT_NORMS))
    def testFftn(self, inverse, real, shape, dtype, axes, s, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_op = _get_fftn_func(jnp.fft, inverse, real)
        np_op = _get_fftn_func(np.fft, inverse, real)
        jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm)
        np_fn = lambda a: np_op(a, axes=axes, norm=norm
                                ) if axes is None or axes else a
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if (config.x64_enabled and dtype
                in (float_dtypes if real and not inverse else inexact_dtypes)):
            # TODO(skye): can we be more precise?
            tol = 0.15
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

        # check dtypes
        dtype = jnp_fn(rng(shape, dtype)).dtype
        expected_dtype = jnp.promote_types(
            float if inverse and real else complex, dtype)
        self.assertEqual(dtype, expected_dtype)

    def testIrfftTranspose(self):
        # regression test for https://github.com/google/jax/issues/6223
        def build_matrix(linear_func, size):
            return jax.vmap(linear_func)(jnp.eye(size, size))

        def func(x):
            return jnp.fft.irfft(
                jnp.concatenate([jnp.zeros(1), x[:2] + 1j * x[2:]]))

        def func_transpose(x):
            return jax.linear_transpose(func, x)(x)[0]

        matrix = build_matrix(func, 4)
        matrix2 = build_matrix(func_transpose, 4).T
        self.assertAllClose(matrix, matrix2)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_inverse={}_real={}".format(inverse, real),
                "inverse": inverse,
                "real": real
            } for inverse in [False, True] for real in [False, True]))
    def testFftnErrors(self, inverse, real):
        rng = jtu.rand_default(self.rng())
        name = 'fftn'
        if real:
            name = 'r' + name
        if inverse:
            name = 'i' + name
        func = _get_fftn_func(jnp.fft, inverse, real)
        self.assertRaisesRegex(
            ValueError, "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. "
            "Got axes None with input rank 4.".format(name),
            lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None))
        self.assertRaisesRegex(
            ValueError,
            "jax.numpy.fft.{} does not support repeated axes. Got axes \\[1, 1\\]."
            .format(name),
            lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3]))

    def testFftEmpty(self):
        out = jnp.fft.fft(jnp.zeros((0, ), jnp.complex64)).block_until_ready()
        self.assertArraysEqual(jnp.zeros((0, ), jnp.complex64), out)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inverse={}_real={}_hermitian={}_shape={}_n={}_axis={}".
                format(inverse, real, hermitian,
                       jtu.format_shape_dtype_string(shape, dtype), n, axis),
                "axis":
                axis,
                "shape":
                shape,
                "dtype":
                dtype,
                "inverse":
                inverse,
                "real":
                real,
                "hermitian":
                hermitian,
                "n":
                n
            } for inverse in [False, True] for real in [False, True]
            for hermitian in [False, True]
            for dtype in (real_dtypes if (real and not inverse) or (
                hermitian and inverse) else all_dtypes) for shape in [(10, )]
            for n in [None, 1, 7, 13, 20] for axis in [-1, 0]))
    def testFft(self, inverse, real, hermitian, shape, dtype, n, axis):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        name = 'fft'
        if real:
            name = 'r' + name
        elif hermitian:
            name = 'h' + name
        if inverse:
            name = 'i' + name
        jnp_op = getattr(jnp.fft, name)
        np_op = getattr(np.fft, name)
        jnp_fn = lambda a: jnp_op(a, n=n, axis=axis)
        np_fn = lambda a: np_op(a, n=n, axis=axis)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_op, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inverse={}_real={}_hermitian={}".format(inverse, real,
                                                      hermitian),
            "inverse":
            inverse,
            "real":
            real,
            "hermitian":
            hermitian
        } for inverse in [False, True] for real in [False, True]
                            for hermitian in [False, True]))
    def testFftErrors(self, inverse, real, hermitian):
        rng = jtu.rand_default(self.rng())
        name = 'fft'
        if real:
            name = 'r' + name
        elif hermitian:
            name = 'h' + name
        if inverse:
            name = 'i' + name
        func = getattr(jnp.fft, name)

        self.assertRaisesRegex(
            ValueError,
            f"jax.numpy.fft.{name} does not support multiple axes. "
            f"Please use jax.numpy.fft.{name}n. Got axis = \\[1, 1\\].",
            lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1]))
        self.assertRaisesRegex(
            ValueError,
            f"jax.numpy.fft.{name} does not support multiple axes. "
            f"Please use jax.numpy.fft.{name}n. Got axis = \\(1, 1\\).",
            lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1)))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2]))
        self.assertRaises(
            ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3]))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inverse={}_real={}_shape={}_axes={}_norm={}".format(
                    inverse, real, jtu.format_shape_dtype_string(shape, dtype),
                    axes, norm),
                "axes":
                axes,
                "shape":
                shape,
                "dtype":
                dtype,
                "inverse":
                inverse,
                "real":
                real,
                "norm":
                norm
            } for inverse in [False, True] for real in [False, True]
            for dtype in (real_dtypes if real and not inverse else all_dtypes)
            for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)]
            for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)]
            for norm in FFT_NORMS))
    def testFft2(self, inverse, real, shape, dtype, axes, norm):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        name = 'fft2'
        if real:
            name = 'r' + name
        if inverse:
            name = 'i' + name
        jnp_op = getattr(jnp.fft, name)
        np_op = getattr(np.fft, name)
        jnp_fn = lambda a: jnp_op(a, axes=axes, norm=norm)
        np_fn = lambda a: np_op(a, axes=axes, norm=norm
                                ) if axes is None or axes else a
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_op, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_inverse={}_real={}".format(inverse, real),
                "inverse": inverse,
                "real": real
            } for inverse in [False, True] for real in [False, True]))
    def testFft2Errors(self, inverse, real):
        rng = jtu.rand_default(self.rng())
        name = 'fft2'
        if real:
            name = 'r' + name
        if inverse:
            name = 'i' + name
        func = getattr(jnp.fft, name)

        self.assertRaisesRegex(
            ValueError, "jax.numpy.fft.{} only supports 2 axes. "
            "Got axes = \\[0\\].".format(name),
            lambda: func(rng([2, 3], dtype=np.float64), axes=[0]))
        self.assertRaisesRegex(
            ValueError, "jax.numpy.fft.{} only supports 2 axes. "
            "Got axes = \\(0, 1, 2\\).".format(name),
            lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2)))
        self.assertRaises(
            ValueError,
            lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3]))
        self.assertRaises(
            ValueError,
            lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4]))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_size={}_d={}".format(
                jtu.format_shape_dtype_string([size], dtype), d),
            "dtype":
            dtype,
            "size":
            size,
            "d":
            d
        } for dtype in all_dtypes for size in [9, 10, 101, 102]
                            for d in [0.1, 2.]))
    def testFftfreq(self, size, d, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng([size], dtype), )
        jnp_op = jnp.fft.fftfreq
        np_op = np.fft.fftfreq
        jnp_fn = lambda a: jnp_op(size, d=d)
        np_fn = lambda a: np_op(size, d=d)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if dtype in inexact_dtypes:
            tol = 0.15  # TODO(skye): can we be more precise?
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_n={}".format(n),
            "n": n
        } for n in [[0, 1, 2]]))
    def testFftfreqErrors(self, n):
        name = 'fftfreq'
        func = jnp.fft.fftfreq
        self.assertRaisesRegex(
            ValueError,
            "The n argument of jax.numpy.fft.{} only takes an int. "
            "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n))
        self.assertRaisesRegex(
            ValueError,
            "The d argument of jax.numpy.fft.{} only takes a single value. "
            "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_size={}_d={}".format(
                jtu.format_shape_dtype_string([size], dtype), d),
            "dtype":
            dtype,
            "size":
            size,
            "d":
            d
        } for dtype in all_dtypes for size in [9, 10, 101, 102]
                            for d in [0.1, 2.]))
    def testRfftfreq(self, size, d, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng([size], dtype), )
        jnp_op = jnp.fft.rfftfreq
        np_op = np.fft.rfftfreq
        jnp_fn = lambda a: jnp_op(size, d=d)
        np_fn = lambda a: np_op(size, d=d)
        # Numpy promotes to complex128 aggressively.
        self._CheckAgainstNumpy(np_fn,
                                jnp_fn,
                                args_maker,
                                check_dtypes=False,
                                tol=1e-4)
        self._CompileAndCheck(jnp_fn, args_maker)
        # Test gradient for differentiable types.
        if dtype in inexact_dtypes:
            tol = 0.15  # TODO(skye): can we be more precise?
            jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_n={}".format(n),
            "n": n
        } for n in [[0, 1, 2]]))
    def testRfftfreqErrors(self, n):
        name = 'rfftfreq'
        func = jnp.fft.rfftfreq
        self.assertRaisesRegex(
            ValueError,
            "The n argument of jax.numpy.fft.{} only takes an int. "
            "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n))
        self.assertRaisesRegex(
            ValueError,
            "The d argument of jax.numpy.fft.{} only takes a single value. "
            "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "dtype={}_axes={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), axes),
                "dtype":
                dtype,
                "shape":
                shape,
                "axes":
                axes
            } for dtype in all_dtypes
            for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
            for axes in _get_fftn_test_axes(shape)))
    def testFftshift(self, shape, dtype, axes):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes)
        np_fn = lambda arg: np.fft.fftshift(arg, axes=axes)
        self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "dtype={}_axes={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), axes),
                "dtype":
                dtype,
                "shape":
                shape,
                "axes":
                axes
            } for dtype in all_dtypes
            for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]]
            for axes in _get_fftn_test_axes(shape)))
    def testIfftshift(self, shape, dtype, axes):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: (rng(shape, dtype), )
        jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes)
        np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
        self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
Example #23
0
class HigherOrderPrimitiveTest(jtu.JaxTestCase):
    def test_core_call_primitive_inherits_effects(self):
        def f(x):
            @lu.wrap_init
            def f_(x):
                effect_p.bind(effect='foo')
                effect_p.bind(effect='bar')
                return [x]

            return core.call(f_, x)[0]

        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            jax.make_jaxpr(f)(2.)

    def test_xla_call_primitive_inherits_effects(self):
        @jax.jit
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            jax.make_jaxpr(f)(2.)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_{flavor}", flavor=flavor)
            for flavor in ["old", "new"]))
    def test_remat_call_primitive_inherits_effects(self, flavor):
        remat = jax.remat if flavor == "old" else ad_checkpoint.checkpoint

        @remat
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            jax.make_jaxpr(f)(2.)

    def test_custom_jvp_primitive_inherits_effects(self):
        @jax.custom_jvp
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f.defjvp(lambda x, t: (x, t))
        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            jax.make_jaxpr(f)(2.)

    def test_custom_vjp_primitive_inherits_effects(self):
        @jax.custom_vjp
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f.defvjp(fwd=lambda x: (x, ()), bwd=lambda _, g: g)
        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            jax.make_jaxpr(f)(2.)

    def test_pmap_inherits_effects(self):
        @jax.pmap
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        with self.assertRaisesRegex(
                ValueError,
                "Ordered effects not supported for map primitives: {'foo'}"):
            jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))

    def test_xmap_inherits_effects(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
        jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
        self.assertSetEqual(jaxpr.effects, {"foo", "bar"})

    def test_pjit_inherits_effects(self):
        def f(x):
            effect_p.bind(effect='foo')
            effect_p.bind(effect='bar')
            return x

        f = pjit.pjit(f,
                      in_axis_resources=pjit.PartitionSpec('x'),
                      out_axis_resources=pjit.PartitionSpec('x'))
        with self.assertRaisesRegex(NotImplementedError,
                                    'Effects not supported'):
            with maps.Mesh(np.array(jax.devices()), ['x']):
                jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
Example #24
0
class NdimageTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_{}_coordinates={}_order={}_mode={}_cval={}_impl={}_round={}".
                format(
                    jtu.format_shape_dtype_string(shape, dtype),
                    jtu.format_shape_dtype_string(coords_shape, coords_dtype),
                    order, mode, cval, impl, round_),
                "rng_factory":
                rng_factory,
                "shape":
                shape,
                "coords_shape":
                coords_shape,
                "dtype":
                dtype,
                "coords_dtype":
                coords_dtype,
                "order":
                order,
                "mode":
                mode,
                "cval":
                cval,
                "impl":
                impl,
                "round_":
                round_
            } for shape in [(5, ), (3, 4), (3, 4, 5)]
            for coords_shape in [(7, ), (2, 3, 4)]
            for dtype in float_dtypes + int_dtypes
            for coords_dtype in float_dtypes for order in [0, 1]
            for mode in ['wrap', 'constant', 'nearest', 'mirror', 'reflect']
            for cval in ([0, -1] if mode == 'constant' else [0])
            for impl, rng_factory in [
                ("original", partial(jtu.rand_uniform, low=0, high=1)),
                ("fixed", partial(jtu.rand_uniform, low=-0.75, high=1.75)),
            ] for round_ in [True, False]))
    def testMapCoordinates(self, shape, dtype, coords_shape, coords_dtype,
                           order, mode, cval, impl, round_, rng_factory):
        def args_maker():
            x = np.arange(prod(shape), dtype=dtype).reshape(shape)
            coords = [(size - 1) * rng(coords_shape, coords_dtype)
                      for size in shape]
            if round_:
                coords = [c.round().astype(int) for c in coords]
            return x, coords

        rng = rng_factory(self.rng())
        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(
            x, c, order=order, mode=mode, cval=cval)
        impl_fun = (osp_ndimage.map_coordinates
                    if impl == "original" else _fixed_ref_map_coordinates)
        osp_op = lambda x, c: impl_fun(x, c, order=order, mode=mode, cval=cval)

        with jtu.strict_promotion_if_dtypes_match(
            [dtype, int if round else coords_dtype]):
            if dtype in float_dtypes:
                epsilon = max(
                    dtypes.finfo(dtypes.canonicalize_dtype(d)).eps
                    for d in [dtype, coords_dtype])
                self._CheckAgainstNumpy(osp_op,
                                        lsp_op,
                                        args_maker,
                                        tol=100 * epsilon)
            else:
                self._CheckAgainstNumpy(osp_op, lsp_op, args_maker, tol=0)

    def testMapCoordinatesErrors(self):
        x = np.arange(5.0)
        c = [np.linspace(0, 5, num=3)]
        with self.assertRaisesRegex(NotImplementedError, 'requires order<=1'):
            lsp_ndimage.map_coordinates(x, c, order=2)
        with self.assertRaisesRegex(NotImplementedError,
                                    'does not yet support mode'):
            lsp_ndimage.map_coordinates(x, c, order=1, mode='grid-wrap')
        with self.assertRaisesRegex(ValueError, 'sequence of length'):
            lsp_ndimage.map_coordinates(x, [c, c], order=1)

    def testMapCoordinateDocstring(self):
        self.assertIn("Only nearest neighbor",
                      lsp_ndimage.map_coordinates.__doc__)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_{np.dtype(dtype)}_order={order}",
            "dtype": dtype,
            "order": order
        } for dtype in float_dtypes + int_dtypes for order in [0, 1]))
    def testMapCoordinatesRoundHalf(self, dtype, order):
        x = np.arange(-3, 3, dtype=dtype)
        c = np.array([[.5, 1.5, 2.5, 3.5]])

        def args_maker():
            return x, c

        lsp_op = lambda x, c: lsp_ndimage.map_coordinates(x, c, order=order)
        osp_op = lambda x, c: osp_ndimage.map_coordinates(x, c, order=order)

        with jtu.strict_promotion_if_dtypes_match([dtype, c.dtype]):
            self._CheckAgainstNumpy(osp_op, lsp_op, args_maker)

    def testContinuousGradients(self):
        # regression test for https://github.com/google/jax/issues/3024

        def loss(delta):
            x = np.arange(100.0)
            border = 10
            indices = np.arange(x.size, dtype=x.dtype) + delta
            # linear interpolation of the linear function y=x should be exact
            shifted = lsp_ndimage.map_coordinates(x, [indices], order=1)
            return ((x - shifted)**2)[border:-border].mean()

        # analytical gradient of (x - (x - delta)) ** 2 is 2 * delta
        self.assertAllClose(grad(loss)(0.5), 1.0, check_dtypes=False)
        self.assertAllClose(grad(loss)(1.0), 2.0, check_dtypes=False)
Example #25
0
class TestLBFGS(jtu.JaxTestCase):

  @parameterized.named_parameters(jtu.cases_from_list(
    {"testcase_name": f"_func={func_and_init[0].__name__}_maxiter={maxiter}",
     "maxiter": maxiter, "func_and_init": func_and_init}
    for maxiter in [None]
    for func_and_init in [(rosenbrock, np.zeros(2)),
                          (himmelblau, np.zeros(2)),
                          (matyas, np.ones(2) * 6.),
                          (eggholder, np.ones(2) * 100.)]))
  def test_minimize(self, maxiter, func_and_init):

    func, x0 = func_and_init

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          func(jnp),
          x0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(maxiter=maxiter, gtol=1e-7),
      )
      return result.x

    jax_res = min_op(x0)

    # Note that without bounds, L-BFGS-B is just L-BFGS
    with jtu.ignore_warning(category=DeprecationWarning,
                            message=".*tostring.*is deprecated.*"):
      scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x

    if func.__name__ == 'matyas':
      # scipy performs badly for Matyas, compare to true minimum instead
      self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7)
      return

    if func.__name__ == 'eggholder':
      # L-BFGS performs poorly for the eggholder function.
      # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results
      self.assertAllClose(jax_res, scipy_res, atol=1e-3)
      return

    self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False)

  def test_minimize_complex_sphere(self):
    z0 = jnp.array([1., 2. - 3.j, 4., -5.j])

    def f(z):
      return jnp.real(jnp.dot(jnp.conj(z - z0), z - z0))

    @jit
    def min_op(x0):
      result = jax.scipy.optimize.minimize(
          f,
          x0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(gtol=1e-6),
      )
      return result.x

    jax_res = min_op(jnp.zeros_like(z0))

    self.assertAllClose(jax_res, z0)

  def test_complex_rosenbrock(self):
    complex_dim = 5

    f_re = rosenbrock(jnp)
    init_re = jnp.zeros((2 * complex_dim,))
    expect_re = jnp.ones((2 * complex_dim,))

    def f(z):
      x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)])
      return f_re(x_re)

    init = init_re[:complex_dim] + 1.j * init_re[complex_dim:]
    expect = expect_re[:complex_dim] + 1.j * expect_re[complex_dim:]

    @jit
    def min_op(z0):
      result = jax.scipy.optimize.minimize(
          f,
          z0,
          method='l-bfgs-experimental-do-not-rely-on-this',
          options=dict(gtol=1e-6),
      )
      return result.x

    jax_res = min_op(init)
    self.assertAllClose(jax_res, expect, atol=2e-5)
Example #26
0
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
  """Tests for LAX-backed scipy.stats implementations"""

  @genNamedParametersNArgs(3)
  def testPoissonLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.logpmf
    lax_fun = lsp_stats.poisson.logpmf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
      loc = np.floor(loc)
      return [k, mu, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})

  @genNamedParametersNArgs(3)
  def testPoissonPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.pmf
    lax_fun = lsp_stats.poisson.pmf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
      loc = np.floor(loc)
      return [k, mu, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testPoissonCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.cdf
    lax_fun = lsp_stats.poisson.cdf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
      return [k, mu, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testBernoulliLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.bernoulli.logpmf
    lax_fun = lsp_stats.bernoulli.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = np.floor(x)
      p = expit(logit)
      loc = np.floor(loc)
      return [x, p, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testGeomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.geom.logpmf
    lax_fun = lsp_stats.geom.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = np.floor(x)
      p = expit(logit)
      loc = np.floor(loc)
      return [x, p, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(5)
  def testBetaLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, a, b, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker,
                            rtol={np.float32: 2e-3, np.float64: 1e-4})

  def testBetaLogPdfZero(self):
    # Regression test for https://github.com/google/jax/issues/7645
    a = b = 1.
    x = np.array([0., 1.])
    self.assertAllClose(
      osp_stats.beta.pdf(x, a, b), lsp_stats.beta.pdf(x, a, b), atol=1E-6)

  @genNamedParametersNArgs(3)
  def testCauchyLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @parameterized.named_parameters(
    jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes),
        "shapes": [x_shape, alpha_shape], "dtypes": dtypes}
      for x_shape in one_and_two_dim_shapes
      for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)]
      for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2)
  ))
  def testDirichletLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())

    def _normalize(x, alpha):
      x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
      return (x / x_norm).astype(x.dtype), alpha

    def lax_fun(x, alpha):
      return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))

    def scipy_fun(x, alpha):
      # scipy validates the x normalization using float64 arithmetic, so we must
      # cast x to float64 before normalization to ensure this passes.
      x, alpha = _normalize(x.astype('float64'), alpha)

      result = osp_stats.dirichlet.logpdf(x, alpha)
      # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
      # of a consistent rank. This check ensures the results have the same shape.
      return result if x.ndim == 1 else np.atleast_1d(result)

    def args_maker():
      # Don't normalize here, because we want normalization to happen at 64-bit
      # precision in the scipy version.
      x, alpha = map(rng, shapes, dtypes)
      return x, alpha

    tol = {np.float32: 1E-3, np.float64: 1e-5}

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=tol)
      self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)

  @genNamedParametersNArgs(3)
  def testExponLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.expon.logpdf
    lax_fun = lsp_stats.expon.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(4)
  def testGammaLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.gamma.logpdf
    lax_fun = lsp_stats.gamma.logpdf

    def args_maker():
      x, a, loc, scale = map(rng, shapes, dtypes)
      return [x, a, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  def testGammaLogPdfZero(self):
    # Regression test for https://github.com/google/jax/issues/7256
    self.assertAllClose(
      osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6)

  @genNamedParametersNArgs(2)
  def testGenNormLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.gennorm.logpdf
    lax_fun = lsp_stats.gennorm.logpdf

    def args_maker():
      x, p = map(rng, shapes, dtypes)
      return [x, p]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4, rtol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(2)
  def testGenNormCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.gennorm.cdf
    lax_fun = lsp_stats.gennorm.cdf

    def args_maker():
      x, p = map(rng, shapes, dtypes)
      return [x, p]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4, rtol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(4)
  def testNBinomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.nbinom.logpmf
    lax_fun = lsp_stats.nbinom.logpmf

    def args_maker():
      k, n, logit, loc = map(rng, shapes, dtypes)
      k = np.floor(np.abs(k))
      n = np.ceil(np.abs(n))
      p = expit(logit)
      loc = np.floor(loc)
      return [k, n, p, loc]

    tol = {np.float32: 1e-6, np.float64: 1e-8}

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

  @genNamedParametersNArgs(3)
  def testLaplaceLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.laplace.logpdf
    lax_fun = lsp_stats.laplace.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testLaplaceCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.laplace.cdf
    lax_fun = lsp_stats.laplace.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # ensure that scale is not too low
      scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol={np.float32: 1e-5, np.float64: 1e-6})
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.cdf
    lax_fun = lsp_stats.logistic.cdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-6)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticLogpdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.logpdf
    lax_fun = lsp_stats.logistic.logpdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker)

  def testLogisticLogpdfOverflow(self):
    # Regression test for https://github.com/google/jax/issues/10219
    self.assertAllClose(
      np.array([-100, -100], np.float32),
      lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)),
      check_dtypes=False)

  @genNamedParametersNArgs(1)
  def testLogisticPpf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.ppf
    lax_fun = lsp_stats.logistic.ppf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticSf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.sf
    lax_fun = lsp_stats.logistic.sf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testNormLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.logpdf
    lax_fun = lsp_stats.norm.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testNormLogCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.logcdf
    lax_fun = lsp_stats.norm.logcdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testNormCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.cdf
    lax_fun = lsp_stats.norm.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-6)
      self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testNormPpf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.ppf
    lax_fun = lsp_stats.norm.ppf

    def args_maker():
      q, loc, scale = map(rng, shapes, dtypes)
      # ensure probability is between 0 and 1:
      q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [q, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)


  @genNamedParametersNArgs(4)
  def testParetoLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.pareto.logpdf
    lax_fun = lsp_stats.pareto.logpdf

    def args_maker():
      x, b, loc, scale = map(rng, shapes, dtypes)
      return [x, b, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(4)
  def testTLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.t.logpdf
    lax_fun = lsp_stats.t.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
      return [x, df, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-3)
      self._CompileAndCheck(lax_fun, args_maker,
                            rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})


  @genNamedParametersNArgs(3)
  def testUniformLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.uniform.logpdf
    lax_fun = lsp_stats.uniform.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, np.abs(scale)]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=1e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(4)
  def testChi2LogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.chi2.logpdf
    lax_fun = lsp_stats.chi2.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      return [x, df, loc, scale]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(5)
  def testBetaBinomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    lax_fun = lsp_stats.betabinom.logpmf

    def args_maker():
      k, n, a, b, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      n = np.ceil(n)
      a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype)
      b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype)
      loc = np.floor(loc)
      return [k, n, a, b, loc]

    with jtu.strict_promotion_if_dtypes_match(dtypes):
      scipy_fun = osp_stats.betabinom.logpmf
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
      self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)

  def testIssue972(self):
    self.assertAllClose(
      np.ones((4,), np.float32),
      lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)),
      check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_x={}_mean={}_cov={}".format(
          jtu.format_shape_dtype_string(x_shape, x_dtype),
          jtu.format_shape_dtype_string(mean_shape, mean_dtype)
          if mean_shape is not None else None,
          jtu.format_shape_dtype_string(cov_shape, cov_dtype)
          if cov_shape is not None else None),
       "x_shape": x_shape, "x_dtype": x_dtype,
       "mean_shape": mean_shape, "mean_dtype": mean_dtype,
       "cov_shape": cov_shape, "cov_dtype": cov_dtype}
      for x_shape, mean_shape, cov_shape in [
          # # These test cases cover default values for mean/cov, but we don't
          # # support those yet (and they seem not very valuable).
          # [(), None, None],
          # [(), (), None],
          # [(2,), None, None],
          # [(2,), (), None],
          # [(2,), (2,), None],
          # [(3, 2), (3, 2,), None],
          # [(5, 3, 2), (5, 3, 2,), None],

          [(), (), ()],
          [(3,), (), ()],
          [(3,), (3,), ()],
          [(3,), (3,), (3, 3)],
          [(3, 4), (4,), (4, 4)],
          [(2, 3, 4), (4,), (4, 4)],
      ]
      for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
      if (mean_shape is not None or mean_dtype == np.float32)
      and (cov_shape is not None or cov_dtype == np.float32)))
  def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
                                   mean_dtype, cov_shape, cov_dtype):
    rng = jtu.rand_default(self.rng())
    def args_maker():
      args = [rng(x_shape, x_dtype)]
      if mean_shape is not None:
        args.append(5 * rng(mean_shape, mean_dtype))
      if cov_shape is not None:
        if cov_shape == ():
          args.append(0.1 + rng(cov_shape, cov_dtype) ** 2)
        else:
          factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
          factor = rng(factor_shape, cov_dtype)
          args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
      return [a.astype(x_dtype) for a in args]

    self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
                            lsp_stats.multivariate_normal.logpdf,
                            args_maker, tol=1e-3, check_dtypes=False)
    self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
                          rtol=1e-4, atol=1e-4)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_x={}_mean={}_cov={}".format(
          jtu.format_shape_dtype_string(x_shape, x_dtype),
          jtu.format_shape_dtype_string(mean_shape, mean_dtype)
          if mean_shape is not None else None,
          jtu.format_shape_dtype_string(cov_shape, cov_dtype)
          if cov_shape is not None else None),
       "x_shape": x_shape, "x_dtype": x_dtype,
       "mean_shape": mean_shape, "mean_dtype": mean_dtype,
       "cov_shape": cov_shape, "cov_dtype": cov_dtype}
      for x_shape, mean_shape, cov_shape in [
          # These test cases are where scipy flattens things, which has
          # different batch semantics than some might expect, so we manually
          # vectorize scipy's outputs for the sake of testing.
          [(5, 3, 2), (5, 3, 2), (5, 3, 2, 2)],
          [(2,), (5, 3, 2), (5, 3, 2, 2)],
          [(5, 3, 2), (2,), (5, 3, 2, 2)],
          [(5, 3, 2), (5, 3, 2,), (2, 2)],
          [(1, 3, 2), (3, 2,), (5, 1, 2, 2)],
          [(5, 3, 2), (1, 2,), (2, 2)],
      ]
      for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
      if (mean_shape is not None or mean_dtype == np.float32)
      and (cov_shape is not None or cov_dtype == np.float32)))
  def testMultivariateNormalLogpdfBroadcasted(self, x_shape, x_dtype, mean_shape,
                                              mean_dtype, cov_shape, cov_dtype):
    rng = jtu.rand_default(self.rng())
    def args_maker():
      args = [rng(x_shape, x_dtype)]
      if mean_shape is not None:
        args.append(5 * rng(mean_shape, mean_dtype))
      if cov_shape is not None:
        if cov_shape == ():
          args.append(0.1 + rng(cov_shape, cov_dtype) ** 2)
        else:
          factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
          factor = rng(factor_shape, cov_dtype)
          args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
      return [a.astype(x_dtype) for a in args]

    osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf,
                           signature="(n),(n),(n,n)->()")

    self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf,
                            args_maker, tol=1e-3, check_dtypes=False)
    self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
                          rtol=1e-4, atol=1e-4)


  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": f"_ndim={ndim}_nbatch={nbatch}_dtype={dtype.__name__}",
       "ndim": ndim, "nbatch": nbatch, "dtype": dtype}
      for ndim in [2, 3]
      for nbatch in [1, 3, 5]
      for dtype in jtu.dtypes.floating))
  def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
    # Regression test for #5570
    rng = jtu.rand_default(self.rng())
    x = rng((nbatch, ndim), dtype)
    mean = 5 * rng((nbatch, ndim), dtype)
    factor = rng((nbatch, ndim, 2 * ndim), dtype)
    cov = factor @ factor.transpose(0, 2, 1)

    result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
    result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
    self.assertArraysEqual(result1, result2, check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
        "_inshape={}_outsize={}_weights={}_method={}_func={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outsize, weights, method, func),
       "dtype": dtype,
       "inshape": inshape,
       "outsize": outsize,
       "weights": weights,
       "method": method,
       "func": func}
      for inshape in [(50,), (3, 50), (2, 12)]
      for dtype in jtu.dtypes.floating
      for outsize in [None, 10]
      for weights in [False, True]
      for method in [None, "scott", "silverman", 1.5, "callable"]
      for func in [None, "evaluate", "logpdf", "pdf"]))
  def testKde(self, inshape, dtype, outsize, weights, method, func):
    if method == "callable":
      method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))

    def scipy_fun(dataset, points, w):
      w = np.abs(w) if weights else None
      kde = osp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
      if func is None:
        result = kde(points)
      else:
        result = getattr(kde, func)(points)
      # Note: the scipy implementation _always_ returns float64
      return result.astype(dtype)

    def lax_fun(dataset, points, w):
      w = jax.numpy.abs(w) if weights else None
      kde = lsp_stats.gaussian_kde(dataset, bw_method=method, weights=w)
      if func is None:
        result = kde(points)
      else:
        result = getattr(kde, func)(points)
      return result

    if outsize is None:
      outshape = inshape
    else:
      outshape = inshape[:-1] + (outsize,)
    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [
      rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
    self._CheckAgainstNumpy(
        scipy_fun, lax_fun, args_maker, tol={
            np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3,
            np.float64: 1e-14
        })
    self._CompileAndCheck(
        lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
       "dtype": dtype,
       "shape": shape}
      for shape in [(15,), (3, 15), (1, 12)]
      for dtype in jtu.dtypes.floating))
  def testKdeIntegrateGaussian(self, shape, dtype):
    def scipy_fun(dataset, weights):
      kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
      # Note: the scipy implementation _always_ returns float64
      return kde.integrate_gaussian(mean, covariance).astype(dtype)

    def lax_fun(dataset, weights):
      kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
      return kde.integrate_gaussian(mean, covariance)

    # Construct a random mean and positive definite covariance matrix
    rng = jtu.rand_default(self.rng())
    ndim = shape[0] if len(shape) > 1 else 1
    mean = rng(ndim, dtype)
    L = rng((ndim, ndim), dtype)
    L[np.triu_indices(ndim, 1)] = 0.0
    L[np.diag_indices(ndim)] = np.exp(np.diag(L)) + 0.01
    covariance = L @ L.T

    args_maker = lambda: [
      rng(shape, dtype), rng(shape[-1:], dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(
        lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
       "dtype": dtype,
       "shape": shape}
      for shape in [(15,), (12,)]
      for dtype in jtu.dtypes.floating))
  def testKdeIntegrateBox1d(self, shape, dtype):
    def scipy_fun(dataset, weights):
      kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
      # Note: the scipy implementation _always_ returns float64
      return kde.integrate_box_1d(-0.5, 1.5).astype(dtype)

    def lax_fun(dataset, weights):
      kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
      return kde.integrate_box_1d(-0.5, 1.5)

    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [
      rng(shape, dtype), rng(shape[-1:], dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(
        lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
       "dtype": dtype,
       "shape": shape}
      for shape in [(15,), (3, 15), (1, 12)]
      for dtype in jtu.dtypes.floating))
  def testKdeIntegrateKde(self, shape, dtype):
    def scipy_fun(dataset, weights):
      kde = osp_stats.gaussian_kde(dataset, weights=np.abs(weights))
      other = osp_stats.gaussian_kde(
        dataset[..., :-3] + 0.1, weights=np.abs(weights[:-3]))
      # Note: the scipy implementation _always_ returns float64
      return kde.integrate_kde(other).astype(dtype)

    def lax_fun(dataset, weights):
      kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
      other = lsp_stats.gaussian_kde(
        dataset[..., :-3] + 0.1, weights=jax.numpy.abs(weights[:-3]))
      return kde.integrate_kde(other)

    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [
      rng(shape, dtype), rng(shape[-1:], dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
                            tol={np.float32: 1e-3, np.float64: 1e-14})
    self._CompileAndCheck(
        lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
       "dtype": dtype,
       "shape": shape}
      for shape in [(15,), (3, 15), (1, 12)]
      for dtype in jtu.dtypes.floating))
  def testKdeResampleShape(self, shape, dtype):
    def resample(key, dataset, weights, *, shape):
      kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
      return kde.resample(key, shape=shape)

    rng = jtu.rand_default(self.rng())
    args_maker = lambda: [
      jax.random.PRNGKey(0), rng(shape, dtype), rng(shape[-1:], dtype)]

    ndim = shape[0] if len(shape) > 1 else 1

    args = args_maker()
    func = partial(resample, shape=())
    self._CompileAndCheck(
      func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
    result = func(*args)
    assert result.shape == (ndim,)

    func = partial(resample, shape=(4,))
    self._CompileAndCheck(
      func, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
    result = func(*args)
    assert result.shape == (ndim, 4)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [shape], [dtype]),
       "dtype": dtype,
       "shape": shape}
      for shape in [(15,), (1, 12)]
      for dtype in jtu.dtypes.floating))
  def testKdeResample1d(self, shape, dtype):
    rng = jtu.rand_default(self.rng())
    dataset = rng(shape, dtype)
    weights = jax.numpy.abs(rng(shape[-1:], dtype))
    kde = lsp_stats.gaussian_kde(dataset, weights=weights)
    samples = jax.numpy.squeeze(kde.resample(jax.random.PRNGKey(5), shape=(1000,)))

    def cdf(x):
      result = jax.vmap(partial(kde.integrate_box_1d, -np.inf))(x)
      # Manually casting to numpy in order to avoid type promotion error
      return np.array(result)

    self.assertGreater(osp_stats.kstest(samples, cdf).pvalue, 0.01)

  def testKdePyTree(self):
    @jax.jit
    def evaluate_kde(kde, x):
      return kde.evaluate(x)

    dtype = np.float32
    rng = jtu.rand_default(self.rng())
    dataset = rng((3, 15), dtype)
    x = rng((3, 12), dtype)
    kde = lsp_stats.gaussian_kde(dataset)
    leaves, treedef = tree_util.tree_flatten(kde)
    kde2 = tree_util.tree_unflatten(treedef, leaves)
    tree_util.tree_map(lambda a, b: self.assertAllClose(a, b), kde, kde2)
    self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
Example #27
0
class LaxBackedScipySignalTests(jtu.JaxTestCase):
    """Tests for LAX-backed scipy.stats implementations"""
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_op={}_xshape={}_yshape={}_mode={}".format(
                    op, jtu.format_shape_dtype_string(xshape, dtype),
                    jtu.format_shape_dtype_string(yshape, dtype), mode),
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "mode":
                mode,
                "jsp_op":
                getattr(jsp_signal, op),
                "osp_op":
                getattr(osp_signal, op)
            } for mode in ['full', 'same', 'valid']
            for op in ['convolve', 'correlate'] for dtype in default_dtypes
            for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes]
            for xshape in shapeset for yshape in shapeset))
    def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {
            np.float16: 1e-2,
            np.float32: 1e-2,
            np.float64: 1e-12,
            np.complex64: 1e-2,
            np.complex128: 1e-12
        }
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "op={}_xshape={}_yshape={}_mode={}".format(
                op, jtu.format_shape_dtype_string(xshape, dtype),
                jtu.format_shape_dtype_string(yshape, dtype), mode),
            "xshape":
            xshape,
            "yshape":
            yshape,
            "dtype":
            dtype,
            "mode":
            mode,
            "jsp_op":
            getattr(jsp_signal, op),
            "osp_op":
            getattr(osp_signal, op)
        } for mode in ['full', 'same', 'valid']
                            for op in ['convolve2d', 'correlate2d']
                            for dtype in default_dtypes
                            for xshape in twodim_shapes
                            for yshape in twodim_shapes))
    def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]
        osp_fun = partial(osp_op, mode=mode)
        jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST)
        tol = {
            np.float16: 1e-2,
            np.float32: 1e-2,
            np.float64: 1e-12,
            np.complex64: 1e-2,
            np.complex128: 1e-12
        }
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                check_dtypes=False,
                                tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_axis={}_type={}_bp={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "type":
            type,
            "bp":
            bp
        } for shape in [(5, ), (4, 5), (3, 4, 5)]
                            for dtype in jtu.dtypes.floating +
                            jtu.dtypes.integer for axis in [0, -1]
                            for type in ['constant', 'linear']
                            for bp in [0, [0, 2]]))
    def testDetrend(self, shape, dtype, axis, type, bp):
        signal = np.random.normal(loc=2, size=shape)

        if type == 'constant':
            trend = np.ones_like(signal)
        elif type == 'linear':
            trend = np.linspace(-np.pi, np.pi, shape[0])
            if len(shape) == 1:
                trend = np.broadcast_to(trend, shape)
            elif len(shape) == 2:
                trend = np.broadcast_to(trend[:, None], shape)
            elif len(shape) == 3:
                trend = np.broadcast_to(trend[:, None, None], shape)

        args_maker = lambda: [signal, trend]

        def osp_fun(signal, trend):
            return osp_signal.detrend(
                signal + trend, axis=axis, type=type, bp=bp) - trend

        def jsp_fun(signal, noise):
            return jsp_signal.detrend(
                signal + noise, axis=axis, type=type, bp=bp) - trend

        if jtu.device_under_test() == 'tpu':
            tol = {np.float32: 3e-2, np.float64: 1e-12}
        else:
            tol = {np.float32: 1e-5, np.float64: 1e-12}

        self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.
        cases_from_list({
            "testcase_name":
            f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
            f"_fs={fs}_window={window}_boundary={boundary}_detrend={detrend}"
            f"_padded={padded}_nperseg={nperseg}_noverlap={noverlap}"
            f"_axis={timeaxis}_nfft={nfft}",
            "shape":
            shape,
            "dtype":
            dtype,
            "fs":
            fs,
            "window":
            window,
            "nperseg":
            nperseg,
            "noverlap":
            noverlap,
            "nfft":
            nfft,
            "detrend":
            detrend,
            "boundary":
            boundary,
            "padded":
            padded,
            "timeaxis":
            timeaxis
        } for shape, nperseg, noverlap, timeaxis in stft_test_shapes
                        for dtype in default_dtypes for fs in [1.0, 16000.0]
                        for window in
                        ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
                        for nfft in
                        [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
                        for detrend in ['constant', 'linear', False]
                        for boundary in [None, 'even', 'odd', 'zeros']
                        for padded in [True, False]))
    def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                             noverlap, nfft, detrend, boundary, padded,
                             timeaxis):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            return

        osp_fun = partial(osp_signal.stft,
                          fs=fs,
                          window=window,
                          nfft=nfft,
                          boundary=boundary,
                          padded=padded,
                          detrend=detrend,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          axis=timeaxis,
                          return_onesided=not is_complex)
        jsp_fun = partial(jsp_signal.stft,
                          fs=fs,
                          window=window,
                          nfft=nfft,
                          boundary=boundary,
                          padded=padded,
                          detrend=detrend,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          axis=timeaxis,
                          return_onesided=not is_complex)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    # Tests with `average == 'median'`` is excluded from `testCsd*`
    # due to the issue:
    #   https://github.com/scipy/scipy/issues/15601
    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_xshape={jtu.format_shape_dtype_string(xshape, dtype)}"
                f"_yshape={jtu.format_shape_dtype_string(yshape, dtype)}"
                f"_average={average}_scaling={scaling}_nfft={nfft}"
                f"_fs={fs}_window={window}_detrend={detrend}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_axis={timeaxis}",
                "xshape":
                xshape,
                "yshape":
                yshape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            }
            for xshape, yshape, nperseg, noverlap, timeaxis in csd_test_shapes
            for dtype in default_dtypes for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for scaling in ['density', 'spectrum'] for average in ['mean']))
    def testCsdAgainstNumpy(self, *, xshape, yshape, dtype, fs, window,
                            nperseg, noverlap, nfft, detrend, scaling,
                            timeaxis, average):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            raise unittest.SkipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        osp_fun = partial(osp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        jsp_fun = partial(jsp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_average={average}_scaling={scaling}_nfft={nfft}"
                f"_fs={fs}_window={window}_detrend={detrend}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_axis={timeaxis}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            } for shape, unused_yshape, nperseg, noverlap, timeaxis in
            csd_test_shapes for dtype in default_dtypes
            for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for scaling in ['density', 'spectrum'] for average in ['mean']))
    def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window,
                                         nperseg, noverlap, nfft, detrend,
                                         scaling, timeaxis, average):
        is_complex = np.dtype(dtype).kind == 'c'
        if is_complex and detrend is not None:
            raise unittest.SkipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        def osp_fun(x, y):
            # When the identical parameters are given, jsp-version follows
            # the behavior with copied parameters.
            freqs, Pxy = osp_signal.csd(x,
                                        y.copy(),
                                        fs=fs,
                                        window=window,
                                        nperseg=nperseg,
                                        noverlap=noverlap,
                                        nfft=nfft,
                                        detrend=detrend,
                                        return_onesided=not is_complex,
                                        scaling=scaling,
                                        axis=timeaxis,
                                        average=average)
            return freqs, Pxy

        jsp_fun = partial(jsp_signal.csd,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=not is_complex,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)

        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)] * 2

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_fs={fs}_window={window}"
                f"_nperseg={nperseg}_noverlap={noverlap}_nfft={nfft}"
                f"_detrend={detrend}_return_onesided={return_onesided}"
                f"_scaling={scaling}_axis={timeaxis}_average={average}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "detrend":
                detrend,
                "return_onesided":
                return_onesided,
                "scaling":
                scaling,
                "timeaxis":
                timeaxis,
                "average":
                average
            } for shape, nperseg, noverlap, timeaxis in welch_test_shapes
            for dtype in default_dtypes for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for detrend in ['constant', 'linear', False]
            for return_onesided in [True, False]
            for scaling in ['density', 'spectrum']
            for average in ['mean', 'median']))
    def testWelchAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, detrend, return_onesided,
                              scaling, timeaxis, average):
        if np.dtype(dtype).kind == 'c':
            return_onesided = False
            if detrend is not None:
                raise unittest.SkipTest(
                    "Complex signal is not supported in lax-backed `signal.detrend`."
                )

        osp_fun = partial(osp_signal.welch,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=return_onesided,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        jsp_fun = partial(jsp_signal.welch,
                          fs=fs,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          detrend=detrend,
                          return_onesided=return_onesided,
                          scaling=scaling,
                          axis=timeaxis,
                          average=average)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_nperseg={nperseg}_noverlap={noverlap}"
                f"_use_nperseg={use_nperseg}_use_overlap={use_noverlap}"
                f"_axis={timeaxis}",
                "shape":
                shape,
                "dtype":
                dtype,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "use_nperseg":
                use_nperseg,
                "use_noverlap":
                use_noverlap,
                "timeaxis":
                timeaxis
            } for shape, nperseg, noverlap, timeaxis in welch_test_shapes
            for use_nperseg in [False, True] for use_noverlap in [False, True]
            for dtype in jtu.dtypes.floating + jtu.dtypes.integer))
    def testWelchWithDefaultStepArgsAgainstNumpy(self, *, shape, dtype,
                                                 nperseg, noverlap,
                                                 use_nperseg, use_noverlap,
                                                 timeaxis):
        kwargs = {'axis': timeaxis}

        if use_nperseg:
            kwargs['nperseg'] = nperseg
        else:
            kwargs['window'] = osp_signal.get_window('hann', nperseg)
        if use_noverlap:
            kwargs['noverlap'] = noverlap

        osp_fun = partial(osp_signal.welch, **kwargs)
        jsp_fun = partial(jsp_signal.welch, **kwargs)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                f"_shape={jtu.format_shape_dtype_string(shape, dtype)}"
                f"_fs={fs}_window={window}_boundary={boundary}"
                f"_nperseg={nperseg}_noverlap={noverlap}_onesided={onesided}"
                f"_timeaxis={timeaxis}_freqaxis{freqaxis}_nfft={nfft}",
                "shape":
                shape,
                "dtype":
                dtype,
                "fs":
                fs,
                "window":
                window,
                "nperseg":
                nperseg,
                "noverlap":
                noverlap,
                "nfft":
                nfft,
                "onesided":
                onesided,
                "boundary":
                boundary,
                "timeaxis":
                timeaxis,
                "freqaxis":
                freqaxis
            } for shape, nperseg, noverlap, timeaxis, freqaxis in
            istft_test_shapes for dtype in default_dtypes
            for fs in [1.0, 16000.0]
            for window in ['boxcar', 'triang', 'blackman', 'hamming', 'hann']
            for nfft in [None, nperseg,
                         int(nperseg * 1.5), nperseg * 2]
            for onesided in [False, True] for boundary in [False, True]))
    def testIstftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                              noverlap, nfft, onesided, boundary, timeaxis,
                              freqaxis):
        if not onesided:
            new_freq_len = (shape[freqaxis] - 1) * 2
            shape = shape[:freqaxis] + (new_freq_len, ) + shape[freqaxis + 1:]

        def osp_fun(x, fs):
            # Ignore UserWarning in osp so we can also test over ill-posed cases.
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                result = osp_signal.istft(x,
                                          fs=fs,
                                          window=window,
                                          nperseg=nperseg,
                                          noverlap=noverlap,
                                          nfft=nfft,
                                          input_onesided=onesided,
                                          boundary=boundary,
                                          time_axis=timeaxis,
                                          freq_axis=freqaxis)
            return result

        jsp_fun = partial(jsp_signal.istft,
                          window=window,
                          nperseg=nperseg,
                          noverlap=noverlap,
                          nfft=nfft,
                          input_onesided=onesided,
                          boundary=boundary,
                          time_axis=timeaxis,
                          freq_axis=freqaxis)

        tol = {
            np.float32: 1e-4,
            np.float64: 1e-6,
            np.complex64: 1e-4,
            np.complex128: 1e-6
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        rng_fs = jtu.rand_uniform(self.rng(), 1.0, 16000.0)
        args_maker = lambda: [rng(shape, dtype), rng_fs((), np.float)]

        # Here, dtype of output signal is different depending on osp versions,
        # and so depending on the test environment.  Thus, dtype check is disabled.
        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol,
                                check_dtypes=False)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Example #28
0
class TestLineSearch(jtu.JaxTestCase):
    # -- scalar functions; must have dphi(0.) < 0

    def assert_wolfe(self, s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""):
        """
    Check that strong Wolfe conditions apply
    """
        phi1 = phi(s)
        phi0 = phi(0)
        derphi0 = derphi(0)
        derphi1 = derphi(s)
        msg = "s = %s; phi(0) = %s; phi(s) = %s; phi'(0) = %s; phi'(s) = %s; %s" % (
            s, phi0, phi1, derphi0, derphi1, err_msg)

        self.assertTrue(phi1 <= phi0 + c1 * s * derphi0,
                        "Wolfe 1 failed: " + msg)
        self.assertTrue(
            abs(derphi1) <= abs(c2 * derphi0), "Wolfe 2 failed: " + msg)

    def assert_line_wolfe(self, x, p, s, f, fprime, **kw):
        self.assert_wolfe(s,
                          phi=lambda sp: f(x + p * sp),
                          derphi=lambda sp: jnp.dot(fprime(x + p * sp), p),
                          **kw)

    def _scalar_func_1(self, s):
        p = -s - s**3 + s**4
        dp = -1 - 3 * s**2 + 4 * s**3
        return p, dp

    def _scalar_func_2(self, s):
        p = jnp.exp(-4 * s) + s**2
        dp = -4 * jnp.exp(-4 * s) + 2 * s
        return p, dp

    def _scalar_func_3(self, s):
        p = -jnp.sin(10 * s)
        dp = -10 * jnp.cos(10 * s)
        return p, dp

    # -- n-d functions

    def _line_func_1(self, x):
        f = jnp.dot(x, x)
        df = 2 * x
        return f, df

    def _line_func_2(self, x):
        f = jnp.dot(x, jnp.dot(self.A, x)) + 1
        df = jnp.dot(self.A + self.A.T, x)
        return f, df

    # -- Generic scalar searches

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_name={}".format(name),
            "name": name
        } for name in ['_scalar_func_1', '_scalar_func_2', '_scalar_func_3']))
    def test_scalar_search_wolfe2(self, name):
        def bind_index(func, idx):
            # Remember Python's closure semantics!
            return lambda *a, **kw: func(*a, **kw)[idx]

        value = getattr(self, name)
        phi = bind_index(value, 0)
        derphi = bind_index(value, 1)
        for old_phi0 in self.rng().randn(3):
            res = line_search(phi, 0., 1.)
            s, phi1, derphi1 = res.a_k, res.f_k, res.g_k
            self.assertAllClose(phi1, phi(s), check_dtypes=False, atol=1e-6)
            if derphi1 is not None:
                self.assertAllClose(derphi1,
                                    derphi(s),
                                    check_dtypes=False,
                                    atol=1e-6)
            self.assert_wolfe(s,
                              phi,
                              derphi,
                              err_msg="%s %g" % (name, old_phi0))

    # -- Generic line searches

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_name={}".format(name),
            "name": name
        } for name in ['_line_func_1', '_line_func_2']))
    def test_line_search_wolfe2(self, name):
        def bind_index(func, idx):
            # Remember Python's closure semantics!
            return lambda *a, **kw: func(*a, **kw)[idx]

        value = getattr(self, name)
        f = bind_index(value, 0)
        fprime = bind_index(value, 1)

        k = 0
        N = 20
        rng = self.rng()
        # sets A in one of the line funcs
        self.A = self.rng().randn(N, N)
        while k < 9:
            x = rng.randn(N)
            p = rng.randn(N)
            if jnp.dot(p, fprime(x)) >= 0:
                # always pick a descent pk
                continue
            k += 1

            f0 = f(x)
            g0 = fprime(x)
            self.fcount = 0
            res = line_search(f, x, p, old_fval=f0, gfk=g0)
            s = res.a_k
            fv = res.f_k
            gv = res.g_k
            self.assertAllClose(fv,
                                f(x + s * p),
                                check_dtypes=False,
                                atol=1e-5)
            if gv is not None:
                self.assertAllClose(gv,
                                    fprime(x + s * p),
                                    check_dtypes=False,
                                    atol=1e-5)

    def test_line_search_wolfe2_bounds(self):
        # See gh-7475

        # For this f and p, starting at a point on axis 0, the strong Wolfe
        # condition 2 is met if and only if the step length s satisfies
        # |x + s| <= c2 * |x|
        f = lambda x: jnp.dot(x, x)
        fp = lambda x: 2 * x
        p = jnp.array([1, 0])

        # Smallest s satisfying strong Wolfe conditions for these arguments is 30
        x = -60 * p
        c2 = 0.5

        res = line_search(f, x, p, c2=c2)
        s = res.a_k
        # s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2)
        self.assert_line_wolfe(x, p, s, f, fp)
        self.assertTrue(s >= 30.)

        res = line_search(f, x, p, c2=c2, maxiter=5)
        self.assertTrue(res.failed)
        # s=30 will only be tried on the 6th iteration, so this won't converge

    def test_line_search(self):
        def f(x):
            return jnp.cos(jnp.sum(jnp.exp(-x))**2)

        # assert not line_search(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed
        xk = jnp.ones(2)
        pk = jnp.array([-0.5, -0.25])
        res = line_search(f, xk, pk, maxiter=100)

        scipy_res = line_search_wolfe2(f, grad(f), xk, pk)

        self.assertAllClose(scipy_res[0],
                            res.a_k,
                            atol=1e-5,
                            check_dtypes=False)
        self.assertAllClose(scipy_res[3],
                            res.f_k,
                            atol=1e-5,
                            check_dtypes=False)
Example #29
0
class CustomObjectTest(jtu.JaxTestCase):
    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_compile={}_primitive={}".format(compile, primitive),
            "compile":
            compile,
            "primitive":
            primitive
        } for primitive in [True, False] for compile in [True, False]))
    def testSparseIdentity(self, compile, primitive):
        f = identity if primitive else (lambda x: x)
        f = jit(f) if compile else f
        rng = jtu.rand_default(self.rng())
        M = make_sparse_array(rng, (10, ), jnp.float32)
        M2 = f(M)

        jaxpr = make_jaxpr(f)(M).jaxpr
        core.check_jaxpr(jaxpr)

        self.assertEqual(M.dtype, M2.dtype)
        self.assertEqual(M.index_dtype, M2.index_dtype)
        self.assertAllClose(M.data, M2.data)
        self.assertAllClose(M.indices, M2.indices)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_compile={}".format(compile),
            "compile": compile
        } for compile in [True, False]))
    def testSparseSplit(self, compile):
        f = jit(split) if compile else split
        rng = jtu.rand_default(self.rng())
        M = make_sparse_array(rng, (10, ), jnp.float32)
        M2, M3 = f(M)

        jaxpr = make_jaxpr(f)(M).jaxpr
        core.check_jaxpr(jaxpr)

        for MM in M2, M3:
            self.assertEqual(M.dtype, MM.dtype)
            self.assertEqual(M.index_dtype, MM.index_dtype)
            self.assertArraysEqual(M.data, MM.data)
            self.assertArraysEqual(M.indices, MM.indices)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_compile={}_primitive={}".format(compile, primitive),
            "compile":
            compile,
            "primitive":
            primitive
        } for primitive in [True, False] for compile in [True, False]))
    def testSparseLaxLoop(self, compile, primitive):
        rng = jtu.rand_default(self.rng())
        f = identity if primitive else (lambda x: x)
        f = jit(f) if compile else f
        body_fun = lambda _, A: f(A)
        M = make_sparse_array(rng, (10, ), jnp.float32)
        lax.fori_loop(0, 10, body_fun, M)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": "_attr={}".format(attr),
            "attr": attr
        } for attr in ["data", "indices"]))
    def testSparseAttrAccess(self, attr):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [make_sparse_array(rng, (10, ), jnp.float32)]
        f = lambda x: getattr(x, attr)
        self._CompileAndCheck(f, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
            "shape":
            shape,
            "dtype":
            dtype
        } for shape in [(3, 3), (2, 6), (6, 2)]
                            for dtype in jtu.dtypes.floating))
    def testSparseMatvec(self, shape, dtype):
        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [
            make_sparse_array(rng, shape, dtype),
            rng(shape[-1:], dtype)
        ]
        self._CompileAndCheck(matvec, args_maker)

    def testLowerToNothing(self):
        empty = Empty(AbstractEmpty())
        jaxpr = make_jaxpr(jit(lambda e: e))(empty).jaxpr
        core.check_jaxpr(jaxpr)

        # cannot return a unit, because CompileAndCheck assumes array output.
        testfunc = lambda e: None
        args_maker = lambda: [empty]
        self._CompileAndCheck(testfunc, args_maker)

    def testConstantHandler(self):
        def make_const_array():
            data = np.arange(3.0)
            indices = np.arange(3)[:, None]
            shape = (5, )
            aval = AbstractSparseArray(shape, data.dtype, indices.dtype,
                                       len(indices))
            return SparseArray(aval, data, indices)

        out1 = make_const_array()
        out2 = jit(make_const_array)()
        self.assertArraysEqual(out1.data, out2.data)
        self.assertArraysEqual(out1.indices, out2.indices)
Example #30
0
class SparsifyTest(jtu.JaxTestCase):
    @classmethod
    def sparsify(cls, f):
        return sparsify(f, use_tracer=False)

    def testNotImplementedMessages(self):
        x = BCOO.fromdense(jnp.arange(5.0))
        # Test a densifying primitive
        with self.assertRaisesRegex(
                NotImplementedError,
                r"^sparse rule for cos is not implemented because it would result in dense output\."
        ):
            self.sparsify(lax.cos)(x)

        # Test a generic not implemented primitive.
        with self.assertRaisesRegex(
                NotImplementedError,
                r"^sparse rule for complex is not implemented\.$"):
            self.sparsify(lax.complex)(x, x)

    def testTracerIsInstanceCheck(self):
        @self.sparsify
        def f(x):
            self.assertNotIsInstance(x, SparseTracer)

        f(jnp.arange(5))

    def assertBcooIdentical(self, x, y):
        self.assertIsInstance(x, BCOO)
        self.assertIsInstance(y, BCOO)
        self.assertEqual(x.shape, y.shape)
        self.assertArraysEqual(x.data, y.data)
        self.assertArraysEqual(x.indices, y.indices)

    def testSparsifyValue(self):
        X = jnp.arange(5)
        X_BCOO = BCOO.fromdense(X)

        args = (X, X_BCOO, X_BCOO)

        # Independent index
        spenv = SparsifyEnv()
        spvalues = arrays_to_spvalues(spenv, args)
        self.assertEqual(len(spvalues), len(args))
        self.assertLen(spenv._buffers, 5)
        self.assertEqual(
            spvalues,
            (SparsifyValue(X.shape, 0, None), SparsifyValue(
                X.shape, 1, 2), SparsifyValue(X.shape, 3, 4)))

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

        # Shared index
        spvalues = (SparsifyValue(X.shape, 0, None),
                    SparsifyValue(X.shape, 1, 2), SparsifyValue(X.shape, 3, 2))
        spenv = SparsifyEnv([X, X_BCOO.data, X_BCOO.indices, X_BCOO.data])

        args_out = spvalues_to_arrays(spenv, spvalues)
        self.assertEqual(len(args_out), len(args))
        self.assertArraysEqual(args[0], args_out[0])
        self.assertBcooIdentical(args[1], args_out[1])
        self.assertBcooIdentical(args[2], args_out[2])

    def testDropvar(self):
        def inner(x):
            return x * 2, x * 3

        def f(x):
            _, y = jit(inner)(x)
            return y * 4

        x_dense = jnp.arange(5)
        x_sparse = BCOO.fromdense(x_dense)
        self.assertArraysEqual(
            self.sparsify(f)(x_sparse).todense(), f(x_dense))

    def testPytreeInput(self):
        f = self.sparsify(lambda x: x)
        args = (jnp.arange(4), BCOO.fromdense(jnp.arange(4)))
        out = f(args)
        self.assertLen(out, 2)
        self.assertArraysEqual(args[0], out[0])
        self.assertBcooIdentical(args[1], out[1])

    def testSparsify(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)
        v = jnp.arange(M_dense.shape[0])

        @self.sparsify
        def func(x, v):
            return -jnp.sin(jnp.pi * x).T @ (v + 1)

        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(M_sparse, v)
        result_dense = func(M_dense, v)
        self.assertAllClose(result_sparse, result_dense)

    def testSparsifyWithConsts(self):
        M_dense = jnp.arange(24).reshape(4, 6)
        M_sparse = BCOO.fromdense(M_dense)

        @self.sparsify
        def func(x):
            return jit(lambda x: jnp.sum(x, 1))(x)

        result_dense = func(M_dense)
        result_sparse = func(M_sparse)

        self.assertAllClose(result_sparse.todense(), result_dense)

    def testSparseMatmul(self):
        X = jnp.arange(16).reshape(4, 4)
        Xsp = BCOO.fromdense(X)
        Y = jnp.ones(4)
        Ysp = BCOO.fromdense(Y)

        func = self.sparsify(operator.matmul)

        # dot_general
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(Xsp, Y)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse, result_dense)

        # rdot_general
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = func(Y, Xsp)
        result_dense = operator.matmul(Y, X)
        self.assertAllClose(result_sparse, result_dense)

        # spdot_general
        result_sparse = self.sparsify(operator.matmul)(Xsp, Ysp)
        result_dense = operator.matmul(X, Y)
        self.assertAllClose(result_sparse.todense(), result_dense)

    def testSparseAdd(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.add)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 3 * jnp.arange(5))

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.add)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() + y.todense())

    def testSparseMul(self):
        x = BCOO.fromdense(jnp.arange(5))
        y = BCOO.fromdense(2 * jnp.arange(5))

        # Scalar multiplication
        out = self.sparsify(operator.mul)(x, 2.5)
        self.assertArraysEqual(out.todense(), x.todense() * 2.5)

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.mul)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() * y.todense())

    def testSparseSubtract(self):
        x = BCOO.fromdense(3 * jnp.arange(5))
        y = BCOO.fromdense(jnp.arange(5))

        # Distinct indices
        out = self.sparsify(operator.sub)(x, y)
        self.assertEqual(out.nse, 8)  # uses concatenation.
        self.assertArraysEqual(out.todense(), 2 * jnp.arange(5))

        # Shared indices – requires lower level call
        spenv = SparsifyEnv([x.indices, x.data, y.data])
        spvalues = [
            spenv.sparse(x.shape, data_ref=1, indices_ref=0),
            spenv.sparse(y.shape, data_ref=2, indices_ref=0)
        ]

        result = sparsify_raw(operator.sub)(spenv, *spvalues)
        args_out, _ = result
        out, = spvalues_to_arrays(spenv, args_out)

        self.assertAllClose(out.todense(), x.todense() - y.todense())

    def testSparseSum(self):
        x = jnp.arange(20).reshape(4, 5)
        xsp = BCOO.fromdense(x)

        def f(x):
            return x.sum(), x.sum(0), x.sum(1), x.sum((0, 1))

        result_dense = f(x)
        result_sparse = self.sparsify(f)(xsp)

        assert len(result_dense) == len(result_sparse)

        for res_dense, res_sparse in zip(result_dense, result_sparse):
            if isinstance(res_sparse, BCOO):
                res_sparse = res_sparse.todense()
            self.assertArraysAllClose(res_dense, res_sparse)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dimensions={}_nbatch={}_ndense={}".format(
                jtu.format_shape_dtype_string(shape, np.float32), dimensions,
                n_batch, n_dense),
            "shape":
            shape,
            "dimensions":
            dimensions,
            "n_batch":
            n_batch,
            "n_dense":
            n_dense
        } for shape, dimensions in [
            [(1, ), (0, )],
            [(1, ), (-1, )],
            [(2, 1, 4), (1, )],
            [(2, 1, 3, 1), (1, )],
            [(2, 1, 3, 1), (1, 3)],
            [(2, 1, 3, 1), (3, )],
        ] for n_batch in range(len(shape) + 1)
                            for n_dense in range(len(shape) - n_batch + 1)))
    def testSparseSqueeze(self, shape, dimensions, n_batch, n_dense):
        rng = jtu.rand_default(self.rng())

        M_dense = rng(shape, np.float32)
        M_sparse = BCOO.fromdense(M_dense, n_batch=n_batch, n_dense=n_dense)
        func = self.sparsify(partial(lax.squeeze, dimensions=dimensions))

        result_dense = func(M_dense)
        result_sparse = func(M_sparse).todense()

        self.assertAllClose(result_sparse, result_dense)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name": f"_shapes={shapes}_func={func}_nbatch={n_batch}",
            "shapes": shapes,
            "func": func,
            "n_batch": n_batch
        } for shapes, func, n_batch in [
            ([(4, ), (4, )], "concatenate", 0),
            ([(4, ), (4, )], "stack", 0),
            ([(4, ), (4, )], "hstack", 0),
            ([(4, ), (4, )], "vstack", 0),
            ([(4, ), (4, )], "concatenate", 1),
            ([(4, ), (4, )], "stack", 1),
            ([(4, ), (4, )], "hstack", 1),
            ([(4, ), (4, )], "vstack", 1),
            ([(2, 4), (2, 4)], "stack", 0),
            ([(2, 4), (3, 4)], "vstack", 0),
            ([(2, 4), (2, 5)], "hstack", 0),
            ([(2, 4), (3, 4)], "vstack", 1),
            ([(2, 4), (2, 5)], "hstack", 1),
            ([(2, 4), (3, 4)], "vstack", 2),
            ([(2, 4), (2, 5)], "hstack", 2),
            ([(2, 4), (4, ), (3, 4)], "vstack", 0),
            ([(1, 4), (4, ), (1, 4)], "vstack", 0),
        ]))
    def testSparseConcatenate(self, shapes, func, n_batch):
        f = self.sparsify(getattr(jnp, func))
        rng = jtu.rand_some_zero(self.rng())
        arrs = [rng(shape, 'int32') for shape in shapes]
        sparrs = [BCOO.fromdense(arr, n_batch=n_batch) for arr in arrs]
        self.assertArraysEqual(f(arrs), f(sparrs).todense())

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}",
            "shape": shape,
            "new_shape": new_shape,
            "n_batch": n_batch,
            "n_dense": n_dense
        } for shape, new_shape, n_batch, n_dense in [
            [(6, ), (2, 3), 0, 0],
            [(1, 4), (2, 2), 0, 0],
            [(12, 2), (2, 3, 4), 0, 0],
            [(1, 3, 2), (2, 3), 0, 0],
            [(1, 6), (2, 3, 1), 0, 0],
            [(2, 3, 4), (3, 8), 0, 0],
            [(2, 3, 4), (1, 2, 12), 1, 0],
            [(2, 3, 4), (6, 2, 2), 2, 0],
        ]))
    def testSparseReshapeMethod(self, shape, new_shape, n_batch, n_dense):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        arr2 = arr.reshape(new_shape)
        arr2_sparse = arr_sparse.reshape(new_shape)

        self.assertArraysEqual(arr2, arr2_sparse.todense())

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_{shape}->{new_shape}_n_batch={n_batch}_n_dense={n_dense}_dimensions={dimensions}",
            "shape": shape,
            "new_shape": new_shape,
            "n_batch": n_batch,
            "n_dense": n_dense,
            "dimensions": dimensions
        } for shape, new_shape, n_batch, n_dense, dimensions in [
            [(2, 3, 4), (24, ), 0, 0, None],
            [(2, 3, 4), (24, ), 0, 0, (0, 1, 2)],
            [(2, 3, 4), (24, ), 0, 0, (0, 2, 1)],
            [(2, 3, 4), (24, ), 0, 0, (1, 0, 2)],
            [(2, 3, 4), (24, ), 0, 0, (1, 2, 0)],
            [(2, 3, 4), (24, ), 0, 0, (2, 0, 1)],
            [(2, 3, 4), (24, ), 0, 0, (2, 1, 0)],
            [(4, 2, 3), (2, 2, 6), 1, 0, (0, 1, 2)],
            [(4, 2, 3), (2, 2, 6), 1, 0, (0, 2, 1)],
            [(2, 3, 4), (6, 4), 2, 0, (0, 1, 2)],
            [(2, 3, 4), (6, 4), 2, 0, (1, 0, 2)],
        ]))
    def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch,
                                        n_dense, dimensions):
        rng = jtu.rand_some_zero(self.rng())
        arr = rng(shape, 'int32')
        arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense)

        f = self.sparsify(
            lambda x: lax.reshape(x, new_shape, dimensions=dimensions))

        arr2 = f(arr)
        arr2_sparse = f(arr_sparse)

        self.assertArraysEqual(arr2, arr2_sparse.todense())

    def testSparseWhileLoop(self):
        def cond_fun(params):
            i, A = params
            return i < 5

        def body_fun(params):
            i, A = params
            return i + 1, 2 * A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A))

        A = jnp.arange(4)
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 2)
        self.assertEqual(len(out_sparse), 2)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())

    def testSparseWhileLoopDuplicateIndices(self):
        def cond_fun(params):
            i, A, B = params
            return i < 5

        def body_fun(params):
            i, A, B = params
            # TODO(jakevdp): track shared indices through while loop & use this
            #   version of the test, which requires shared indices in order for
            #   the nse of the result to remain the same.
            # return i + 1, A, A + B

            # This version is fine without shared indices, and tests that we're
            # flattening non-shared indices consistently.
            return i + 1, B, A

        def f(A):
            return lax.while_loop(cond_fun, body_fun, (0, A, A))

        A = jnp.arange(4).reshape((2, 2))
        out_dense = f(A)

        Asp = BCOO.fromdense(A)
        out_sparse = self.sparsify(f)(Asp)

        self.assertEqual(len(out_dense), 3)
        self.assertEqual(len(out_sparse), 3)
        self.assertArraysEqual(out_dense[0], out_dense[0])
        self.assertArraysEqual(out_dense[1], out_sparse[1].todense())
        self.assertArraysEqual(out_dense[2], out_sparse[2].todense())

    def testSparsifyDenseXlaCall(self):
        # Test handling of dense xla_call within jaxpr interpreter.
        out = self.sparsify(jit(lambda x: x + 1))(0.0)
        self.assertEqual(out, 1.0)

    def testSparsifySparseXlaCall(self):
        # Test sparse lowering of XLA call
        def func(M):
            return 2 * M

        M = jnp.arange(6).reshape(2, 3)
        Msp = BCOO.fromdense(M)

        out_dense = func(M)
        out_sparse = self.sparsify(jit(func))(Msp)
        self.assertArraysEqual(out_dense, out_sparse.todense())

    def testSparseForiLoop(self):
        def func(M, x):
            body_fun = lambda i, val: (M @ val) / M.shape[1]
            return lax.fori_loop(0, 2, body_fun, x)

        x = jnp.arange(5.0)
        M = jnp.arange(25).reshape(5, 5)
        M_bcoo = BCOO.fromdense(M)

        result_dense = func(M, x)
        with jtu.ignore_warning(
                category=CuSparseEfficiencyWarning,
                message=
                "bcoo_dot_general GPU lowering requires matrices with sorted indices*"
        ):
            result_sparse = self.sparsify(func)(M_bcoo, x)

        self.assertArraysAllClose(result_dense, result_sparse)

    def testSparseCondSimple(self):
        def func(x):
            return lax.cond(False, lambda x: x, lambda x: 2 * x, x)

        x = jnp.arange(5.0)
        result_dense = func(x)

        x_bcoo = BCOO.fromdense(x)
        result_sparse = self.sparsify(func)(x_bcoo)

        self.assertArraysAllClose(result_dense, result_sparse.todense())

    def testSparseCondMismatchError(self):
        @self.sparsify
        def func(x, y):
            return lax.cond(False, lambda x: x[0], lambda x: x[1], (x, y))

        x = jnp.arange(5.0)
        y = jnp.arange(5.0)

        x_bcoo = BCOO.fromdense(x)
        y_bcoo = BCOO.fromdense(y)

        func(x, y)  # No error
        func(x_bcoo, y_bcoo)  # No error

        with self.assertRaisesRegex(
                TypeError, "sparsified true_fun and false_fun output.*"):
            func(x_bcoo, y)

    def testToDense(self):
        M = jnp.arange(4)
        Msp = BCOO.fromdense(M)

        @self.sparsify
        def func(M):
            return todense(M) + 1

        self.assertArraysEqual(func(M), M + 1)
        self.assertArraysEqual(func(Msp), M + 1)
        self.assertArraysEqual(jit(func)(M), M + 1)
        self.assertArraysEqual(jit(func)(Msp), M + 1)

    def testWeakTypes(self):
        # Regression test for https://github.com/google/jax/issues/8267
        M = jnp.arange(12, dtype='int32').reshape(3, 4)
        Msp = BCOO.fromdense(M)
        self.assertArraysEqual(
            operator.mul(2, M),
            self.sparsify(operator.mul)(2, Msp).todense(),
            check_dtypes=True,
        )