예제 #1
0
  def testLogSumExp(self, shapes, dtype, axis,
                    keepdims, return_sign, use_b):
    if jtu.device_under_test() != "cpu":
      rng = jtu.rand_some_inf_and_nan(self.rng())
    else:
      rng = jtu.rand_default(self.rng())
    # TODO(mattjj): test autodiff
    if use_b:
      def scipy_fun(array_to_reduce, scale_array):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      def lax_fun(array_to_reduce, scale_array):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign, b=scale_array)

      args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
    else:
      def scipy_fun(array_to_reduce):
        return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      def lax_fun(array_to_reduce):
        return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims,
                                     return_sign=return_sign)

      args_maker = lambda: [rng(shapes[0], dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
    self._CompileAndCheck(lax_fun, args_maker)
예제 #2
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "rng":
            jtu.rand_some_inf_and_nan(
            ) if jtu.device_under_test() != "cpu" else jtu.rand_default(),
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims
        } for shape in all_shapes for dtype in float_dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    @jtu.skip_on_flag("jax_xla_backend", "xrt")
    def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
        # TODO(mattjj): test autodiff
        def scipy_fun(array_to_reduce):
            return osp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        def lax_fun(array_to_reduce):
            return lsp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng":
                    rec.rng,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in CombosWithReplacement(all_shapes, rec.nargs)
                for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                            test_autodiff):
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

        if test_autodiff:
            jtu.check_grads(lax_op,
                            args,
                            order=1,
                            atol=1e-3,
                            rtol=3e-3,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "rng":
            jtu.rand_positive(),
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, rng, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    def testIssue980(self):
        x = onp.full((4, ), -1e20, dtype=onp.float32)
        self.assertAllClose(onp.zeros((4, ), dtype=onp.float32),
                            lsp_special.expit(x),
                            check_dtypes=True)