コード例 #1
0
ファイル: scipy_stats_test.py プロジェクト: npapernot/jax
def genNamedParametersNArgs(n, rng):
    return parameterized.named_parameters(
        jtu.cases_from_list(
          {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
            "rng": rng, "shapes": shapes, "dtypes": dtypes}
          for shapes in CombosWithReplacement(all_shapes, n)
          for dtypes in CombosWithReplacement(float_dtypes, n)))
コード例 #2
0
ファイル: scipy_stats_test.py プロジェクト: zhaowilliam/jax
def genNamedParametersNArgs(n):
    return parameterized.named_parameters(
        jtu.cases_from_list(
          {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
            "shapes": shapes, "dtypes": dtypes}
          for shapes in itertools.combinations_with_replacement(all_shapes, n)
          for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
コード例 #3
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}_keepdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": jtu.rand_default(), "shape": shape, "dtype": dtype,
       "axis": axis, "keepdims": keepdims}
      for shape in all_shapes for dtype in float_dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True]))
  @jtu.skip_on_flag("jax_xla_backend", "xrt")
  def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
    # TODO(mattjj): test autodiff
    def scipy_fun(array_to_reduce):
      return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    def lax_fun(array_to_reduce):
      return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix(
          rec.test_name, shapes, dtypes),
       "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
       "test_autodiff": rec.test_autodiff,
       "scipy_op": getattr(osp_special, rec.name),
       "lax_op": getattr(lsp_special, rec.name)}
      for rec in JAX_SPECIAL_FUNCTION_RECORDS
      for shapes in CombosWithReplacement(all_shapes, rec.nargs)
      for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)))
  def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                          test_autodiff):
    # TODO(mattjj): unskip this test combination when real() on tpu is improved
    if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
        and not shapes[0]):
      return absltest.unittest.skip("real() on scalar not supported on tpu")

    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

    if test_autodiff:
      jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3)
コード例 #4
0
ファイル: lax_vmap_test.py プロジェクト: x1489/jax
class LaxVmapTest(jtu.JaxTestCase):
    def _CheckBatching(self,
                       op,
                       bdim_size,
                       bdims,
                       shapes,
                       dtypes,
                       rng,
                       rtol=None,
                       atol=None):
        batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes)
        args = [
            rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes)
        ]
        args_slice = args_slicer(args, bdims)
        ans = api.vmap(op, bdims)(*args)
        if bdim_size == 0:
            args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
            out = op(*args)
            expected = np.zeros((0, ) + out.shape, out.dtype)
        else:
            expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)])
        self.assertAllClose(ans, expected, rtol=rtol, atol=atol)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    "{}_bdims={}".format(
                        jtu.format_test_name_suffix(
                            rec.op, shapes, itertools.repeat(dtype)), bdims),
                    "op_name":
                    rec.op,
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtype":
                    dtype,
                    "bdims":
                    bdims,
                    "tol":
                    rec.tol
                } for shape_group in compatible_shapes
                for shapes in itertools.combinations_with_replacement(
                    shape_group, rec.nargs) for bdims in all_bdims(*shapes)
                for dtype in rec.dtypes) for rec in LAX_OPS))
    def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol):
        rng = rng_factory(self.rng())
        op = getattr(lax, op_name)
        self._CheckBatching(op,
                            10,
                            bdims,
                            shapes, [dtype] * len(shapes),
                            rng,
                            atol=tol,
                            rtol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
            "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
            "_lhs_bdim={}_rhs_bdim={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), strides,
                padding, lhs_dil, rhs_dil, ",".join(dim_nums),
                feature_group_count, batch_group_count, lhs_bdim, rhs_bdim),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "strides":
            strides,
            "padding":
            padding,
            "lhs_dil":
            lhs_dil,
            "rhs_dil":
            rhs_dil,
            "dimension_numbers":
            dim_nums,
            "perms":
            perms,
            "lhs_bdim":
            lhs_bdim,
            "rhs_bdim":
            rhs_bdim,
            "feature_group_count":
            feature_group_count,
            "batch_group_count":
            batch_group_count,
        } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1,
                                                                           2)])
                            for lhs_shape, rhs_shape, all_strides, all_pads,
                            lhs_dils, rhs_dils in [
                                (
                                    (b * batch_group_count,
                                     i * feature_group_count, 6,
                                     7),  # lhs_shape
                                    (j * batch_group_count *
                                     feature_group_count, i, 1,
                                     2),  # rhs_shape
                                    [(1, 1), (1, 2), (2, 1)],  # strides
                                    [((0, 0),
                                      (0, 0)), ((1, 0),
                                                (0, 1)), ((0, -1),
                                                          (0, 0))],  # pads
                                    [(1, 1), (2, 1)],  # lhs_dils
                                    [(1, 1), (2, 2)])  # rhs_dils
                                for b, i, j in itertools.product([1, 2],
                                                                 repeat=3)
                            ] for strides in all_strides
                            for rhs_dil in rhs_dils for lhs_dil in lhs_dils
                            for dtype in [np.float32] for padding in all_pads
                            for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"),
                                                     ([0, 1, 2, 3],
                                                      [0, 1, 2, 3])),
                                                    (("NHWC", "HWIO", "NHWC"),
                                                     ([0, 2, 3, 1],
                                                      [2, 3, 1, 0])),
                                                    (("NHWC", "OIHW", "NCHW"),
                                                     ([0, 2, 3, 1],
                                                      [0, 1, 2, 3]))]
                            for lhs_bdim in itertools.chain(
                                [cast(Optional[int], None)],
                                range(len(lhs_shape) + 1))
                            for rhs_bdim in itertools.chain(
                                [cast(Optional[int], None)],
                                range(len(rhs_shape) + 1))
                            if (lhs_bdim, rhs_bdim) != (None, None)))
    def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype,
                                       strides, padding, lhs_dil, rhs_dil,
                                       dimension_numbers, perms,
                                       feature_group_count, batch_group_count,
                                       lhs_bdim, rhs_bdim):
        rng = jtu.rand_default(self.rng())
        tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3

        # permute shapes to match dim_spec, scale by feature_group_count
        lhs_perm, rhs_perm = perms
        lhs_shape = list(np.take(lhs_shape, lhs_perm))
        rhs_shape = list(np.take(rhs_shape, rhs_perm))

        conv = partial(lax.conv_general_dilated,
                       window_strides=strides,
                       padding=padding,
                       lhs_dilation=lhs_dil,
                       rhs_dilation=rhs_dil,
                       dimension_numbers=dimension_numbers,
                       feature_group_count=feature_group_count,
                       batch_group_count=batch_group_count,
                       precision=lax.Precision.HIGHEST)
        self._CheckBatching(conv,
                            5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape),
                            (dtype, dtype),
                            rng,
                            rtol=tol,
                            atol=tol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
                shape, from_dtype, to_dtype, bdims),
            "shape":
            shape,
            "from_dtype":
            from_dtype,
            "to_dtype":
            to_dtype,
            "bdims":
            bdims
        } for from_dtype, to_dtype in itertools.product(
            [np.float32, np.int32, "float32", "int32"], repeat=2)
                            for shape in [(2, 3)]
                            for bdims in all_bdims(shape)))
    def testConvertElementType(self, shape, from_dtype, to_dtype, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.convert_element_type(x, to_dtype)
        self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format(
                shape, from_dtype, to_dtype, bdims),
            "shape":
            shape,
            "from_dtype":
            from_dtype,
            "to_dtype":
            to_dtype,
            "bdims":
            bdims
        } for from_dtype, to_dtype in itertools.product(
            [np.float32, np.int32, "float32", "int32"], repeat=2)
                            for shape in [(2, 3)]
                            for bdims in all_bdims(shape)))
    def testBitcastElementType(
        self,
        shape,
        from_dtype,
        to_dtype,
        bdims,
    ):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.bitcast_convert_type(x, to_dtype)
        self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format(
                    jtu.format_shape_dtype_string(min_shape, dtype),
                    jtu.format_shape_dtype_string(operand_shape, dtype),
                    jtu.format_shape_dtype_string(max_shape, dtype), bdims),
                "min_shape":
                min_shape,
                "operand_shape":
                operand_shape,
                "max_shape":
                max_shape,
                "dtype":
                dtype,
                "bdims":
                bdims
            } for min_shape, operand_shape, max_shape in [
                [(), (2, 3), ()],
                [(2, 3), (2, 3), ()],
                [(), (2, 3), (2, 3)],
                [(2, 3), (2, 3), (2, 3)],
            ] for dtype in default_dtypes
            for bdims in all_bdims(min_shape, operand_shape, max_shape)))
    def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims):
        rng = jtu.rand_default(self.rng())
        raise SkipTest(
            "batching rule for clamp not implemented")  # TODO(mattj)
        shapes = [min_shape, operand_shape, max_shape]
        self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_bdims={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype), bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "bdims":
            bdims
        } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)]
                            for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes))
    def testDot(self, lhs_shape, rhs_shape, dtype, bdims):
        rng = jtu.rand_default(self.rng())
        op = partial(lax.dot, precision=lax.Precision.HIGHEST)
        self._CheckBatching(op,
                            5,
                            bdims, (lhs_shape, rhs_shape), (dtype, dtype),
                            rng,
                            rtol={
                                np.float16: 5e-2,
                                np.float64: 5e-14
                            })

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}"
            .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
                    jtu.format_shape_dtype_string(rhs_shape, dtype),
                    lhs_contracting, rhs_contracting, bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "lhs_contracting":
            lhs_contracting,
            "rhs_contracting":
            rhs_contracting,
            "bdims":
            bdims
        } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [
            [(5, ), (5, ), [0], [0]],
            [(5, 7), (5, ), [0], [0]],
            [(7, 5), (5, ), [1], [0]],
            [(3, 5), (2, 5), [1], [1]],
            [(5, 3), (5, 2), [0], [0]],
            [(5, 3, 2), (5, 2, 4), [0], [0]],
            [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]],
            [(5, 3, 2), (3, 5, 2, 4), [0, 2], [1, 2]],
            [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]],
            [(3, 2), (2, 4), [1], [0]],
        ] for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes))
    def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype,
                                   lhs_contracting, rhs_contracting, bdims):
        rng = jtu.rand_small(self.rng())
        dimension_numbers = ((lhs_contracting, rhs_contracting), ([], []))
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format(
                jtu.format_shape_dtype_string(lhs_shape, dtype),
                jtu.format_shape_dtype_string(rhs_shape, dtype),
                dimension_numbers, bdims),
            "lhs_shape":
            lhs_shape,
            "rhs_shape":
            rhs_shape,
            "dtype":
            dtype,
            "dimension_numbers":
            dimension_numbers,
            "bdims":
            bdims
        } for lhs_shape, rhs_shape, dimension_numbers in [
            ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
            ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))),
            ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))),
        ] for bdims in all_bdims(lhs_shape, rhs_shape)
                            for dtype in default_dtypes))
    def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype,
                                       dimension_numbers, bdims):
        rng = jtu.rand_small(self.rng())
        dot = partial(lax.dot_general, dimension_numbers=dimension_numbers)
        self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape),
                            (dtype, dtype), rng)

        # Checks that batching didn't introduce any transposes or broadcasts.
        jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype),
                                    np.zeros(rhs_shape, dtype))
        for eqn in jtu.iter_eqns(jaxpr.jaxpr):
            self.assertFalse(eqn.primitive in ["transpose", "broadcast"])

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format(
                shape,
                np.dtype(dtype).name, broadcast_sizes, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "broadcast_sizes":
            broadcast_sizes,
            "bdims":
            bdims
        } for shape in [(), (2, 3)] for dtype in default_dtypes
                            for broadcast_sizes in [(), (2, ), (1, 2)]
                            for bdims in all_bdims(shape)))
    def testBroadcast(self, shape, dtype, broadcast_sizes, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.broadcast(x, broadcast_sizes)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_outshape={}_bcdims={}_bdims={}".format(
                jtu.format_shape_dtype_string(inshape, dtype), outshape,
                broadcast_dimensions, bdims),
            "inshape":
            inshape,
            "dtype":
            dtype,
            "outshape":
            outshape,
            "dimensions":
            broadcast_dimensions,
            "bdims":
            bdims
        } for inshape, outshape, broadcast_dimensions in [
            ([2], [2, 2], [0]),
            ([2], [2, 2], [1]),
            ([2], [2, 3], [0]),
            ([], [2, 3], []),
        ] for dtype in default_dtypes for bdims in all_bdims(inshape)))
    def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims):
        rng = jtu.rand_default(self.rng())
        raise SkipTest("this test has failures in some cases")  # TODO(mattjj)
        op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
        self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_dimensions={}_bdims={}".format(
                jtu.format_shape_dtype_string(arg_shape, np.float32),
                dimensions, bdims),
            "arg_shape":
            arg_shape,
            "dimensions":
            dimensions,
            "bdims":
            bdims
        } for arg_shape, dimensions in [
            [(1, ), (0, )],
            [(1, ), (-1, )],
            [(2, 1, 4), (1, )],
            [(2, 1, 4), (-2, )],
            [(2, 1, 3, 1), (1, )],
            [(2, 1, 3, 1), (1, 3)],
            [(2, 1, 3, 1), (3, )],
            [(2, 1, 3, 1), (1, -1)],
        ] for bdims in all_bdims(arg_shape)))
    def testSqueeze(self, arg_shape, dimensions, bdims):
        dtype = np.float32
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.squeeze(x, dimensions)
        self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_inshape={}_outshape={}_dims={}_bdims={}".format(
                    jtu.format_shape_dtype_string(arg_shape, dtype),
                    jtu.format_shape_dtype_string(out_shape, dtype),
                    dimensions, bdims),
                "arg_shape":
                arg_shape,
                "out_shape":
                out_shape,
                "dtype":
                dtype,
                "dimensions":
                dimensions,
                "bdims":
                bdims
            } for dtype in default_dtypes
            for arg_shape, dimensions, out_shape in [[(3, 4), None, (
                12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, (
                    2, 8)], [(2, 2,
                              4), (0, 1,
                                   2), (2, 8)], [(2, 2, 4), (1, 0, 2), (
                                       8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2,
                                                                       2)]]
            for bdims in all_bdims(arg_shape)))
    def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions)
        self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_pads={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), pads, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "pads":
            pads,
            "bdims":
            bdims
        } for shape in [(2, 3)] for bdims in all_bdims(shape)
                            for dtype in default_dtypes
                            for pads in [[(1, 2, 1), (0, 1, 0)]]))
    def testPad(self, shape, dtype, pads, bdims):
        rng = jtu.rand_small(self.rng())
        fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_predshape={}_argshapes={}_bdims={}".format(
                    jtu.format_shape_dtype_string(pred_shape, np.bool_),
                    jtu.format_shape_dtype_string(arg_shape, arg_dtype),
                    bdims),
                "pred_shape":
                pred_shape,
                "arg_shape":
                arg_shape,
                "arg_dtype":
                arg_dtype,
                "bdims":
                bdims
            }
            for arg_shape in [(), (3, ), (2, 3)]
            for pred_shape in ([(), arg_shape] if arg_shape else [()])
            for bdims in all_bdims(pred_shape, arg_shape, arg_shape)
            for arg_dtype in default_dtypes))
    def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda c, x, y: lax.select(c < 0, x, y)
        self._CheckBatching(op, 5, bdims, (
            pred_shape,
            arg_shape,
            arg_shape,
        ), (np.bool_, arg_dtype, arg_dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}".
            format(jtu.format_shape_dtype_string(shape, dtype), start_indices,
                   limit_indices, strides, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "starts":
            start_indices,
            "limits":
            limit_indices,
            "strides":
            strides,
            "bdims":
            bdims
        } for shape, start_indices, limit_indices, strides in [
            [(3, ), (1, ), (2, ), None],
            [(7, ), (4, ), (7, ), None],
            [(5, ), (1, ), (5, ), (2, )],
            [(8, ), (1, ), (6, ), (2, )],
            [(5, 3), (1, 1), (3, 2), None],
            [(5, 3), (1, 1), (3, 1), None],
            [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
            [(5, 3), (1, 1), (2, 1), (1, 1)],
            [(5, 3), (1, 1), (5, 3), (2, 1)],
        ] for bdims in all_bdims(shape) for dtype in default_dtypes))
    def testSlice(self, shape, dtype, starts, limits, strides, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.slice(x, starts, limits, strides)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_perm={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), perm, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "perm":
            perm,
            "bdims":
            bdims
        } for shape, perm in [
            [(3, 4), (1, 0)],
            [(3, 4), (0, 1)],
            [(3, 4, 5), (2, 1, 0)],
            [(3, 4, 5), (1, 0, 2)],
        ] for bdims in all_bdims(shape) for dtype in default_dtypes))
    def testTranspose(self, shape, dtype, perm, bdims):
        rng = jtu.rand_default(self.rng())
        op = lambda x: lax.transpose(x, perm)
        self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
                init_val, bdims),
            "op":
            op,
            "init_val":
            init_val,
            "shape":
            shape,
            "dtype":
            dtype,
            "dims":
            dims,
            "bdims":
            bdims
        } for init_val, op, dtypes in [
            (0, lax.add, default_dtypes),
            (1, lax.mul, default_dtypes),
            (0, lax.max, all_dtypes),  # non-monoidal
            (-np.inf, lax.max, float_dtypes),
            (dtypes.iinfo(np.int32).min, lax.max, [np.int32]),
            (dtypes.iinfo(np.int64).min, lax.max, [np.int64]),
            (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]),
            (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]),
            (np.inf, lax.min, float_dtypes),
            (dtypes.iinfo(np.int32).max, lax.min, [np.int32]),
            (dtypes.iinfo(np.int64).max, lax.min, [np.int64]),
            (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]),
            (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]),
        ] for dtype in dtypes for shape, dims in [[(3, 4, 5), (
            0, )], [(3, 4, 5), (1, 2)], [(3, 4,
                                          5), (0, 2)], [(3, 4, 5), (0, 1, 2)]]
                            for bdims in all_bdims(shape)))
    def testReduce(self, op, init_val, shape, dtype, dims, bdims):
        rng = jtu.rand_small(self.rng())
        init_val = np.asarray(init_val, dtype=dtype)
        fun = lambda operand: lax.reduce(operand, init_val, op, dims)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_inshape={}_reducedims={}_bdims={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
                bdims),
            "op":
            op,
            "shape":
            shape,
            "dtype":
            dtype,
            "dim":
            dim,
            "bdims":
            bdims
        } for op in [lax.argmin, lax.argmax] for dtype in default_dtypes
                            for shape in [(3, 4, 5)]
                            for dim in range(len(shape))
                            for bdims in all_bdims(shape)))
    def testArgminmax(self, op, shape, dtype, dim, bdims):
        rng = jtu.rand_default(self.rng())
        fun = lambda operand: op(operand, dim, np.int32)
        self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                ("_op={}_shape={}_dims={}_strides={}_padding={}"
                 "_basedilation={}_windowdilation={}").format(
                     op.__name__, jtu.format_shape_dtype_string(shape, dtype),
                     dims, strides, padding, base_dilation, window_dilation),
                "op":
                op,
                "init_val":
                init_val,
                "dtype":
                dtype,
                "shape":
                shape,
                "dims":
                dims,
                "strides":
                strides,
                "padding":
                padding,
                "base_dilation":
                base_dilation,
                "window_dilation":
                window_dilation
            } for init_val, op, dtypes in [
                (0, lax.add, [np.float32]),
                (-np.inf, lax.max, [np.float32]),
                (np.inf, lax.min, [np.float32]),
            ] for shape, dims, strides, padding, base_dilation, window_dilation
            in (itertools.chain(
                itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (
                    2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [(
                        1, 1), (2, 3)], [(1, 1), (1, 2)]),
                itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (
                    2, 1, 2, 1)], [(1, 2, 2, 1), (
                        1, 1, 1, 1
                    )], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(
                        1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2,
                                                                    1)])))
            for dtype in dtypes))
    def testReduceWindow(self, op, init_val, dtype, shape, dims, strides,
                         padding, base_dilation, window_dilation):
        rng = jtu.rand_small(self.rng())
        init_val = np.asarray(init_val, dtype=dtype)

        def fun(operand):
            return lax.reduce_window(operand, init_val, op, dims, strides,
                                     padding, base_dilation, window_dilation)

        for bdims in all_bdims(shape):
            self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_op={}_shape={}_axis={}_bdims={}_reverse={}".format(
                op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis,
                bdims, reverse),
            "op":
            op,
            "shape":
            shape,
            "dtype":
            dtype,
            "bdims":
            bdims,
            "axis":
            axis,
            "reverse":
            reverse
        } for op, types in [
            (lax.cumsum, [np.float32, np.float64]),
            (lax.cumprod, [np.float32, np.float64]),
        ] for dtype in types for shape in [[10], [3, 4, 5]]
                            for axis in range(len(shape))
                            for bdims in all_bdims(shape)
                            for reverse in [False, True]))
    def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse):
        rng_factory = (jtu.rand_default if dtypes.issubdtype(
            dtype, np.integer) else jtu.rand_small)
        rng = rng_factory(self.rng())
        self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims,
                            (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_dtype={}_padding={}".format(np.dtype(dtype).name, padding),
            "dtype":
            dtype,
            "padding":
            padding
        } for dtype in float_dtypes for padding in ["VALID", "SAME"]))
    @jtu.skip_on_flag("jax_skip_slow_tests", True)
    @jtu.ignore_warning(message="Using reduced precision for gradient.*")
    def testSelectAndGatherAdd(self, dtype, padding):
        if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16:
            raise SkipTest(
                "bfloat16 _select_and_gather_add doesn't work on tpu")
        rng = jtu.rand_small(self.rng())
        all_configs = itertools.chain(
            itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1),
                                                           (1, 2)]),
            itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)],
                              [(1, 2, 2, 1), (1, 1, 1, 1)]))

        def fun(operand, tangents):
            pads = lax.padtype_to_pads(operand.shape, dims, strides, padding)
            ones = (1, ) * len(operand.shape)
            return lax._select_and_gather_add(operand, tangents, lax.ge_p,
                                              dims, strides, pads, ones, ones)

        for shape, dims, strides in all_configs:
            for bdims in all_bdims(shape, shape):
                self._CheckBatching(fun, 3, bdims, (shape, shape),
                                    (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}"
            f"_padding={padding}_dims={dims}_strides={strides}",
            "dtype":
            dtype,
            "padding":
            padding,
            "shape":
            shape,
            "dims":
            dims,
            "strides":
            strides
        } for dtype in float_dtypes for padding in ["VALID", "SAME"]
                            for shape in [(3, 2, 4, 6)]
                            for dims in [(1, 1, 2, 1)]
                            for strides in [(1, 2, 2, 1), (1, 1, 1, 1)]))
    def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides):
        rng = jtu.rand_small(self.rng())

        pads = lax.padtype_to_pads(shape, dims, strides, padding)

        def fun(operand, cotangents):
            return lax._select_and_scatter_add(operand, cotangents, lax.ge_p,
                                               dims, strides, pads)

        ones = (1, ) * len(shape)
        cotangent_shape = api.eval_shape(
            lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides,
                                                 pads, ones, ones),
            np.ones(shape, dtype)).shape

        for bdims in all_bdims(cotangent_shape, shape):
            self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape),
                                (dtype, dtype), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims),
            "shape":
            shape,
            "bdims":
            bdims,
            "fft_ndims":
            fft_ndims
        } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)]
                            for bdims in all_bdims(shape)
                            for fft_ndims in range(0,
                                                   min(3, len(shape)) + 1)))
    @jtu.skip_on_devices("tpu")  # TODO(b/137993701): unimplemented cases.
    def testFft(self, fft_ndims, shape, bdims):
        rng = jtu.rand_default(self.rng())
        ndims = len(shape)
        axes = range(ndims - fft_ndims, ndims)
        fft_lengths = [shape[axis] for axis in axes]
        op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths)
        self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
                slice_sizes, bdims),
            "shape":
            shape,
            "dtype":
            dtype,
            "idxs":
            idxs,
            "dnums":
            dnums,
            "slice_sizes":
            slice_sizes,
            "bdims":
            bdims
        } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [
            ((5, ), np.array([[0], [2]]),
             lax.GatherDimensionNumbers(offset_dims=(),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, )),
            ((10, ), np.array([[0], [0], [0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(),
                                        start_index_map=(0, )), (2, )),
            ((
                10,
                5,
            ), np.array([[0], [2], [1]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, )), (1, 3)),
            ((10, 5), np.array([[0, 2], [1, 0]]),
             lax.GatherDimensionNumbers(offset_dims=(1, ),
                                        collapsed_slice_dims=(0, ),
                                        start_index_map=(0, 1)), (1, 3)),
        ] for bdims in all_bdims(shape, idxs.shape)))
    def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims):
        fun = partial(lax.gather,
                      dimension_numbers=dnums,
                      slice_sizes=slice_sizes)
        self._CheckBatching(fun, 0, bdims, [shape, idxs.shape],
                            [dtype, idxs.dtype], jtu.rand_default(self.rng()))
        self._CheckBatching(fun, 5, bdims, [shape, idxs.shape],
                            [dtype, idxs.dtype], jtu.rand_default(self.rng()))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format(
                    jtu.format_shape_dtype_string(arg_shape, dtype), idxs,
                    update_shape, dnums, bdims),
                "arg_shape":
                arg_shape,
                "dtype":
                dtype,
                "idxs":
                idxs,
                "update_shape":
                update_shape,
                "dnums":
                dnums,
                "bdims":
                bdims
            } for dtype in float_dtypes
            for arg_shape, idxs, update_shape, dnums in [
                ((5, ), np.array([[0], [2]]), (2, ),
                 lax.ScatterDimensionNumbers(update_window_dims=(),
                                             inserted_window_dims=(0, ),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
                ((10, ), np.array([[0], [0], [0]]), (3, 2),
                 lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                             inserted_window_dims=(),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
                ((
                    10,
                    5,
                ), np.array([[0], [2], [1]]), (3, 3),
                 lax.ScatterDimensionNumbers(update_window_dims=(1, ),
                                             inserted_window_dims=(0, ),
                                             scatter_dims_to_operand_dims=(
                                                 0, ))),
            ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape)))
    def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums,
                       bdims):
        fun = partial(lax.scatter_add, dimension_numbers=dnums)
        self._CheckBatching(fun,
                            5,
                            bdims, [arg_shape, idxs.shape, update_shape],
                            [dtype, idxs.dtype, dtype],
                            jtu.rand_default(self.rng()),
                            rtol={np.float16: 5e-3})

    def testShapeUsesBuiltinInt(self):
        x = lax.iota(np.int32, 3) + 1
        self.assertIsInstance(x.shape[0], int)  # not np.int64

    def testBroadcastShapesReturnsPythonInts(self):
        shape1, shape2 = (1, 2, 3), (2, 3)
        out_shape = lax.broadcast_shapes(shape1, shape2)
        self.assertTrue(all(type(s) is int for s in out_shape))

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name":
                "_shape={}_k={}_bdims={}".format(
                    jtu.format_shape_dtype_string(shape, dtype), k, bdims),
                "shape":
                shape,
                "dtype":
                dtype,
                "k":
                k,
                "bdims":
                bdims,
                "rng_factory":
                rng_factory
            } for shape in [(4, ), (3, 5, 3)] for k in [1, 3]
            for bdims in all_bdims(shape)
            # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed:
            # The top_k indices for integer arrays with identical entries won't match between
            # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes.
            # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
            # values a bfloat16 can represent exactly to avoid ties.
            for dtype, rng_factory in itertools.chain(
                unsafe_zip(default_dtypes, itertools.repeat(
                    jtu.rand_unique_int)))))
    def testTopK(self, shape, dtype, k, bdims, rng_factory):
        rng = rng_factory(self.rng())
        # _CheckBatching doesn't work with tuple outputs, so test outputs separately.
        op1 = lambda x: lax.top_k(x, k=k)[0]
        self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng)
        op2 = lambda x: lax.top_k(x, k=k)[1]
        self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format(
                jtu.format_shape_dtype_string(shape, np.float32), dimension,
                arity, bdims, is_stable),
            "shape":
            shape,
            "dimension":
            dimension,
            "arity":
            arity,
            "bdims":
            bdims,
            "is_stable":
            is_stable
        } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3)
                            for bdims in all_bdims(*((shape, ) * arity))
                            for is_stable in [False, True]))
    def testSort(self, shape, dimension, arity, bdims, is_stable):
        rng = jtu.rand_default(self.rng())
        if arity == 1:
            fun = partial(lax.sort, dimension=dimension)
            self._CheckBatching(fun, 5, bdims, (shape, ) * arity,
                                (np.float32, ) * arity, rng)
        else:
            for i in range(arity):
                fun = lambda *args, i=i: lax.sort(
                    args, dimension=dimension, is_stable=is_stable)[i]
                self._CheckBatching(fun, 5, bdims, (shape, ) * arity,
                                    (np.float32, ) * arity, rng)
コード例 #5
0
class LaxBackedNumpyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Numpy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
                                                      dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
        for shapes in filter(
          _shapes_are_broadcast_compatible,
          CombosWithReplacement(rec.shapes, rec.nargs))
        for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
      for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS)))
  def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
        for shapes in filter(
          _shapes_are_broadcast_compatible,
          CombosWithReplacement(rec.shapes, rec.nargs))
        for dtypes in filter(
          _dtypes_are_compatible_for_bitwise_ops,
          CombosWithReplacement(rec.dtypes, rec.nargs)))
      for rec in JAX_BITWISE_OP_RECORDS))
  def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes):
    if not FLAGS.jax_enable_x64 and any(
        onp.iinfo(dtype).bits == 64 for dtype in dtypes):
      self.skipTest("x64 types are disabled by jax_enable_x64")
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis,
          "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims),
       "rng": rec.rng, "shape": shape, "dtype": dtype, "out_dtype": out_dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis, "keepdims": keepdims}
      for rec in JAX_REDUCER_RECORDS
      for shape in rec.shapes for dtype in rec.dtypes
      for out_dtype in [None] + rec.dtypes
      for axis in set(range(-len(shape), len(shape))) | set([None])
      for keepdims in [False, True]))
  def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis, keepdims):
    onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
    lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis, "keepdims": keepdims}
      for rec in JAX_REDUCER_NO_DTYPE_RECORDS
      for shape in rec.shapes for dtype in rec.dtypes
      for axis in set(range(-len(shape), len(shape))) | set([None])
      for keepdims in [False, True]))
  def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims):
    onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
    lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "shape": shape, "dtype": dtype, "axis": axis}
      for shape in all_shapes for dtype in all_dtypes
      for axis in set(range(-len(shape), len(shape))) | set([None])))
  def testCountNonzero(self, shape, dtype, axis):
    rng = jtu.rand_some_zero()
    onp_fun = lambda x: onp.count_nonzero(x, axis)
    lnp_fun = lambda x: lnp.count_nonzero(x, axis)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "{}_inshape={}_axis={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis}
      for rec in JAX_ARGMINMAX_RECORDS
      for shape in rec.shapes for dtype in rec.dtypes
      for axis in range(-len(shape), len(shape))))
  def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):

    def onp_fun(array_to_reduce):
      return onp_op(array_to_reduce, axis)

    def lnp_fun(array_to_reduce):
      return lnp_op(array_to_reduce, axis)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("matrix-scalar", (3, 3), ()),
          ("scalar-matrix", (), (3, 3)),
          ("matrix-vector", (4, 5), (5,)),
          ("vector-matrix", (6,), (6, 4)),
          ("matrix-matrix", (3, 4), (4, 5)),
          ("tensor-vector", (4, 3, 2), (2,)),
          ("vector-tensor", (2,), (3, 2, 4)),
          ("tensor-matrix", (4, 3, 2), (2, 5)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-tensor", (2, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)))
  def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("vector-vector", (3,), (3,)),
          ("matrix-vector", (3, 3), (3,)),
          ("vector-matrix", (3,), (3, 3)),
          ("matrix-matrix", (3, 3), (3, 3)),
          ("vector-tensor", (3,), (5, 3, 2)),
          ("tensor-vector", (5, 3, 2), (2,)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-matrix", (5, 2, 3), (3, 2)),
          ("tensor-tensor", (5, 3, 4), (5, 4, 1)),
          ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)))
  def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
                            check_dtypes=True)
    self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}_{}".format(
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype),
          axes),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "axes": axes, "rng": rng}
      for rng in [jtu.rand_default()]
      for lhs_shape, rhs_shape, axes in [
          [(2, 3, 4), (3, 4, 5, 6), 2],
          [(2, 3, 4), (5, 4, 3, 6), [1, 2]],
          [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]],
          [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]],
      ]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)))
  def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    lnp_fun = lambda a, b: lnp.tensordot(a, b, axes)
    onp_fun = lambda a, b: onp.tensordot(a, b, axes)
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_{}".format(
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": jtu.rand_default()}
      # TODO(phawkins): support integer dtypes too.
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)
      for lhs_shape, rhs_shape in [
        (l, r) for l, r in CombosWithReplacement(all_shapes, 2)
        if len(jtu._dims_of_shape(l)) == 0
        or len(jtu._dims_of_shape(r)) == 0
        or l[-1] == r[-1]]))
  def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs)
    lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs)
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_amin={}_amax={}".format(
          jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
       "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in float_dtypes
      for a_min, a_max in [(-1, None), (None, 1), (-1, 1)]))
  def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
    onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
    lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_decimals={}".format(
          jtu.format_shape_dtype_string(shape, dtype), decimals),
       "shape": shape, "dtype": dtype, "decimals": decimals,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in float_dtypes
      for decimals in [0, 1, -2]))
  def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
    onp_fun = lambda x: onp.round(x, decimals=decimals)
    lnp_fun = lambda x: lnp.round(x, decimals=decimals)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
          axis, ",".join(str(d) for d in base_shape),
          ",".join(onp.dtype(dtype).name for dtype in dtypes)),
       "axis": axis, "base_shape": base_shape, "dtypes": dtypes,
       "rng": jtu.rand_default()}
      for num_arrs in [3]
      for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for axis in range(-len(base_shape)+1, len(base_shape))))
  def testConcatenate(self, axis, base_shape, dtypes, rng):
    wrapped_axis = axis % len(base_shape)
    shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
    onp_fun = lambda *args: onp.concatenate(args, axis=axis)
    lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)

    def args_maker():
      return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, repeats),
       "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats,
       "rng": jtu.rand_default()}
      for repeats in [0, 1, 2]
      for dtype in default_dtypes
      for shape in all_shapes
      for axis in [None] + list(range(-len(shape), len(shape)))))
  def testRepeat(self, axis, shape, dtype, repeats, rng):
    onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis)
    lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis)

    args_maker = lambda: [rng(shape, dtype)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dtype={}_m={}_n={}_k={}".format(
          onp.dtype(dtype).name, m, n, k),
       "m": m, "n": n, "k": k, "dtype": dtype, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for n in [0, 4]
      for m in [None, 0, 1, 3, 4]
      for k in list(range(-4, 4))))
  def testTri(self, m, n, k, dtype, rng):
    onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype)
    lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype)
    args_maker = lambda: []
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_k={}".format(
          op, jtu.format_shape_dtype_string(shape, dtype), k),
       "dtype": dtype, "shape": shape, "op": op, "k": k,
       "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) >= 1]
      for op in ["tril", "triu"]
      for k in list(range(-3, 3))))
  def testTriLU(self, dtype, shape, op, k, rng):
    onp_fun = lambda arg: getattr(onp, op)(arg, k=k)
    lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "dtype": dtype, "shape": shape, "k": k, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) in (1, 2)]
      for k in list(range(-4, 4))))
  def testDiag(self, shape, dtype, k, rng):
    onp_fun = lambda arg: onp.diag(arg, k)
    lnp_fun = lambda arg: lnp.diag(arg, k)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format(
          jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2),
       "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1,
       "axis2": axis2, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) >= 2]
      for (axis1, axis2) in itertools.combinations(range(len(shape)), 2)
      for offset in list(range(-4, 4))))
  def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng):
    onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2)
    lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n),
       "dtype": dtype, "n": n}
      for dtype in default_dtypes
      for n in list(range(4))))
  def testIdentity(self, n, dtype):
    onp_fun = lambda: onp.identity(n, dtype)
    lnp_fun = lambda: lnp.identity(n, dtype)
    args_maker = lambda: []
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          out_dtype, offset, axis1, axis2),
       "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset,
       "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for out_dtype in [None] + default_dtypes
      for shape in [shape for shape in all_shapes if len(shape) >= 2]
      for (axis1, axis2) in itertools.combinations(range(len(shape)), 2)
      for offset in list(range(-4, 4))))
  def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng):
    onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype)
    lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}".format(
          jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
       "shape": shape, "dtypes": dtypes, "rng": rng}
      for dtypes in [
        [onp.float32],
        [onp.float32, onp.float32],
        [onp.float32, onp.int32, onp.float32],
        [onp.float32, onp.int64, onp.float32],
        [onp.float32, onp.int32, onp.float64],
      ]
      for shape in [(), (2,), (3, 4), (1, 100)]
      for rng in [jtu.rand_default()]))
  def testStack(self, shape, dtypes, rng):
    args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
    self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outdtype={}".format(
          jtu.format_shape_dtype_string(shape, fill_value_dtype),
          onp.dtype(out_dtype).name),
       "shape": shape, "fill_value_dtype": fill_value_dtype,
       "out_dtype": out_dtype, "rng": jtu.rand_default()}
      for shape in array_shapes
      for fill_value_dtype in default_dtypes
      for out_dtype in default_dtypes))
  def testFull(self, shape, fill_value_dtype, out_dtype, rng):
    onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
    lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
    args_maker = lambda: [rng((), fill_value_dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format(
          jtu.format_shape_dtype_string(shape, in_dtype),
          onp.dtype(fill_value_dtype).name,
          onp.dtype(out_dtype).name),
       "shape": shape, "in_dtype": in_dtype,
       "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype,
       "rng": jtu.rand_default()}
      for shape in array_shapes
      for in_dtype in default_dtypes
      for fill_value_dtype in default_dtypes
      for out_dtype in default_dtypes))
  def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng):
    onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype)
    lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype)
    args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_axis={}_{}sections".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
       "shape": shape, "num_sections": num_sections, "axis": axis,
       "dtype": dtype, "rng": jtu.rand_default()}
      for shape, axis, num_sections in [
          ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
          ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
      for dtype in default_dtypes))
  def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
    onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
    lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype)),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for arg_shape, out_shape in [
          (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)),
          ((), (1, 1, 1)),
          ((7, 0), (0, 42, 101)),
          ((3, 4), 12),
          ((3, 4), (12,)),
          ((3, 4), -1),
          ((2, 1, 4), (-1,)),
          ((2, 2, 4), (2, 8))
      ]))
  def testReshape(self, arg_shape, out_shape, dtype, rng):
    onp_fun = lambda x: onp.reshape(x, out_shape)
    lnp_fun = lambda x: lnp.reshape(x, out_shape)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_expanddim={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), dim),
       "arg_shape": arg_shape, "dtype": dtype, "dim": dim,
       "rng": jtu.rand_default()}
      for arg_shape in [(), (3,), (3, 4)]
      for dtype in default_dtypes
      for dim in range(-len(arg_shape)+1, len(arg_shape))))
  def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
    onp_fun = lambda x: onp.expand_dims(x, dim)
    lnp_fun = lambda x: lnp.expand_dims(x, dim)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axes=({},{})".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
       "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
       "rng": jtu.rand_default()}
      for arg_shape, ax1, ax2 in [
          ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
          ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
      for dtype in default_dtypes))
  def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
    onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
    lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax),
       "arg_shape": arg_shape, "dtype": dtype, "ax": ax,
       "rng": jtu.rand_default()}
      for arg_shape, ax in [
          ((3, 1), None),
          ((3, 1), 1),
          ((1, 3, 1), (0, 2)),
          ((1, 4, 1), (0,))]
      for dtype in default_dtypes))
  def testSqueeze(self, arg_shape, dtype, ax, rng):
    onp_fun = lambda x: onp.squeeze(x, ax)
    lnp_fun = lambda x: lnp.squeeze(x, ax)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_arg{}".format(i), "arg": arg}
      for i, arg in enumerate([
          [1, 2, 3], [1., 2., 3.],
          [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
          [[3, onp.array(2), 1], onp.arange(3.)],
      ])))
  def testArray(self, arg):
    args_maker = lambda: [arg]
    self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)

  def testArrayAsarrayMethod(self):
    class arraylike(object):
      def __asarray__(self, dtype=None):
        return 3.
    a = arraylike()
    ans = lnp.array(a)
    assert ans == 3.

  def testAllClose(self):
    rng = onp.random.RandomState(0)
    x = rng.randn(2, 2)
    y = rng.randn(2)

    def same(list1, list2):
      allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
      elements_close = list(map(allclose, list1, list2))
      return lnp.all(lnp.array(elements_close))

    csame = api.jit(same)

    a1 = same((x, y), (x, y))
    a2 = csame((x, y), (x, y))
    a3 = csame((x, y), (x, 2 * y))

    self.assertTrue(a1)
    self.assertTrue(a2)
    self.assertFalse(a3)

  @jtu.skip_on_devices("tpu")  # TODO(mattjj): investigate this failure
  def testOnesBroadcastingConstantHandler(self):
    # TODO(mattjj): update this test for jax3
    self.skipTest("test needs jax3 update")

    def fun(x):
      ones = lnp.ones((3, 4))
      assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)

      # To check that the constant handler generates a Broadcast for stride-zero
      # arrays, we monkey-patch the client instance.
      # TODO(mattjj): once we have better HLO dumping and inspecting facilities,
      # we can check the HLO more directly.
      c = x._node.c
      Broadcast = c.Broadcast  # pylint: disable=invalid-name
      was_called = []
      c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
      out = x + ones  # the ndarray constant handler should call Broadcast here
      assert was_called, "Broadcast was not called."

      return out

    fun = api.jit(fun)
    out_val = fun(lnp.ones(4))
    self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)

  def testZeroStridesConstantHandler(self):
    raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
    const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))

    def fun(x):
      return x * const

    fun = api.jit(fun)
    out_val = fun(3.)
    self.assertAllClose(out_val, 3. * const, check_dtypes=False)

  def testIsInstanceNdarrayDuringTracing(self):
    arr = onp.ones(3)

    @api.jit
    def f(x):
      self.assertIsInstance(x, lnp.ndarray)
      return lnp.sum(x)

    f(arr)


  def testNonArrayErrorMessage(self):
    x = [1., 2.]
    y = onp.array([3., 4.])

    def g(x, y):
      return lnp.add(x, y)

    def f(x, y):
      return lnp.dot(x, y)

    self.assertRaises(TypeError, lambda: g(x, y))
    self.assertRaises(TypeError, lambda: f(x, y))
    self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
    self.assertRaises(TypeError, lambda: api.jit(f)(x, y))

  def testAbstractionErrorMessage(self):

    @api.jit
    def f(x, n):
      for _ in range(n):
        x = x * x
      return x

    self.assertRaises(TypeError, lambda: f(3., 3))

    @api.jit
    def g(x):
      if x > 0.:
        return x * 2
      else:
        return x + 2

    self.assertRaises(TypeError, lambda: g(3.))

  def testTracingPrimitiveWithNoTranslationErrorMessage(self):
    # TODO(mattjj): update this for jax3
    self.skipTest("test needs jax3 update")
    foo = lnp._not_implemented(lambda x: x)

    # No error if there's no tracing.
    foo(onp.arange(3))

    cfoo = api.jit(foo)
    self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_axis={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rng, "shape": shape, "dtype": dtype, "axis": axis}
      for shape in [(3,), (2, 3)]
      for dtype in default_dtypes
      for axis in range(len(shape))
      for rng in [jtu.rand_default()]))
  def testFlip(self, shape, dtype, axis, rng):
    args_maker = self._GetArgsMaker(rng, [shape], [dtype])
    lnp_op = lambda x: lnp.flip(x, axis)
    onp_op = lambda x: onp.flip(x, axis)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_{}_k={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k, axes),
       "rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes}
      for shape, axes in [
          [(2, 3), (0, 1)],
          [(2, 3), (1, 0)],
          [(4, 3, 2), (0, 2)],
          [(4, 3, 2), (2, 1)],
      ]
      for k in range(-3, 4)
      for dtype in default_dtypes
      for rng in [jtu.rand_default()]))
  def testRot90(self, shape, dtype, k, axes, rng):
    args_maker = self._GetArgsMaker(rng, [shape], [dtype])
    lnp_op = lambda x: lnp.rot90(x, k, axes)
    onp_op = lambda x: onp.rot90(x, k, axes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  # TODO(mattjj): test infix operator overrides

  def testRavel(self):
    # TODO(mattjj): support this method-based syntax?
    rng = onp.random.RandomState(0)
    args_maker = lambda: [rng.randn(3, 4).astype("float32")]
    self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)

  def testAstype(self):
    rng = onp.random.RandomState(0)
    args_maker = lambda: [rng.randn(3, 4).astype("float32")]
    op = lambda x: x.astype(lnp.int32)
    self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True)
    self._CompileAndCheck(op, args_maker, check_dtypes=True)

  # TODO(mattjj): test other ndarray-like method overrides

  def testOnpMean(self):
    # from https://github.com/google/jax/issues/125
    x = lax.add(lnp.eye(3), 0.)
    ans = onp.mean(x)
    self.assertAllClose(ans, onp.array([1./3, 1./3, 1./3]), check_dtypes=False)

  # TODO(mattjj): more exhaustive arange tests
  def testArangeOnFloats(self):
    # from https://github.com/google/jax/issues/145
    expected = onp.arange(0.0, 1.0, 0.1)
    ans = lnp.arange(0.0, 1.0, 0.1)
    self.assertAllClose(expected, ans, check_dtypes=True)
コード例 #6
0
ファイル: scipy_stats_test.py プロジェクト: zhaowilliam/jax
class LaxBackedScipyStatsTests(jtu.JaxTestCase):
  """Tests for LAX-backed scipy.stats implementations"""

  @genNamedParametersNArgs(3)
  def testPoissonLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.logpmf
    lax_fun = lsp_stats.poisson.logpmf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
      loc = np.floor(loc)
      return [k, mu, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})

  @genNamedParametersNArgs(3)
  def testPoissonPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.pmf
    lax_fun = lsp_stats.poisson.pmf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
      loc = np.floor(loc)
      return [k, mu, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testPoissonCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.poisson.cdf
    lax_fun = lsp_stats.poisson.cdf

    def args_maker():
      k, mu, loc = map(rng, shapes, dtypes)
      # clipping to ensure that rate parameter is strictly positive
      mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
      return [k, mu, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testBernoulliLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.bernoulli.logpmf
    lax_fun = lsp_stats.bernoulli.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = np.floor(x)
      p = expit(logit)
      loc = np.floor(loc)
      return [x, p, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testGeomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.geom.logpmf
    lax_fun = lsp_stats.geom.logpmf

    def args_maker():
      x, logit, loc = map(rng, shapes, dtypes)
      x = np.floor(x)
      p = expit(logit)
      loc = np.floor(loc)
      return [x, p, loc]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(5)
  def testBetaLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.beta.logpdf
    lax_fun = lsp_stats.beta.logpdf

    def args_maker():
      x, a, b, loc, scale = map(rng, shapes, dtypes)
      return [x, a, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker,
                          rtol={np.float32: 2e-3, np.float64: 1e-4})

  @genNamedParametersNArgs(3)
  def testCauchyLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.cauchy.logpdf
    lax_fun = lsp_stats.cauchy.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @parameterized.named_parameters(
    jtu.cases_from_list(
      {"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes),
        "shapes": [x_shape, alpha_shape], "dtypes": dtypes}
      for x_shape in one_and_two_dim_shapes
      for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)]
      for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2)
  ))
  def testDirichletLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())

    def _normalize(x, alpha):
      x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1)
      return (x / x_norm).astype(x.dtype), alpha

    def lax_fun(x, alpha):
      return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha))

    def scipy_fun(x, alpha):
      # scipy validates the x normalization using float64 arithmetic, so we must
      # cast x to float64 before normalization to ensure this passes.
      x, alpha = _normalize(x.astype('float64'), alpha)

      result = osp_stats.dirichlet.logpdf(x, alpha)
      # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays
      # of a consistent rank. This check ensures the results have the same shape.
      return result if x.ndim == 1 else np.atleast_1d(result)

    def args_maker():
      # Don't normalize here, because we want normalization to happen at 64-bit
      # precision in the scipy version.
      x, alpha = map(rng, shapes, dtypes)
      return x, alpha

    tol = {np.float32: 1E-3, np.float64: 1e-5}
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=tol)
    self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)

  @genNamedParametersNArgs(3)
  def testExponLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.expon.logpdf
    lax_fun = lsp_stats.expon.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(4)
  def testGammaLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.gamma.logpdf
    lax_fun = lsp_stats.gamma.logpdf

    def args_maker():
      x, a, loc, scale = map(rng, shapes, dtypes)
      return [x, a, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testLaplaceLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.laplace.logpdf
    lax_fun = lsp_stats.laplace.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testLaplaceCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.laplace.cdf
    lax_fun = lsp_stats.laplace.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # ensure that scale is not too low
      scale = np.clip(scale, a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol={np.float32: 1e-5, np.float64: 1e-6})
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.cdf
    lax_fun = lsp_stats.logistic.cdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticLogpdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.logpdf
    lax_fun = lsp_stats.logistic.logpdf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticPpf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.ppf
    lax_fun = lsp_stats.logistic.ppf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(1)
  def testLogisticSf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.logistic.sf
    lax_fun = lsp_stats.logistic.sf

    def args_maker():
      return list(map(rng, shapes, dtypes))

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(3)
  def testNormLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.logpdf
    lax_fun = lsp_stats.norm.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testNormLogCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.logcdf
    lax_fun = lsp_stats.norm.logcdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testNormCdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.cdf
    lax_fun = lsp_stats.norm.cdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
      return [x, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-6)
    self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(3)
  def testNormPpf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.norm.ppf
    lax_fun = lsp_stats.norm.ppf

    def args_maker():
      q, loc, scale = map(rng, shapes, dtypes)
      # ensure probability is between 0 and 1:
      q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
      return [q, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)


  @genNamedParametersNArgs(4)
  def testParetoLogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.pareto.logpdf
    lax_fun = lsp_stats.pareto.logpdf

    def args_maker():
      x, b, loc, scale = map(rng, shapes, dtypes)
      return [x, b, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker)


  @genNamedParametersNArgs(4)
  def testTLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.t.logpdf
    lax_fun = lsp_stats.t.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      # clipping to ensure that scale is not too low
      scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
      return [x, df, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-3)
    self._CompileAndCheck(lax_fun, args_maker,
                          rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})


  @genNamedParametersNArgs(3)
  def testUniformLogPdf(self, shapes, dtypes):
    rng = jtu.rand_default(self.rng())
    scipy_fun = osp_stats.uniform.logpdf
    lax_fun = lsp_stats.uniform.logpdf

    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      return [x, loc, np.abs(scale)]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=1e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(4)
  def testChi2LogPdf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    scipy_fun = osp_stats.chi2.logpdf
    lax_fun = lsp_stats.chi2.logpdf

    def args_maker():
      x, df, loc, scale = map(rng, shapes, dtypes)
      return [x, df, loc, scale]

    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                            tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker)

  @genNamedParametersNArgs(5)
  def testBetaBinomLogPmf(self, shapes, dtypes):
    rng = jtu.rand_positive(self.rng())
    lax_fun = lsp_stats.betabinom.logpmf

    def args_maker():
      k, n, a, b, loc = map(rng, shapes, dtypes)
      k = np.floor(k)
      n = np.ceil(n)
      a = np.clip(a, a_min = 0.1, a_max = None)
      b = np.clip(a, a_min = 0.1, a_max = None)
      loc = np.floor(loc)
      return [k, n, a, b, loc]

    if scipy_version >= (1, 4):
      scipy_fun = osp_stats.betabinom.logpmf
      self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
                              tol=5e-4)
    self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)

  def testIssue972(self):
    self.assertAllClose(
      np.ones((4,), np.float32),
      lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)),
      check_dtypes=False)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_x={}_mean={}_cov={}".format(
          jtu.format_shape_dtype_string(x_shape, x_dtype),
          jtu.format_shape_dtype_string(mean_shape, mean_dtype)
          if mean_shape is not None else None,
          jtu.format_shape_dtype_string(cov_shape, cov_dtype)
          if cov_shape is not None else None),
       "x_shape": x_shape, "x_dtype": x_dtype,
       "mean_shape": mean_shape, "mean_dtype": mean_dtype,
       "cov_shape": cov_shape, "cov_dtype": cov_dtype}
      for x_shape, mean_shape, cov_shape in [
          # # These test cases cover default values for mean/cov, but we don't
          # # support those yet (and they seem not very valuable).
          # [(), None, None],
          # [(), (), None],
          # [(2,), None, None],
          # [(2,), (), None],
          # [(2,), (2,), None],
          # [(3, 2), (3, 2,), None],
          # [(5, 3, 2), (5, 3, 2,), None],

          [(), (), ()],
          [(3,), (), ()],
          [(3,), (3,), ()],
          [(3,), (3,), (3, 3)],
          [(3, 4), (4,), (4, 4)],

          # # These test cases are where scipy flattens things, which has
          # # different batch semantics than some might expect
          # [(5, 3, 2), (5, 3, 2,), ()],
          # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)],
          # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)],
          # [(5, 3, 2), (3, 2,), (2, 2)],
      ]
      for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3)
      if (mean_shape is not None or mean_dtype == np.float32)
      and (cov_shape is not None or cov_dtype == np.float32)))
  def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape,
                                   mean_dtype, cov_shape, cov_dtype):
    rng = jtu.rand_default(self.rng())
    def args_maker():
      args = [rng(x_shape, x_dtype)]
      if mean_shape is not None:
        args.append(5 * rng(mean_shape, mean_dtype))
      if cov_shape is not None:
        if cov_shape == ():
          args.append(0.1 + rng(cov_shape, cov_dtype) ** 2)
        else:
          factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
          factor = rng(factor_shape, cov_dtype)
          args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
      return args

    self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
                            lsp_stats.multivariate_normal.logpdf,
                            args_maker, tol=1e-3)
    self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
                          rtol=1e-4, atol=1e-4)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__),
       "ndim": ndim, "nbatch": nbatch, "dtype": dtype}
      for ndim in [2, 3]
      for nbatch in [1, 3, 5]
      for dtype in jtu.dtypes.floating))
  def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype):
    # Regression test for #5570
    rng = jtu.rand_default(self.rng())
    x = rng((nbatch, ndim), dtype)
    mean = 5 * rng((nbatch, ndim), dtype)
    factor = rng((nbatch, ndim, 2 * ndim), dtype)
    cov = factor @ factor.transpose(0, 2, 1)

    result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
    result2 = api.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
    self.assertArraysEqual(result1, result2)
コード例 #7
0
class LaxAutodiffTest(jtu.JaxTestCase):

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.name, shapes, itertools.repeat(dtype)),
         "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype,
         "order": rec.order, "tol": rec.tol}
        for shape_group in compatible_shapes
        for shapes in CombosWithReplacement(shape_group, rec.nargs)
        for dtype in rec.dtypes)
      for rec in LAX_GRAD_OPS))
  def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.pow:
      raise SkipTest("pow grad imprecise on tpu")
    tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol
    args = tuple(rng(shape, dtype) for shape in shapes)
    check_grads(op, args, order, ["fwd", "rev"], tol, tol)

  @parameterized.named_parameters(itertools.chain.from_iterable(
      jtu.cases_from_list(
          {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value),
           "op": rec.op, "special_value": special_value, "tol": rec.tol}
          for special_value in rec.values)
      for rec in LAX_GRAD_SPECIAL_VALUE_TESTS))
  def testOpGradSpecialValue(self, op, special_value, tol):
    check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_from_dtype={}_to_dtype={}".format(
          jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)),
       "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory}
      for from_dtype, to_dtype in itertools.product(
          float_dtypes + complex_dtypes, repeat=2)
      for rng_factory in [jtu.rand_default]))
  def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance),
              jtu.tolerance(from_dtype, jtu.default_gradient_tolerance))
    args = (rng((2, 3), from_dtype),)
    convert_element_type = lambda x: lax.convert_element_type(x, to_dtype)
    convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)(
      convert_element_type)
    check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype)),
       "shape": shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in grad_float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testClampGrad(self, shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    low = operand - dtype(10)
    high = operand + dtype(10)
    # Avoids points near the boundary where the gradient may be inaccurate.
    check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2)
    check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format(
          dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name,
          num_arrs),
       "dim": dim, "base_shape": base_shape, "dtype": dtype,
       "num_arrs": num_arrs, "rng_factory": rng_factory}
      for num_arrs in [3]
      for dtype in float_dtypes
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for dim in range(len(base_shape))
      for rng_factory in [jtu.rand_default]))
  def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory):
    rng = rng_factory(self.rng())
    shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))]
    operands = tuple(rng(shape, dtype) for shape in shapes)
    concatenate = lambda *args: lax.concatenate(args, dim)
    check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "rng_factory": rng_factory,}
       for lhs_shape, rhs_shape, all_strides in itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)])])
       for strides in all_strides
       for dtype in float_dtypes
       for padding in ["VALID", "SAME"]
       for rng_factory in [jtu.rand_small]))
  def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv, window_strides=strides, padding=padding,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory}
       for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in
       itertools.chain(
           [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)],
             [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))],
             [(1, 1), (2, 1)], [(1, 1)])
            for b, i, j in itertools.product([2, 3], repeat=3)],
           [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)],
             [(1,), (2,)], [(1,), (2,)])])
       for strides in all_strides
       for rhs_dil in rhs_dils
       for lhs_dil in lhs_dils
       for dtype in float_dtypes
       for padding in all_pads
       for rng_factory in [jtu.rand_small]))
  def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                     padding, lhs_dil, rhs_dil, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_with_general_padding, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=1e-2, rtol=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_"
       "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums),
               feature_group_count, batch_group_count),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "strides": strides, "padding": padding, "lhs_dil": lhs_dil,
       "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums,
       "perms": perms, "feature_group_count": feature_group_count,
       "batch_group_count": batch_group_count}
      for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)])
      for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [
          ([(b * batch_group_count, i * feature_group_count, 6, 7),
            (b * batch_group_count, i * feature_group_count, 0, 4)],  # lhs_shape
           (j * batch_group_count * feature_group_count, i, 1, 2),  # rhs_shape
           [(1, 1), (1, 2), (2, 1)],  # strides
           [(1, 1), (2, 1)],  # lhs_dils
           [(1, 1), (2, 2)])  # rhs_dils
          for b, i, j in itertools.product([1, 2], repeat=3)]
      for lhs_shape in lhs_shapes
      for strides in all_strides
      for rhs_dil in rhs_dils
      for lhs_dil in lhs_dils
      for dtype in grad_float_dtypes
      for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] +
        ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else []))
      for dim_nums, perms in [
          (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])),
          (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])),
          (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))]
      for rng_factory in [jtu.rand_default]
  ))
  def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides,
                                 padding, lhs_dil, rhs_dil, dimension_numbers,
                                 perms, feature_group_count, batch_group_count,
                                 rng_factory):
    if dtype == onp.float16:
      raise SkipTest("float16 numerical issues")  # TODO(mattjj): resolve

    rng = rng_factory(self.rng())
    tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4}

    # permute shapes to match dim_spec, scale by feature_group_count
    lhs_perm, rhs_perm = perms
    lhs_shape = list(onp.take(lhs_shape, lhs_perm))
    rhs_shape = list(onp.take(rhs_shape, rhs_perm))

    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    conv = partial(lax.conv_general_dilated, window_strides=strides,
                   padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil,
                   dimension_numbers=dimension_numbers,
                   feature_group_count=feature_group_count,
                   batch_group_count=batch_group_count,
                   precision=lax.Precision.HIGHEST)
    check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_lhs_shape={}_rhs_shape={}".format(
          jtu.format_shape_dtype_string(lhs_shape, dtype),
          jtu.format_shape_dtype_string(rhs_shape, dtype)),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "rng_factory": jtu.rand_default}
      for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)]
      for dtype in float_dtypes))
  def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    tol = {onp.float16: 1e-1, onp.float32: 1e-4}
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot = partial(lax.dot, precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"],
                         atol=tol, rtol=tol)
    # check that precision config is preserved
    result, pullback = api.vjp(dot, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_lhs_shape={}_rhs_shape={}_dimension_numbers={}"
       .format(jtu.format_shape_dtype_string(lhs_shape, dtype),
               jtu.format_shape_dtype_string(rhs_shape, dtype),
               dimension_numbers),
       "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype,
       "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small}
      for lhs_shape, rhs_shape, dimension_numbers in [
          ((3, 2), (2, 4), (([1], [0]), ([], []))),
          ((3, 5), (2, 5), (([1], [1]), ([], []))),
          ((5, 3), (5, 2), (([0], [0]), ([], []))),
          ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))),
      ]
      for dtype in float_dtypes))
  def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype,
                                          dimension_numbers, rng_factory):
    rng = rng_factory(self.rng())
    lhs = rng(lhs_shape, dtype)
    rhs = rng(rhs_shape, dtype)
    dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers,
                          precision=lax.Precision.HIGHEST)
    check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"])
    # check that precision config is preserved
    result, pullback = api.vjp(dot_general, lhs, rhs)
    gresult = lax.zeros_like_array(result)
    s = str(api.make_jaxpr(pullback)(gresult))
    assert "precision=HIGHEST" in s

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format(
          shape, onp.dtype(dtype).name, broadcast_sizes),
       "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes,
       "rng_factory": rng_factory}
      for shape in [(), (2, 3)]
      for dtype in float_dtypes
      for broadcast_sizes in [(), (2,), (1, 2)]
      for rng_factory in [jtu.rand_default]))
  def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
    rng = rng_factory(self.rng())
    args = (rng(shape, dtype),)
    broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
    check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format(
          jtu.format_shape_dtype_string(inshape, dtype),
          outshape, broadcast_dimensions),
       "inshape": inshape, "dtype": dtype, "outshape": outshape,
       "dimensions": broadcast_dimensions, "rng_factory": rng_factory}
      for inshape, outshape, broadcast_dimensions in [
          ([2], [2, 2], [0]),
          ([2], [2, 2], [1]),
          ([2], [2, 3], [0]),
          ([], [2, 3], []),
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(inshape, dtype)
    broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
    check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_outshape={}_perm={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype),
          permutation),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng_factory": rng_factory, "permutation": permutation}
      for dtype in float_dtypes
      for arg_shape, out_shape, permutation in [
          [(3, 4), (12,), None],
          [(2, 1, 4), (8,), None],
          [(2, 2, 4), (2, 8), None],
          [(3, 4), (12,), (0, 1)],
          [(3, 4), (12,), (1, 0)],
          [(2, 1, 4), (8,), (0, 2, 1)],
          [(2, 1, 4), (8,), (2, 0, 1)],
          [(2, 2, 4), (2, 8), (0, 2, 1)],
          [(2, 2, 4), (2, 8), (2, 0, 1)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(arg_shape, dtype)
    reshape = lambda x: lax.reshape(x, out_shape, permutation)
    check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_pads={}"
       .format(jtu.format_shape_dtype_string(shape, dtype), pads),
       "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small}
      for shape in [(2, 3)]
      for dtype in float_dtypes
      for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]]))
  def testPadGrad(self, shape, dtype, pads, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads)
    check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.)

    operand = rng(shape, dtype)
    padding_value = onp.array(0., dtype)
    pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads)
    check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)

  def testReverseGrad(self):
    rev = lambda operand: lax.rev(operand, dimensions)

    dimensions = [0]
    check_grads(rev, (onp.array([3., 2., 1.]),), 2)

    dimensions = [0, 1]
    check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
                rtol={onp.float32: 3e-3})

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_predshape={}_argshapes={}".format(
          jtu.format_shape_dtype_string(pred_shape, onp.bool_),
          jtu.format_shape_dtype_string(arg_shape, dtype)),
       "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype,
       "rng_factory": rng_factory}
      for arg_shape in [(), (3,), (2, 3)]
      for pred_shape in ([(), arg_shape] if arg_shape else [()])
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory):
    rng = rng_factory(self.rng())
    pred = rng(pred_shape, onp.bool_)
    on_true = rng(arg_shape, dtype)
    on_false = rng(arg_shape, dtype)
    select = lambda on_true, on_false: lax.select(pred, on_true, on_false)
    check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name":
       "_shape={}_start_indices={}_limit_indices={}_strides={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, limit_indices, strides),
       "shape": shape, "dtype": dtype, "starts": start_indices,
       "limits": limit_indices, "strides": strides, "rng_factory": rng_factory}
      for shape, start_indices, limit_indices, strides in [
        [(3,), (1,), (2,), None],
        [(7,), (4,), (7,), None],
        [(5,), (1,), (5,), (2,)],
        [(8,), (1,), (6,), (2,)],
        [(5, 3), (1, 1), (3, 2), None],
        [(5, 3), (1, 1), (3, 1), None],
        [(7, 5, 3), (4, 0, 1), (7, 1, 3), None],
        [(5, 3), (1, 1), (2, 1), (1, 1)],
        [(5, 3), (1, 1), (5, 3), (2, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    slice = lambda x: lax.slice(x, starts, limits, strides)
    check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, size_indices),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "size_indices": size_indices, "rng_factory": rng_factory}
      for shape, start_indices, size_indices in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices,
                           rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
    check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format(
          jtu.format_shape_dtype_string(shape, dtype),
          start_indices, update_shape),
       "shape": shape, "dtype": dtype, "start_indices": start_indices,
       "update_shape": update_shape, "rng_factory": rng_factory}
      for shape, start_indices, update_shape in [
        [(3,), (1,), (1,)],
        [(5, 3), (1, 1), (3, 1)],
        [(7, 5, 3), (4, 1, 0), (2, 0, 1)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices,
                                 update_shape, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    update = rng(update_shape, dtype)
    start_indices = onp.array(start_indices)

    dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices)
    check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.)

    dus = lambda x: lax.dynamic_update_slice(x, update, start_indices)
    check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.)

    dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices)
    check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_perm={}".format(
          jtu.format_shape_dtype_string(shape, dtype), perm),
       "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory}
      for shape, perm in [
        [(3, 4), (1, 0)],
        [(3, 4), (0, 1)],
        [(3, 4, 5), (2, 1, 0)],
        [(3, 4, 5), (1, 0, 2)],
      ]
      for dtype in float_dtypes
      for rng_factory in [jtu.rand_default]))
  def testTransposeGrad(self, shape, dtype, perm, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    transpose = lambda x: lax.transpose(x, perm)
    check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_inshape={}_reducedims={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims),
       "op": op, "init_val": init_val, "shape": shape, "dtype": dtype,
       "dims": dims, "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, inexact_dtypes, jtu.rand_default),
          (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int),
          (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int),
          (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)),
      ]
      for dtype in dtypes
      for shape, dims in [
          [(3, 4, 5), ()],
          [(3, 4, 5), (0,)],
          [(3, 4, 5), (1, 2)],
          [(3, 4, 5), (0, 2)],
          [(3, 4, 5), (0, 1, 2)],
          [(3, 1), (1,)],
      ]))
  def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory):
    rng = rng_factory(self.rng())
    if jtu.device_under_test() == "tpu" and op is lax.mul:
      raise SkipTest("unimplemented case")
    tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1,
           onp.float64: 1e-3, onp.complex64: 1e-1}
    operand = rng(shape, dtype)
    init_val = onp.asarray(init_val, dtype=dtype)
    reduce = lambda operand: lax.reduce(operand, init_val, op, dims)
    eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else
           1e-1 if dtype == dtypes.bfloat16 else
           1e-2 if dtypes.finfo(dtype).bits == 32 else None)
    check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_dtype={}_padding={}"
       .format(op.__name__, onp.dtype(dtype).name, padding),
       "op": op, "init_val": init_val, "dtype": dtype, "padding": padding,
       "rng_factory": rng_factory}
      for init_val, op, dtypes, rng_factory in [
          (0, lax.add, grad_float_dtypes, jtu.rand_small),
          (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int),
          (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int),
      ]
      for dtype in dtypes
      for padding in ["VALID", "SAME"]))
  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  @jtu.ignore_warning(category=UserWarning,
                      message="Using reduced precision for gradient.*")
  def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory):
    rng = rng_factory(self.rng())
    init_val = onp.asarray(init_val, dtype=dtype)

    # We need this conditional and the corresponding loop logic to be in the
    # test method, rather than at the parameterized test level, because it
    # depends on FLAGS for the device under test.
    # TODO(b/31565929): enable when fixed.
    if jtu.device_under_test() == "tpu" and op is not lax.add:
      all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))]

      # TODO(b/73062247): need variadic reduce-window for better precision.
      gradient_order = 1
    else:
      all_configs = itertools.chain(
          itertools.product(
              [(4, 6)],  # shapes
              [(2, 1), (1, 2)],  # window_dimensions
              [(1, 1), (2, 1), (1, 2)]  # strides
          ),
          itertools.product(
              [(3, 2, 4, 6)],  # shapes
              [(1, 1, 2, 1), (2, 1, 2, 1)],  # window_dimensions
              [(1, 2, 2, 1), (1, 1, 1, 1)]),  # strides
      )
      gradient_order = 3

    def fun(operand):
      return lax.reduce_window(operand, init_val, op, dims, strides, padding)

    for shape, dims, strides in all_configs:
      operand = rng(shape, dtype)
      if op is lax.add:
        eps = 1.
        tol = None
      else:
        # this test can fail if there are duplicates in operand
        self.assertEqual(onp.unique(operand).size, operand.size,
                         msg="test requires operand elements to be unique.")
        eps = 1e-2
        tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2}
      check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol,
                  eps)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_op={}_shape={}_axis={}"
       .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis),
       "op": op, "shape": shape, "dtype": dtype,
       "axis": axis, "rng_factory": rng_factory}
      for op, types in [
          (lax.cumsum, [onp.float32, onp.float64]),
          (lax.cumprod, [onp.float32, onp.float64]),
      ]
      for dtype in types
      for shape in [[10], [3, 4, 5]]
      for axis in range(len(shape))
      for rng_factory in [
          jtu.rand_default if dtypes.issubdtype(dtype, onp.integer)
          else jtu.rand_small]))
  def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
    rng = rng_factory(self.rng())
    check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)


  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, is_stable),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis,
       "is_stable": is_stable}
      for dtype in [onp.float32]
      for shape in [(5,), (5, 7)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory):
    rng = rng_factory(self.rng())
    operand = rng(shape, dtype)
    sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
    check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)

  # TODO(b/205052657): enable more tests when supported
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format(
          jtu.format_shape_dtype_string(shape, key_dtype),
          jtu.format_shape_dtype_string(shape, val_dtype),
          axis, is_stable),
       "rng_factory": rng_factory, "shape": shape,
       "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis,
       "is_stable": is_stable}
      for key_dtype in [onp.float32]
      for val_dtype in [onp.float32]
      for shape in [(3,), (5, 3)]
      for axis in [len(shape) - 1]
      for is_stable in [False, True]
      for rng_factory in [jtu.rand_default]))
  def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable,
                         rng_factory):
    rng = rng_factory(self.rng())
    # This test relies on the property that wherever keys are tied, values are
    # too, since we don't guarantee the same ordering of values with equal keys.
    # To avoid that case, we generate unique keys (globally in the key array).
    def args_maker():
      flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype)
      keys = self.rng().permutation(flat_keys).reshape(shape)
      values = rng(shape, val_dtype)
      return keys, values
    keys, values = args_maker()

    fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable)
    check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_k={}".format(
          jtu.format_shape_dtype_string(shape, dtype), k),
       "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k}
      for dtype in [onp.float32,]
      for shape in [(4,), (5, 5), (2, 1, 4)]
      for k in [1, 3]
      for rng_factory in [jtu.rand_default]))
  def testTopKGrad(self, shape, dtype, k, rng_factory):
    flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype)
    values = self.rng().permutation(flat_values).reshape(shape)
    fun = lambda vs: lax.top_k(vs, k=k)[0]
    check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_axes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, axes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes,
       "rng_factory": rng_factory}
      for dtype in float_dtypes
      for shape, idxs, axes in [
          [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)],
          [(3, 4, 5), (onp.array([-1, -2]),), (0,)],
          [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)],
          [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)],
      ]
      for rng_factory in [jtu.rand_default]))
  def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
    rng = rng_factory(self.rng())
    src = rng(shape, dtype)
    index_take = lambda src: lax.index_take(src, idxs, axes)
    check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format(
          jtu.format_shape_dtype_string(shape, dtype), idxs, dnums,
          slice_sizes),
       "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums,
       "slice_sizes": slice_sizes, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for shape, idxs, dnums, slice_sizes, max_idx in [
          ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers(
            offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1,), 5),
          ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)),
            (2,), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers(
            offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)),
            (1, 3), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory,
                     rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums,
                                  slice_sizes=slice_sizes)
    x = rng(shape, dtype)
    check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                         rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_add = lambda x, y: lax.scatter_add(x, idxs, y,
                                               dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory,
       "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums, max_idx in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 4),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,)), 9),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,)), 3),
      ]
      # Scatters with conflicting indices are not deterministic on GPU, so we
      # use indices that do not collide.
      for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)]
      for rng_factory in [jtu.rand_default]))
  def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums,
                      rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.)

  def testScatterGradSymbolicZeroUpdate(self):
    # https://github.com/google/jax/issues/1901
    def f(x):
      n = x.shape[0]
      y = onp.arange(n, dtype=x.dtype)
      return jax.ops.index_update(x, onp.diag_indices(n), y)
    rng = jtu.rand_default(self.rng())
    check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2,
                1.)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          idxs, update_shape, dnums),
       "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs,
       "update_shape": update_shape, "dnums": dnums,
       "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory}
      for dtype in grad_float_dtypes
      for arg_shape, idxs, update_shape, dnums in [
          ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers(
            update_window_dims=(), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
          ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(),
            scatter_dims_to_operand_dims=(0,))),
          ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(
            update_window_dims=(1,), inserted_window_dims=(0,),
            scatter_dims_to_operand_dims=(0,))),
      ]
      for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))]
      for rng_factory in [jtu.rand_default]))
  def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums,
                     rng_factory, rng_idx_factory):
    rng = rng_factory(self.rng())
    rng_idx = rng_idx_factory(self.rng())
    idxs = rng_idx(idxs.shape, idxs.dtype)
    scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums)
    x = rng(arg_shape, dtype)
    y = rng(update_shape, dtype)
    check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2)

  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = onp.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with core.skipping_checks():
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)

  # TODO(mattjj): make this a more systematic test
  def testRemainder(self):
    rng = onp.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(3, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 1))
    assert not set(onp.unique(x)) & set(onp.unique(y))
    tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

    rng = onp.random.RandomState(0)
    x = rng.uniform(-0.9, 9, size=(1, 4))
    y = rng.uniform(0.7, 1.9, size=(3, 4))
    assert not set(onp.unique(x)) & set(onp.unique(y))
    tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3
    check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol)

  def testHigherOrderGradientOfReciprocal(self):
    # Regression test for https://github.com/google/jax/issues/3136
    def inv(x):
      # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x)
      return 1 / x
    grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv))))))
    self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
コード例 #8
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_axis={}_keepdims={}".format(
                jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "rng_factory":
            jtu.rand_some_inf_and_nan
            if jtu.device_under_test() != "cpu" else jtu.rand_default,
            "shape":
            shape,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims
        } for shape in all_shapes for dtype in float_dtypes
                            for axis in range(-len(shape), len(shape))
                            for keepdims in [False, True]))
    @jtu.skip_on_flag("jax_xla_backend", "xrt")
    def testLogSumExp(self, rng_factory, shape, dtype, axis, keepdims):
        rng = rng_factory(self.rng())

        # TODO(mattjj): test autodiff
        def scipy_fun(array_to_reduce):
            return osp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        def lax_fun(array_to_reduce):
            return lsp_special.logsumexp(array_to_reduce,
                                         axis,
                                         keepdims=keepdims)

        args_maker = lambda: [rng(shape, dtype)]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True)
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in CombosWithReplacement(all_shapes, rec.nargs)
                for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff):
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5)

        if test_autodiff:
            jtu.check_grads(lax_op,
                            args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "rng_factory":
            jtu.rand_positive,
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, rng_factory, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = rng_factory(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                check_dtypes=True,
                                tol={
                                    onp.float32: 1e-3,
                                    onp.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

    def testIssue980(self):
        x = onp.full((4, ), -1e20, dtype=onp.float32)
        self.assertAllClose(onp.zeros((4, ), dtype=onp.float32),
                            lsp_special.expit(x),
                            check_dtypes=True)

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(api.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)
コード例 #9
0
ファイル: lax_scipy_test.py プロジェクト: zhaowilliam/jax
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
                jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims,
                return_sign, use_b),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "shapes":
            shapes,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "return_sign":
            return_sign,
            "use_b":
            use_b
        } for shape_group in compatible_shapes
                            for dtype in float_dtypes + int_dtypes
                            for use_b in [False, True]
                            for shapes in itertools.product(
                                *((shape_group,
                                   shape_group) if use_b else (shape_group, )))
                            for axis in range(
                                -max(len(shape) for shape in shapes),
                                max(len(shape) for shape in shapes))
                            for keepdims in [False, True]
                            for return_sign in [False, True]))
    @jtu.ignore_warning(category=RuntimeWarning,
                        message="invalid value encountered in .*")
    def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b):
        if jtu.device_under_test() != "cpu":
            rng = jtu.rand_some_inf_and_nan(self.rng())
        else:
            rng = jtu.rand_default(self.rng())
        # TODO(mattjj): test autodiff
        if use_b:

            def scipy_fun(array_to_reduce, scale_array):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            def lax_fun(array_to_reduce, scale_array):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
        else:

            def scipy_fun(array_to_reduce):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            def lax_fun(array_to_reduce):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            args_maker = lambda: [rng(shapes[0], dtype)]
        tol = {np.float32: 1E-6, np.float64: 1E-14}
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    def testLogSumExpZeros(self):
        # Regression test for https://github.com/google/jax/issues/5370
        scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
        lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
        args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "nondiff_argnums":
                    rec.nondiff_argnums,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in itertools.combinations_with_replacement(
                    all_shapes, rec.nargs)
                for dtypes in (itertools.combinations_with_replacement(
                    rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else
                               itertools.product(*rec.dtypes)))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = jtu.rand_positive(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    def testIssue980(self):
        x = np.full((4, ), -1e20, dtype=np.float32)
        self.assertAllClose(np.zeros((4, ), dtype=np.float32),
                            lsp_special.expit(x))

    def testIssue3758(self):
        x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
        q = np.array([1., 40., 30.], dtype=np.float32)
        self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32),
                            lsp_special.zeta(x, q))

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(api.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)
コード例 #10
0
class LaxBackedNumpyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Numpy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(
      {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes,
                                                    dtypes),
       "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)}
      for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS,
                                 JAX_COMPOUND_OP_RECORDS)
      for shapes in CombosWithReplacement(all_shapes, rec.nargs)
      for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
  def testOp(self, onp_op, lnp_op, rng, shapes, dtypes):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis, "keepdims": keepdims}
      for rec in JAX_REDUCER_RECORDS
      for shape in all_shapes for dtype in rec.dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True])
  def testReducer(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims):
    onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims)
    lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "{}_inshape={}_axis={}".format(
          rec.test_name.capitalize(),
          jtu.format_shape_dtype_string(shape, dtype), axis),
       "rng": rec.rng, "shape": shape, "dtype": dtype,
       "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name),
       "axis": axis}
      for rec in JAX_ARGMINMAX_RECORDS
      for shape in all_shapes for dtype in rec.dtypes
      for axis in range(-len(shape), len(shape)))
  def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis):

    def onp_fun(array_to_reduce):
      return onp_op(array_to_reduce, axis)

    def lnp_fun(array_to_reduce):
      return lnp_op(array_to_reduce, axis)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("matrix-scalar", (3, 3), ()),
          ("scalar-matrix", (), (3, 3)),
          ("matrix-vector", (4, 5), (5,)),
          ("vector-matrix", (6,), (6, 4)),
          ("matrix-matrix", (3, 4), (4, 5)),
          ("tensor-vector", (4, 3, 2), (2,)),
          ("vector-tensor", (2,), (3, 2, 4)),
          ("tensor-matrix", (4, 3, 2), (2, 5)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-tensor", (2, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))
  def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_{}_{}".format(
          name,
          jtu.format_shape_dtype_string(lhs_shape, lhs_dtype),
          jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)),
       "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype,
       "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype,
       "rng": rng}
      for rng in [jtu.rand_default()]
      for name, lhs_shape, rhs_shape in [
          ("vector-vector", (3,), (3,)),
          ("matrix-vector", (3, 3), (3,)),
          ("vector-matrix", (3,), (3, 3)),
          ("matrix-matrix", (3, 3), (3, 3)),
          ("vector-tensor", (3,), (5, 3, 2)),
          ("tensor-vector", (5, 3, 2), (2,)),
          ("matrix-tensor", (5, 2), (3, 2, 4)),
          ("tensor-matrix", (5, 2, 3), (3, 2)),
          ("tensor-tensor", (5, 3, 4), (5, 4, 1)),
          ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))]
      for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))
  def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng):
    args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)]
    self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker,
                            check_dtypes=True)
    self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_amin={}_amax={}".format(
          jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
       "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in float_dtypes
      for a_min, a_max in [(-1, None), (None, 1), (-1, 1)])
  def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng):
    onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
    lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_decimals={}".format(
          jtu.format_shape_dtype_string(shape, dtype), decimals),
       "shape": shape, "dtype": dtype, "decimals": decimals,
       "rng": jtu.rand_default()}
      for shape in all_shapes for dtype in float_dtypes
      for decimals in [0, 1, -2])
  def testRoundStaticDecimals(self, shape, dtype, decimals, rng):
    onp_fun = lambda x: onp.round(x, decimals=decimals)
    lnp_fun = lambda x: lnp.round(x, decimals=decimals)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format(
          axis, ",".join(str(d) for d in base_shape),
          ",".join(onp.dtype(dtype).name for dtype in dtypes)),
       "axis": axis, "base_shape": base_shape, "dtypes": dtypes,
       "rng": jtu.rand_default()}
      for num_arrs in [3]
      for dtypes in CombosWithReplacement(default_dtypes, num_arrs)
      for base_shape in [(4,), (3, 4), (2, 3, 4)]
      for axis in range(-len(base_shape)+1, len(base_shape)))
  def testConcatenate(self, axis, base_shape, dtypes, rng):
    wrapped_axis = axis % len(base_shape)
    shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
              for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)]
    onp_fun = lambda *args: onp.concatenate(args, axis=axis)
    lnp_fun = lambda *args: lnp.concatenate(args, axis=axis)

    def args_maker():
      return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}".format(
          jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)),
       "shape": shape, "dtypes": dtypes, "rng": rng}
      for dtypes in [
        [onp.float32],
        [onp.float32, onp.float32],
        [onp.float32, onp.int32, onp.float32],
        [onp.float32, onp.int64, onp.float32],
        [onp.float32, onp.int32, onp.float64],
      ]
      for shape in [(), (2,), (3, 4), (1, 100)]
      for rng in [jtu.rand_default()])
  def testStack(self, shape, dtypes, rng):
    args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]]
    self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape=[{}]_indtype={}_outdtype={}".format(
          "_".join(str(d) for d in shape),
          onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name),
       "shape": shape, "fill_value_dtype": fill_value_dtype,
       "out_dtype": out_dtype, "rng": jtu.rand_default()}
      for shape in all_shapes
      for fill_value_dtype in default_dtypes
      for out_dtype in default_dtypes)
  def testFull(self, shape, fill_value_dtype, out_dtype, rng):
    onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
    lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
    args_maker = lambda: [rng((), fill_value_dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_{}_axis={}_{}sections".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, num_sections),
       "shape": shape, "num_sections": num_sections, "axis": axis,
       "dtype": dtype, "rng": jtu.rand_default()}
      for shape, axis, num_sections in [
          ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2),
          ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)]
      for dtype in default_dtypes)
  def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng):
    onp_fun = lambda x: onp.split(x, num_sections, axis=axis)
    lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis)
    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_outshape={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype),
          jtu.format_shape_dtype_string(out_shape, dtype)),
       "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype,
       "rng": jtu.rand_default()}
      for dtype in default_dtypes
      for arg_shape, out_shape in [
          ((3, 4), 12),
          ((3, 4), (12,)),
          ((3, 4), -1),
          ((2, 1, 4), (-1,)),
          ((2, 2, 4), (2, 8))
      ])
  def testReshape(self, arg_shape, out_shape, dtype, rng):
    onp_fun = lambda x: onp.reshape(x, out_shape)
    lnp_fun = lambda x: lnp.reshape(x, out_shape)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_expanddim={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), dim),
       "arg_shape": arg_shape, "dtype": dtype, "dim": dim,
       "rng": jtu.rand_default()}
      for arg_shape in [(), (3,), (3, 4)]
      for dtype in default_dtypes
      for dim in range(-len(arg_shape)+1, len(arg_shape)))
  def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng):
    onp_fun = lambda x: onp.expand_dims(x, dim)
    lnp_fun = lambda x: lnp.expand_dims(x, dim)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_axes=({},{})".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2),
       "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2,
       "rng": jtu.rand_default()}
      for arg_shape, ax1, ax2 in [
          ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2),
          ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)]
      for dtype in default_dtypes)
  def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng):
    onp_fun = lambda x: onp.swapaxes(x, ax1, ax2)
    lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_axis={}".format(
          jtu.format_shape_dtype_string(arg_shape, dtype), ax),
       "arg_shape": arg_shape, "dtype": dtype, "ax": ax,
       "rng": jtu.rand_default()}
      for arg_shape, ax in [
          ((3, 1), None),
          ((3, 1), 1),
          ((1, 3, 1), (0, 2)),
          ((1, 4, 1), (0,))]
      for dtype in default_dtypes)
  def testSqueeze(self, arg_shape, dtype, ax, rng):
    onp_fun = lambda x: onp.squeeze(x, ax)
    lnp_fun = lambda x: lnp.squeeze(x, ax)
    args_maker = lambda: [rng(arg_shape, dtype)]
    self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": "_arg{}".format(i), "arg": arg}
      for i, arg in enumerate([
          [1, 2, 3], [1., 2., 3.],
          [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]],
          [[3, onp.array(2), 1], onp.arange(3.)],
      ]))
  def testArray(self, arg):
    args_maker = lambda: [arg]
    self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True)
    self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True)

  def testAllClose(self):
    rng = onp.random.RandomState(0)
    x = rng.randn(2, 2)
    y = rng.randn(2)

    def same(list1, list2):
      allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3)
      elements_close = list(map(allclose, list1, list2))
      return lnp.all(lnp.array(elements_close))

    csame = api.jit(same)

    a1 = same((x, y), (x, y))
    a2 = csame((x, y), (x, y))
    a3 = csame((x, y), (x, 2 * y))

    self.assertTrue(a1)
    self.assertTrue(a2)
    self.assertFalse(a3)

  @jtu.skip_on_devices("tpu")  # TODO(mattjj): investigate this failure
  def DISABLED_testOnesBroadcastingConstantHandler(self):
    # TODO(mattjj): update this test for jax3

    def fun(x):
      ones = lnp.ones((3, 4))
      assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0)

      # To check that the constant handler generates a Broadcast for stride-zero
      # arrays, we monkey-patch the client instance.
      # TODO(mattjj): once we have better HLO dumping and inspecting facilities,
      # we can check the HLO more directly.
      c = x._node.c
      Broadcast = c.Broadcast  # pylint: disable=invalid-name
      was_called = []
      c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args)
      out = x + ones  # the ndarray constant handler should call Broadcast here
      assert was_called, "Broadcast was not called."

      return out

    fun = api.jit(fun)
    out_val = fun(lnp.ones(4))
    self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False)

  def testZeroStridesConstantHandler(self):
    raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1)
    const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6))

    def fun(x):
      return x * const

    fun = api.jit(fun)
    out_val = fun(3.)
    self.assertAllClose(out_val, 3. * const, check_dtypes=False)

  def testIsInstanceNdarrayDuringTracing(self):
    arr = onp.ones(3)

    @api.jit
    def f(x):
      self.assertIsInstance(x, lnp.ndarray)
      return lnp.sum(x)

    f(arr)


  def testNonArrayErrorMessage(self):
    x = [1., 2.]
    y = onp.array([3., 4.])

    def g(x, y):
      return lnp.add(x, y)

    def f(x, y):
      return lnp.dot(x, y)

    self.assertRaises(TypeError, lambda: g(x, y))
    self.assertRaises(TypeError, lambda: f(x, y))
    self.assertRaises(TypeError, lambda: api.jit(g)(x, y))
    self.assertRaises(TypeError, lambda: api.jit(f)(x, y))

  def testAbstractionErrorMessage(self):

    @api.jit
    def f(x, n):
      for _ in range(n):
        x = x * x
      return x

    self.assertRaises(TypeError, lambda: f(3., 3))

    @api.jit
    def g(x):
      if x > 0.:
        return x * 2
      else:
        return x + 2

    self.assertRaises(TypeError, lambda: g(3.))

  def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self):
    # TODO(mattjj): update this for jax3
    foo = lnp._not_implemented(lambda x: x)

    # No error if there's no tracing.
    foo(onp.arange(3))

    cfoo = api.jit(foo)
    self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3)))

  # TODO(mattjj): test infix operator overrides

  def DISABLED_testRavel(self):
    # TODO(mattjj): support this method-based syntax?
    rng = onp.random.RandomState(0)
    args_maker = lambda: [rng.randn(3, 4).astype("float32")]
    self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
コード例 #11
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_axis={}_keepdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": jtu.rand_default(), "shape": shape, "dtype": dtype,
       "axis": axis, "keepdims": keepdims}
      for shape in all_shapes for dtype in float_dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True]))
  @jtu.skip_on_flag("jax_xla_backend", "xrt")
  def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
    # TODO(mattjj): test autodiff
    def scipy_fun(array_to_reduce):
      return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    def lax_fun(array_to_reduce):
      return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(itertools.chain.from_iterable(
    jtu.cases_from_list(
        {"testcase_name": jtu.format_test_name_suffix(
            rec.test_name, shapes, dtypes),
         "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
         "test_autodiff": rec.test_autodiff,
         "scipy_op": getattr(osp_special, rec.name),
         "lax_op": getattr(lsp_special, rec.name)}
        for shapes in CombosWithReplacement(all_shapes, rec.nargs)
        for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
      for rec in JAX_SPECIAL_FUNCTION_RECORDS))
  def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes,
                          test_autodiff):
    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

    if test_autodiff:
      jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3)

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_inshape={}_d={}".format(
          jtu.format_shape_dtype_string(shape, dtype), d),
       "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d}
      for shape in all_shapes
      for dtype in float_dtypes
      for d in [1, 2, 5]))
  def testMultigammaln(self, rng, shape, dtype, d):
    def scipy_fun(a):
      return osp_special.multigammaln(a, d)

    def lax_fun(a):
      return lsp_special.multigammaln(a, d)

    args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  def testIssue980(self):
    x = onp.full((4,), -1e20, dtype=onp.float32)
    self.assertAllClose(onp.zeros((4,), dtype=onp.float32),
                        lsp_special.expit(x), check_dtypes=True)
コード例 #12
0
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
                jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims,
                return_sign, use_b),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "shapes":
            shapes,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "return_sign":
            return_sign,
            "use_b":
            use_b
        } for shape_group in compatible_shapes for dtype in float_dtypes +
                            complex_dtypes + int_dtypes
                            for use_b in [False, True]
                            for shapes in itertools.product(
                                *((shape_group,
                                   shape_group) if use_b else (shape_group, )))
                            for axis in range(
                                -max(len(shape) for shape in shapes),
                                max(len(shape) for shape in shapes))
                            for keepdims in [False, True]
                            for return_sign in [False, True]))
    @jtu.ignore_warning(category=RuntimeWarning,
                        message="invalid value encountered in .*")
    def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b):
        if jtu.device_under_test() != "cpu":
            rng = jtu.rand_some_inf_and_nan(self.rng())
        else:
            rng = jtu.rand_default(self.rng())
        # TODO(mattjj): test autodiff
        if use_b:

            def scipy_fun(array_to_reduce, scale_array):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            def lax_fun(array_to_reduce, scale_array):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
        else:

            def scipy_fun(array_to_reduce):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            def lax_fun(array_to_reduce):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            args_maker = lambda: [rng(shapes[0], dtype)]
        tol = {np.float32: 1E-6, np.float64: 1E-14}
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    def testLogSumExpZeros(self):
        # Regression test for https://github.com/google/jax/issues/5370
        scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
        lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
        args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "nondiff_argnums":
                    rec.nondiff_argnums,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in itertools.combinations_with_replacement(
                    all_shapes, rec.nargs)
                for dtypes in (itertools.combinations_with_replacement(
                    rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else
                               itertools.product(*rec.dtypes)))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = jtu.rand_positive(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    def testIssue980(self):
        x = np.full((4, ), -1e20, dtype=np.float32)
        self.assertAllClose(np.zeros((4, ), dtype=np.float32),
                            lsp_special.expit(x))

    def testIssue3758(self):
        x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
        q = np.array([1., 40., 30.], dtype=np.float32)
        self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32),
                            lsp_special.zeta(x, q))

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(api.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_maxdegree={}_inputsize={}".format(
                    l_max, num_z),
                "l_max": l_max,
                "num_z": num_z
            } for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
    def testLpmn(self, l_max, num_z):
        # Points on which the associated Legendre functions areevaluated.
        z = np.linspace(-0.2, 0.9, num_z)
        actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max,
                                                               n=l_max,
                                                               z=z)

        # The expected results are obtained from scipy.
        expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
        expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))

        for i in range(num_z):
            val, derivative = osp_special.lpmn(l_max, l_max, z[i])
            expected_p_vals[:, :, i] = val
            expected_p_derivatives[:, :, i] = derivative

        with self.subTest('Test values.'):
            self.assertAllClose(actual_p_vals,
                                expected_p_vals,
                                rtol=1e-6,
                                atol=3.2e-6)

        with self.subTest('Test derivatives.'):
            self.assertAllClose(actual_p_derivatives,
                                expected_p_derivatives,
                                rtol=1e-6,
                                atol=8.4e-4)

        with self.subTest('Test JIT compatibility'):
            args_maker = lambda: [z]
            lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
            self._CompileAndCheck(lsp_special_fn, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_maxdegree={}_inputsize={}".format(
                    l_max, num_z),
                "l_max": l_max,
                "num_z": num_z
            } for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
    def testNormalizedLpmnValues(self, l_max, num_z):
        # Points on which the associated Legendre functions areevaluated.
        z = np.linspace(-0.2, 0.9, num_z)
        is_normalized = True
        actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)

        # The expected results are obtained from scipy.
        expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
        for i in range(num_z):
            expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]

        def apply_normalization(a):
            """Applies normalization to the associated Legendre functions."""
            num_m, num_l, _ = a.shape
            a_normalized = np.zeros_like(a)
            for m in range(num_m):
                for l in range(num_l):
                    c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
                    c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
                    c2 = np.sqrt(c0 / c1)
                    a_normalized[m, l] = c2 * a[m, l]
            return a_normalized

        # The results from scipy are not normalized and the comparison requires
        # normalizing the results.
        expected_p_vals_normalized = apply_normalization(expected_p_vals)

        with self.subTest('Test accuracy.'):
            self.assertAllClose(actual_p_vals,
                                expected_p_vals_normalized,
                                rtol=1e-6,
                                atol=3.2e-6)

        with self.subTest('Test JIT compatibility'):
            args_maker = lambda: [z]
            lsp_special_fn = lambda z: lsp_special.lpmn_values(
                l_max, l_max, z, is_normalized)
            self._CompileAndCheck(lsp_special_fn, args_maker)

    def testSphHarmAccuracy(self):
        m = jnp.arange(-3, 3)[:, None]
        n = jnp.arange(3, 6)
        n_max = 5
        theta = 0.0
        phi = jnp.pi

        expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

        actual = osp_special.sph_harm(m, n, theta, phi)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

    def testSphHarmOrderZeroDegreeZero(self):
        """Tests the spherical harmonics of order zero and degree zero."""
        theta = jnp.array([0.3])
        phi = jnp.array([2.3])
        n_max = 0

        expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)])
        actual = jnp.real(
            lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi,
                                 n_max))

        self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8)

    def testSphHarmOrderZeroDegreeOne(self):
        """Tests the spherical harmonics of order one and degree zero."""
        theta = jnp.array([2.0])
        phi = jnp.array([3.1])
        n_max = 1

        expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi)
        actual = jnp.real(
            lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi,
                                 n_max))

        self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8)

    def testSphHarmOrderOneDegreeOne(self):
        """Tests the spherical harmonics of order one and degree one."""
        theta = jnp.array([2.0])
        phi = jnp.array([2.5])
        n_max = 1

        expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) *
                    jnp.exp(1j * theta))
        actual = lsp_special.sph_harm(jnp.array([1]), jnp.array([1]), theta,
                                      phi, n_max)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_maxdegree={}_inputsize={}_dtype={}'.format(l_max, num_z, dtype),
            'l_max':
            l_max,
            'num_z':
            num_z,
            'dtype':
            dtype
        } for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8])
                            for dtype in jtu.dtypes.all_integer))
    def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype):
        """Tests against JIT compatibility and Numpy."""
        n_max = l_max
        shape = (num_z, )
        rng = jtu.rand_int(self.rng(), -l_max, l_max + 1)

        lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max)

        def args_maker():
            m = rng(shape, dtype)
            n = abs(m)
            theta = jnp.linspace(-4.0, 5.0, num_z)
            phi = jnp.linspace(-2.0, 1.0, num_z)
            return m, n, theta, phi

        with self.subTest('Test JIT compatibility'):
            self._CompileAndCheck(lsp_special_fn, args_maker)

        with self.subTest('Test against numpy.'):
            self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn,
                                    args_maker)

    def testSphHarmCornerCaseWithWrongNmax(self):
        """Tests the corner case where `n_max` is not the maximum value of `n`."""
        m = jnp.array([2])
        n = jnp.array([10])
        n_clipped = jnp.array([6])
        n_max = 6
        theta = jnp.array([0.9])
        phi = jnp.array([0.2])

        expected = lsp_special.sph_harm(m, n, theta, phi, n_max)

        actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max)

        self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                'testcase_name':
                '_shape={}'
                '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}'
                '_max_sv={}_method={}_side={}'
                '_nonzero_condition_number={}_seed={}'.format(
                    jtu.format_shape_dtype_string(
                        shape,
                        jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv,
                    degeneracy, geometric_spectrum, max_sv, method, side,
                    nonzero_condition_number, seed),
                'n_zero_sv':
                n_zero_sv,
                'degeneracy':
                degeneracy,
                'geometric_spectrum':
                geometric_spectrum,
                'max_sv':
                max_sv,
                'shape':
                shape,
                'method':
                method,
                'side':
                side,
                'nonzero_condition_number':
                nonzero_condition_number,
                'dtype':
                dtype,
                'seed':
                seed
            } for n_zero_sv in n_zero_svs for degeneracy in degeneracies
            for geometric_spectrum in geometric_spectra for max_sv in max_svs
            for shape in polar_shapes for method in methods for side in sides
            for nonzero_condition_number in nonzero_condition_numbers
            for dtype in jtu.dtypes.floating for seed in seeds))
    def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv,
                  shape, method, side, nonzero_condition_number, dtype, seed):
        """ Tests jax.scipy.linalg.polar."""
        if jtu.device_under_test() != "cpu":
            if jnp.dtype(dtype).name in ("bfloat16", "float16"):
                raise unittest.SkipTest("Skip half precision off CPU.")
            if method == "svd":
                raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.")

        np.random.seed(seed)
        matrix, _ = _initialize_polar_test(shape, n_zero_sv, degeneracy,
                                           geometric_spectrum, max_sv,
                                           nonzero_condition_number, dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError,
                              jsp.linalg.polar,
                              matrix,
                              method=method,
                              side=side)
            return

        unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side)
        if shape[0] >= shape[1]:
            should_be_eye = np.matmul(unitary.conj().T, unitary)
        else:
            should_be_eye = np.matmul(unitary, unitary.conj().T)
        tol = 10 * jnp.finfo(matrix.dtype).eps
        eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype)
        with self.subTest('Test unitarity.'):
            self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape))

        with self.subTest('Test Hermiticity.'):
            self.assertAllClose(posdef,
                                posdef.conj().T,
                                atol=tol * jnp.linalg.norm(posdef))

        ev, _ = np.linalg.eigh(posdef)
        ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)]
        negative_ev = jnp.sum(ev < 0.)
        with self.subTest('Test positive definiteness.'):
            assert negative_ev == 0.

        if side == "right":
            recon = jnp.matmul(unitary,
                               posdef,
                               precision=lax.Precision.HIGHEST)
        elif side == "left":
            recon = jnp.matmul(posdef,
                               unitary,
                               precision=lax.Precision.HIGHEST)
        with self.subTest('Test reconstruction.'):
            self.assertAllClose(matrix,
                                recon,
                                atol=tol * jnp.linalg.norm(matrix))

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed,
                                                       jnp.dtype(dtype).name),
            'linear_size':
            linear_size,
            'seed':
            seed,
            'dtype':
            dtype
        } for linear_size in linear_sizes for seed in seeds
                            for dtype in jtu.dtypes.floating))
    def test_spectral_dac_eigh(self, linear_size, seed, dtype):
        if jtu.device_under_test != "cpu":
            raise unittest.SkipTest("Skip eigh off CPU for now.")
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        np.random.seed(seed)
        H = np.random.randn(linear_size, linear_size)
        H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.eigh, H)
            return
        evs, V = jax._src.scipy.eigh.eigh(H)
        ev_exp, eV_exp = jnp.linalg.eigh(H)
        HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST)
        vV = evs * V
        eps = jnp.finfo(H.dtype).eps
        atol = jnp.linalg.norm(H) * eps
        self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol)
        self.assertAllClose(HV, vV, atol=30 * atol)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            'testcase_name':
            '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed,
                                                       jnp.dtype(dtype).name),
            'linear_size':
            linear_size,
            'seed':
            seed,
            'dtype':
            dtype
        } for linear_size in linear_sizes for seed in seeds
                            for dtype in jtu.dtypes.floating))
    def test_spectral_dac_svd(self, linear_size, seed, dtype):
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            if jtu.device_under_test() != "cpu":
                raise unittest.SkipTest("Skip half precision off CPU.")

        np.random.seed(seed)
        A = np.random.randn(linear_size, linear_size).astype(dtype)
        if jnp.dtype(dtype).name in ("bfloat16", "float16"):
            self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A)
            return
        S_expected = np.linalg.svd(A, compute_uv=False)
        U, S, V = jax._src.scipy.eigh.svd(A)
        recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST)
        eps = jnp.finfo(dtype).eps
        eps = eps * jnp.linalg.norm(A) * 10
        self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps)
        self.assertAllClose(A, recon, atol=eps)

        # U is unitary.
        u_unitary_delta = jnp.dot(U.conj().T,
                                  U,
                                  precision=lax.Precision.HIGHEST)
        u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(u_unitary_delta, u_eye, atol=eps)

        # V is unitary.
        v_unitary_delta = jnp.dot(V.conj().T,
                                  V,
                                  precision=lax.Precision.HIGHEST)
        v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype)
        self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
コード例 #13
0
ファイル: lax_scipy_test.py プロジェクト: jamestwebber/jax
class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format(
                jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims,
                return_sign, use_b),
            # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU.
            "shapes":
            shapes,
            "dtype":
            dtype,
            "axis":
            axis,
            "keepdims":
            keepdims,
            "return_sign":
            return_sign,
            "use_b":
            use_b
        } for shape_group in compatible_shapes for dtype in float_dtypes +
                            complex_dtypes + int_dtypes
                            for use_b in [False, True]
                            for shapes in itertools.product(
                                *((shape_group,
                                   shape_group) if use_b else (shape_group, )))
                            for axis in range(
                                -max(len(shape) for shape in shapes),
                                max(len(shape) for shape in shapes))
                            for keepdims in [False, True]
                            for return_sign in [False, True]))
    @jtu.ignore_warning(category=RuntimeWarning,
                        message="invalid value encountered in .*")
    def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b):
        if jtu.device_under_test() != "cpu":
            rng = jtu.rand_some_inf_and_nan(self.rng())
        else:
            rng = jtu.rand_default(self.rng())
        # TODO(mattjj): test autodiff
        if use_b:

            def scipy_fun(array_to_reduce, scale_array):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            def lax_fun(array_to_reduce, scale_array):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign,
                                             b=scale_array)

            args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)]
        else:

            def scipy_fun(array_to_reduce):
                return osp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            def lax_fun(array_to_reduce):
                return lsp_special.logsumexp(array_to_reduce,
                                             axis,
                                             keepdims=keepdims,
                                             return_sign=return_sign)

            args_maker = lambda: [rng(shapes[0], dtype)]
        tol = {np.float32: 1E-6, np.float64: 1E-14}
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)

    def testLogSumExpZeros(self):
        # Regression test for https://github.com/google/jax/issues/5370
        scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
        lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
        args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
        self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
        self._CompileAndCheck(lax_fun, args_maker)

    @parameterized.named_parameters(
        itertools.chain.from_iterable(
            jtu.cases_from_list(
                {
                    "testcase_name":
                    jtu.format_test_name_suffix(rec.test_name, shapes, dtypes),
                    "rng_factory":
                    rec.rng_factory,
                    "shapes":
                    shapes,
                    "dtypes":
                    dtypes,
                    "test_autodiff":
                    rec.test_autodiff,
                    "nondiff_argnums":
                    rec.nondiff_argnums,
                    "scipy_op":
                    getattr(osp_special, rec.name),
                    "lax_op":
                    getattr(lsp_special, rec.name)
                } for shapes in itertools.combinations_with_replacement(
                    all_shapes, rec.nargs)
                for dtypes in (itertools.combinations_with_replacement(
                    rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else
                               itertools.product(*rec.dtypes)))
            for rec in JAX_SPECIAL_FUNCTION_RECORDS))
    def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes,
                            dtypes, test_autodiff, nondiff_argnums):
        if (jtu.device_under_test() == "cpu"
                and (lax_op is lsp_special.gammainc
                     or lax_op is lsp_special.gammaincc)):
            # TODO(b/173608403): re-enable test when LLVM bug is fixed.
            raise unittest.SkipTest("Skipping test due to LLVM lowering bug")
        rng = rng_factory(self.rng())
        args_maker = self._GetArgsMaker(rng, shapes, dtypes)
        args = args_maker()
        self.assertAllClose(scipy_op(*args),
                            lax_op(*args),
                            atol=1e-3,
                            rtol=1e-3,
                            check_dtypes=False)
        self._CompileAndCheck(lax_op, args_maker, rtol=1e-4)

        if test_autodiff:

            def partial_lax_op(*vals):
                list_args = list(vals)
                for i in nondiff_argnums:
                    list_args.insert(i, args[i])
                return lax_op(*list_args)

            assert list(nondiff_argnums) == sorted(set(nondiff_argnums))
            diff_args = [
                x for i, x in enumerate(args) if i not in nondiff_argnums
            ]
            jtu.check_grads(partial_lax_op,
                            diff_args,
                            order=1,
                            atol=jtu.if_device_under_test("tpu", .1, 1e-3),
                            rtol=.1,
                            eps=1e-3)

    @parameterized.named_parameters(
        jtu.cases_from_list({
            "testcase_name":
            "_inshape={}_d={}".format(
                jtu.format_shape_dtype_string(shape, dtype), d),
            "shape":
            shape,
            "dtype":
            dtype,
            "d":
            d
        } for shape in all_shapes for dtype in float_dtypes
                            for d in [1, 2, 5]))
    def testMultigammaln(self, shape, dtype, d):
        def scipy_fun(a):
            return osp_special.multigammaln(a, d)

        def lax_fun(a):
            return lsp_special.multigammaln(a, d)

        rng = jtu.rand_positive(self.rng())
        args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.]
        self._CheckAgainstNumpy(scipy_fun,
                                lax_fun,
                                args_maker,
                                tol={
                                    np.float32: 1e-3,
                                    np.float64: 1e-14
                                })
        self._CompileAndCheck(lax_fun, args_maker)

    def testIssue980(self):
        x = np.full((4, ), -1e20, dtype=np.float32)
        self.assertAllClose(np.zeros((4, ), dtype=np.float32),
                            lsp_special.expit(x))

    def testIssue3758(self):
        x = np.array([1e5, 1e19, 1e10], dtype=np.float32)
        q = np.array([1., 40., 30.], dtype=np.float32)
        self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32),
                            lsp_special.zeta(x, q))

    def testXlogyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False)

    def testGradOfXlogyAtZero(self):
        partial_xlogy = functools.partial(lsp_special.xlogy, 0.)
        self.assertAllClose(api.grad(partial_xlogy)(0.),
                            0.,
                            check_dtypes=False)

    def testXlog1pyShouldReturnZero(self):
        self.assertAllClose(lsp_special.xlog1py(0., -1.),
                            0.,
                            check_dtypes=False)

    def testGradOfXlog1pyAtZero(self):
        partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
        self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                            0.,
                            check_dtypes=False)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_maxdegree={}_inputsize={}".format(
                    l_max, num_z),
                "l_max": l_max,
                "num_z": num_z
            } for l_max, num_z in zip([1, 2, 3], [6, 7, 8])))
    def testLpmn(self, l_max, num_z):
        # Points on which the associated Legendre functions areevaluated.
        z = np.linspace(-0.2, 0.9, num_z)
        actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max,
                                                               n=l_max,
                                                               z=z)

        # The expected results are obtained from scipy.
        expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
        expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z))

        for i in range(num_z):
            val, derivative = osp_special.lpmn(l_max, l_max, z[i])
            expected_p_vals[:, :, i] = val
            expected_p_derivatives[:, :, i] = derivative

        with self.subTest('Test values.'):
            self.assertAllClose(actual_p_vals,
                                expected_p_vals,
                                rtol=1e-6,
                                atol=3.2e-6)

        with self.subTest('Test derivatives.'):
            self.assertAllClose(actual_p_derivatives,
                                expected_p_derivatives,
                                rtol=1e-6,
                                atol=8.4e-4)

        with self.subTest('Test JIT compatibility'):
            args_maker = lambda: [z]
            lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z)
            self._CompileAndCheck(lsp_special_fn, args_maker)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            {
                "testcase_name": "_maxdegree={}_inputsize={}".format(
                    l_max, num_z),
                "l_max": l_max,
                "num_z": num_z
            } for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64])))
    def testNormalizedLpmnValues(self, l_max, num_z):
        # Points on which the associated Legendre functions areevaluated.
        z = np.linspace(-0.2, 0.9, num_z)
        is_normalized = True
        actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized)

        # The expected results are obtained from scipy.
        expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z))
        for i in range(num_z):
            expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0]

        def apply_normalization(a):
            """Applies normalization to the associated Legendre functions."""
            num_m, num_l, _ = a.shape
            a_normalized = np.zeros_like(a)
            for m in range(num_m):
                for l in range(num_l):
                    c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m)
                    c1 = (4.0 * np.pi) * osp_special.factorial(l + m)
                    c2 = np.sqrt(c0 / c1)
                    a_normalized[m, l] = c2 * a[m, l]
            return a_normalized

        # The results from scipy are not normalized and the comparison requires
        # normalizing the results.
        expected_p_vals_normalized = apply_normalization(expected_p_vals)

        with self.subTest('Test accuracy.'):
            self.assertAllClose(actual_p_vals,
                                expected_p_vals_normalized,
                                rtol=1e-6,
                                atol=3.2e-6)

        with self.subTest('Test JIT compatibility'):
            args_maker = lambda: [z]
            lsp_special_fn = lambda z: lsp_special.lpmn_values(
                l_max, l_max, z, is_normalized)
            self._CompileAndCheck(lsp_special_fn, args_maker)
コード例 #14
0
class LaxBackedScipyTests(jtu.JaxTestCase):
  """Tests for LAX-backed Scipy implementation."""

  def _GetArgsMaker(self, rng, shapes, dtypes):
    return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]

  @parameterized.named_parameters(
      {"testcase_name": "_inshape={}_axis={}_keepdims={}".format(
          jtu.format_shape_dtype_string(shape, dtype), axis, keepdims),
       "rng": jtu.rand_default(), "shape": shape, "dtype": dtype,
       "axis": axis, "keepdims": keepdims}
      for shape in all_shapes for dtype in float_dtypes
      for axis in range(-len(shape), len(shape))
      for keepdims in [False, True])
  def testLogSumExp(self, rng, shape, dtype, axis, keepdims):
    # TODO(mattjj): test autodiff
    def scipy_fun(array_to_reduce):
      return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    def lax_fun(array_to_reduce):
      return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims)

    args_maker = lambda: [rng(shape, dtype)]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": jtu.format_test_name_suffix(
          rec.test_name, shapes, dtypes),
       "rng": rec.rng, "shapes": shapes, "dtypes": dtypes,
       "modes": rec.diff_modes,
       "scipy_op": getattr(osp_special, rec.name),
       "lax_op": getattr(lsp_special, rec.name)}
      for rec in JAX_SPECIAL_FUNCTION_RECORDS
      for shapes in CombosWithReplacement(all_shapes, rec.nargs)
      for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))
  def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, modes):
    # TODO(mattjj): unskip this test combination when real() on tpu is improved
    # TODO(mattjj): test autodiff
    if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu")
        and not shapes[0]):
      return absltest.unittest.skip("real() on scalar not supported on tpu")

    args_maker = self._GetArgsMaker(rng, shapes, dtypes)
    args = args_maker()
    self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3,
                        check_dtypes=False)
    self._CompileAndCheck(lax_op, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": jtu.format_test_name_suffix(
          "", shapes, dtypes),
       "rng": rng, "shapes": shapes, "dtypes": dtypes}
      for shapes in CombosWithReplacement(all_shapes, 3)
      for dtypes in CombosWithReplacement(default_dtypes, 3)
      for rng in [jtu.rand_default()])
  def testNormLogPdfThreeArgs(self, rng, shapes, dtypes):
    # TODO(mattjj): test autodiff
    scipy_fun = osp_stats.norm.logpdf
    lax_fun = lsp_stats.norm.logpdf
    def args_maker():
      x, loc, scale = map(rng, shapes, dtypes)
      scale = 0.5 + onp.abs(scale)
      return [x, loc, scale]
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)

  @parameterized.named_parameters(
      {"testcase_name": jtu.format_test_name_suffix(
          "", shapes, dtypes),
       "rng": rng, "shapes": shapes, "dtypes": dtypes}
      for shapes in CombosWithReplacement(all_shapes, 2)
      for dtypes in CombosWithReplacement(default_dtypes, 2)
      for rng in [jtu.rand_default()])
  def testNormLogPdfTwoArgs(self, rng, shapes, dtypes):
    # TODO(mattjj): test autodiff
    scale = 0.5
    scipy_fun = functools.partial(osp_stats.norm.logpdf, scale=scale)
    lax_fun = functools.partial(lsp_stats.norm.logpdf, scale=scale)
    def args_maker():
      return list(map(rng, shapes, dtypes))
    self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True)
    self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)