def genNamedParametersNArgs(n, rng): return parameterized.named_parameters( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), "rng": rng, "shapes": shapes, "dtypes": dtypes} for shapes in CombosWithReplacement(all_shapes, n) for dtypes in CombosWithReplacement(float_dtypes, n)))
def genNamedParametersNArgs(n): return parameterized.named_parameters( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes), "shapes": shapes, "dtypes": dtypes} for shapes in itertools.combinations_with_replacement(all_shapes, n) for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims} for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for rec in JAX_SPECIAL_FUNCTION_RECORDS for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, test_autodiff): # TODO(mattjj): unskip this test combination when real() on tpu is improved if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu") and not shapes[0]): return absltest.unittest.skip("real() on scalar not supported on tpu") args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3)
class LaxVmapTest(jtu.JaxTestCase): def _CheckBatching(self, op, bdim_size, bdims, shapes, dtypes, rng, rtol=None, atol=None): batched_shapes = map(partial(add_bdim, bdim_size), bdims, shapes) args = [ rng(shape, dtype) for shape, dtype in zip(batched_shapes, dtypes) ] args_slice = args_slicer(args, bdims) ans = api.vmap(op, bdims)(*args) if bdim_size == 0: args = [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] out = op(*args) expected = np.zeros((0, ) + out.shape, out.dtype) else: expected = np.stack([op(*args_slice(i)) for i in range(bdim_size)]) self.assertAllClose(ans, expected, rtol=rtol, atol=atol) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": "{}_bdims={}".format( jtu.format_test_name_suffix( rec.op, shapes, itertools.repeat(dtype)), bdims), "op_name": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "bdims": bdims, "tol": rec.tol } for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement( shape_group, rec.nargs) for bdims in all_bdims(*shapes) for dtype in rec.dtypes) for rec in LAX_OPS)) def testOp(self, op_name, rng_factory, shapes, dtype, bdims, tol): rng = rng_factory(self.rng()) op = getattr(lax, op_name) self._CheckBatching(op, 10, bdims, shapes, [dtype] * len(shapes), rng, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" "_lhs_bdim={}_rhs_bdim={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count, lhs_bdim, rhs_bdim), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, "perms": perms, "lhs_bdim": lhs_bdim, "rhs_bdim": rhs_bdim, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count, } for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in [ ( (b * batch_group_count, i * feature_group_count, 6, 7), # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [((0, 0), (0, 0)), ((1, 0), (0, 1)), ((0, -1), (0, 0))], # pads [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3) ] for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in [np.float32] for padding in all_pads for dim_nums, perms in [(("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for lhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(lhs_shape) + 1)) for rhs_bdim in itertools.chain( [cast(Optional[int], None)], range(len(rhs_shape) + 1)) if (lhs_bdim, rhs_bdim) != (None, None))) def testConvGeneralDilatedBatching(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, lhs_bdim, rhs_bdim): rng = jtu.rand_default(self.rng()) tol = 1e-1 if dtypes.finfo(dtype).bits <= 32 else 1e-3 # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) self._CheckBatching(conv, 5, (lhs_bdim, rhs_bdim), (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims } for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testConvertElementType(self, shape, from_dtype, to_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.convert_element_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_from_dtype={}_to_dtype={}_bdims={}".format( shape, from_dtype, to_dtype, bdims), "shape": shape, "from_dtype": from_dtype, "to_dtype": to_dtype, "bdims": bdims } for from_dtype, to_dtype in itertools.product( [np.float32, np.int32, "float32", "int32"], repeat=2) for shape in [(2, 3)] for bdims in all_bdims(shape))) def testBitcastElementType( self, shape, from_dtype, to_dtype, bdims, ): rng = jtu.rand_default(self.rng()) op = lambda x: lax.bitcast_convert_type(x, to_dtype) self._CheckBatching(op, 10, bdims, (shape, ), (from_dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_min_shape={}_operand_shape={}_max_shape={}_bdims={}".format( jtu.format_shape_dtype_string(min_shape, dtype), jtu.format_shape_dtype_string(operand_shape, dtype), jtu.format_shape_dtype_string(max_shape, dtype), bdims), "min_shape": min_shape, "operand_shape": operand_shape, "max_shape": max_shape, "dtype": dtype, "bdims": bdims } for min_shape, operand_shape, max_shape in [ [(), (2, 3), ()], [(2, 3), (2, 3), ()], [(), (2, 3), (2, 3)], [(2, 3), (2, 3), (2, 3)], ] for dtype in default_dtypes for bdims in all_bdims(min_shape, operand_shape, max_shape))) def testClamp(self, min_shape, operand_shape, max_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest( "batching rule for clamp not implemented") # TODO(mattj) shapes = [min_shape, operand_shape, max_shape] self._CheckBatching(lax.clamp, 10, bdims, shapes, [dtype] * 3, rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "bdims": bdims } for lhs_shape in [(3, ), (4, 3)] for rhs_shape in [(3, ), (3, 6)] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDot(self, lhs_shape, rhs_shape, dtype, bdims): rng = jtu.rand_default(self.rng()) op = partial(lax.dot, precision=lax.Precision.HIGHEST) self._CheckBatching(op, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng, rtol={ np.float16: 5e-2, np.float64: 5e-14 }) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_lhs_contracting={}_rhs_contracting={}_bdims={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lhs_contracting, rhs_contracting, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "lhs_contracting": lhs_contracting, "rhs_contracting": rhs_contracting, "bdims": bdims } for lhs_shape, rhs_shape, lhs_contracting, rhs_contracting in [ [(5, ), (5, ), [0], [0]], [(5, 7), (5, ), [0], [0]], [(7, 5), (5, ), [1], [0]], [(3, 5), (2, 5), [1], [1]], [(5, 3), (5, 2), [0], [0]], [(5, 3, 2), (5, 2, 4), [0], [0]], [(5, 3, 2), (5, 2, 4), [0, 2], [0, 1]], [(5, 3, 2), (3, 5, 2, 4), [0, 2], [1, 2]], [(1, 2, 2, 3), (1, 2, 3, 1), [1], [1]], [(3, 2), (2, 4), [1], [0]], ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims): rng = jtu.rand_small(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}_bdims={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers, bdims), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "bdims": bdims } for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 3, 2), (2, 3, 4), (([2], [0]), ([0], [1]))), ((3, 4, 2, 4), (3, 4, 3, 2), (([2], [3]), ([0, 1], [0, 1]))), ] for bdims in all_bdims(lhs_shape, rhs_shape) for dtype in default_dtypes)) def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"]) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dtype={}_broadcast_sizes={}_bdims={}".format( shape, np.dtype(dtype).name, broadcast_sizes, bdims), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "bdims": bdims } for shape in [(), (2, 3)] for dtype in default_dtypes for broadcast_sizes in [(), (2, ), (1, 2)] for bdims in all_bdims(shape))) def testBroadcast(self, shape, dtype, broadcast_sizes, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.broadcast(x, broadcast_sizes) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_outshape={}_bcdims={}_bdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions, bdims), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "bdims": bdims } for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in default_dtypes for bdims in all_bdims(inshape))) def testBroadcastInDim(self, inshape, dtype, outshape, dimensions, bdims): rng = jtu.rand_default(self.rng()) raise SkipTest("this test has failures in some cases") # TODO(mattjj) op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) self._CheckBatching(op, 5, bdims, (inshape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_dimensions={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, np.float32), dimensions, bdims), "arg_shape": arg_shape, "dimensions": dimensions, "bdims": bdims } for arg_shape, dimensions in [ [(1, ), (0, )], [(1, ), (-1, )], [(2, 1, 4), (1, )], [(2, 1, 4), (-2, )], [(2, 1, 3, 1), (1, )], [(2, 1, 3, 1), (1, 3)], [(2, 1, 3, 1), (3, )], [(2, 1, 3, 1), (1, -1)], ] for bdims in all_bdims(arg_shape))) def testSqueeze(self, arg_shape, dimensions, bdims): dtype = np.float32 rng = jtu.rand_default(self.rng()) op = lambda x: lax.squeeze(x, dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_inshape={}_outshape={}_dims={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), dimensions, bdims), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "dimensions": dimensions, "bdims": bdims } for dtype in default_dtypes for arg_shape, dimensions, out_shape in [[(3, 4), None, ( 12, )], [(2, 1, 4), None, (8, )], [(2, 2, 4), None, ( 2, 8)], [(2, 2, 4), (0, 1, 2), (2, 8)], [(2, 2, 4), (1, 0, 2), ( 8, 2)], [(2, 2, 4), (2, 1, 0), (4, 2, 2)]] for bdims in all_bdims(arg_shape))) def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_pads={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), pads, bdims), "shape": shape, "dtype": dtype, "pads": pads, "bdims": bdims } for shape in [(2, 3)] for bdims in all_bdims(shape) for dtype in default_dtypes for pads in [[(1, 2, 1), (0, 1, 0)]])) def testPad(self, shape, dtype, pads, bdims): rng = jtu.rand_small(self.rng()) fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_predshape={}_argshapes={}_bdims={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, arg_dtype), bdims), "pred_shape": pred_shape, "arg_shape": arg_shape, "arg_dtype": arg_dtype, "bdims": bdims } for arg_shape in [(), (3, ), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for bdims in all_bdims(pred_shape, arg_shape, arg_shape) for arg_dtype in default_dtypes)) def testSelect(self, pred_shape, arg_shape, arg_dtype, bdims): rng = jtu.rand_default(self.rng()) op = lambda c, x, y: lax.select(c < 0, x, y) self._CheckBatching(op, 5, bdims, ( pred_shape, arg_shape, arg_shape, ), (np.bool_, arg_dtype, arg_dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}_bdims={}". format(jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides, bdims), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "bdims": bdims } for shape, start_indices, limit_indices, strides in [ [(3, ), (1, ), (2, ), None], [(7, ), (4, ), (7, ), None], [(5, ), (1, ), (5, ), (2, )], [(8, ), (1, ), (6, ), (2, )], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testSlice(self, shape, dtype, starts, limits, strides, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_perm={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), perm, bdims), "shape": shape, "dtype": dtype, "perm": perm, "bdims": bdims } for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for bdims in all_bdims(shape) for dtype in default_dtypes)) def testTranspose(self, shape, dtype, perm, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_initval={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, init_val, bdims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "bdims": bdims } for init_val, op, dtypes in [ (0, lax.add, default_dtypes), (1, lax.mul, default_dtypes), (0, lax.max, all_dtypes), # non-monoidal (-np.inf, lax.max, float_dtypes), (dtypes.iinfo(np.int32).min, lax.max, [np.int32]), (dtypes.iinfo(np.int64).min, lax.max, [np.int64]), (dtypes.iinfo(np.uint32).min, lax.max, [np.uint32]), (dtypes.iinfo(np.uint64).min, lax.max, [np.uint64]), (np.inf, lax.min, float_dtypes), (dtypes.iinfo(np.int32).max, lax.min, [np.int32]), (dtypes.iinfo(np.int64).max, lax.min, [np.int64]), (dtypes.iinfo(np.uint32).max, lax.min, [np.uint32]), (dtypes.iinfo(np.uint64).max, lax.min, [np.uint64]), ] for dtype in dtypes for shape, dims in [[(3, 4, 5), ( 0, )], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]] for bdims in all_bdims(shape))) def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim, bdims), "op": op, "shape": shape, "dtype": dtype, "dim": dim, "bdims": bdims } for op in [lax.argmin, lax.argmax] for dtype in default_dtypes for shape in [(3, 4, 5)] for dim in range(len(shape)) for bdims in all_bdims(shape))) def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}").format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation } for init_val, op, dtypes in [ (0, lax.add, [np.float32]), (-np.inf, lax.max, [np.float32]), (np.inf, lax.min, [np.float32]), ] for shape, dims, strides, padding, base_dilation, window_dilation in (itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), ( 2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [( 1, 1), (2, 3)], [(1, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), ( 2, 1, 2, 1)], [(1, 2, 2, 1), ( 1, 1, 1, 1 )], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [( 1, 1, 1, 1), (2, 1, 3, 2)], [(1, 1, 1, 1), (1, 2, 2, 1)]))) for dtype in dtypes)) def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_shape={}_axis={}_bdims={}_reverse={}".format( op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, bdims, reverse), "op": op, "shape": shape, "dtype": dtype, "bdims": bdims, "axis": axis, "reverse": reverse } for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for bdims in all_bdims(shape) for reverse in [False, True])) def testCumulativeReduce(self, op, shape, dtype, axis, bdims, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype( dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) self._CheckBatching(partial(op, axis=axis, reverse=reverse), 7, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_dtype={}_padding={}".format(np.dtype(dtype).name, padding), "dtype": dtype, "padding": padding } for dtype in float_dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(message="Using reduced precision for gradient.*") def testSelectAndGatherAdd(self, dtype, padding): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest( "bfloat16 _select_and_gather_add doesn't work on tpu") rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1, ) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": f"_dtype={jtu.format_shape_dtype_string(shape, dtype)}" f"_padding={padding}_dims={dims}_strides={strides}", "dtype": dtype, "padding": padding, "shape": shape, "dims": dims, "strides": strides } for dtype in float_dtypes for padding in ["VALID", "SAME"] for shape in [(3, 2, 4, 6)] for dims in [(1, 1, 2, 1)] for strides in [(1, 2, 2, 1), (1, 1, 1, 1)])) def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1, ) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_bdims={}_fft_ndims={}".format(shape, bdims, fft_ndims), "shape": shape, "bdims": bdims, "fft_ndims": fft_ndims } for shape in [(5, ), (3, 4, 5), (2, 3, 4, 5)] for bdims in all_bdims(shape) for fft_ndims in range(0, min(3, len(shape)) + 1))) @jtu.skip_on_devices("tpu") # TODO(b/137993701): unimplemented cases. def testFft(self, fft_ndims, shape, bdims): rng = jtu.rand_default(self.rng()) ndims = len(shape) axes = range(ndims - fft_ndims, ndims) fft_lengths = [shape[axis] for axis in axes] op = lambda x: lax.fft(x, xla_client.FftType.FFT, fft_lengths) self._CheckBatching(op, 5, bdims, [shape], [np.complex64], rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes, bdims), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "bdims": bdims } for dtype in all_dtypes for shape, idxs, dnums, slice_sizes in [ ((5, ), np.array([[0], [2]]), lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, )), ((10, ), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(), start_index_map=(0, )), (2, )), (( 10, 5, ), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, )), (1, 3)), ((10, 5), np.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers(offset_dims=(1, ), collapsed_slice_dims=(0, ), start_index_map=(0, 1)), (1, 3)), ] for bdims in all_bdims(shape, idxs.shape))) def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 0, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng())) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_idxs={}_update={}_dnums={}_bdims={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums, bdims), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "bdims": bdims } for dtype in float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5, ), np.array([[0], [2]]), (2, ), lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ((10, ), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(), scatter_dims_to_operand_dims=( 0, ))), (( 10, 5, ), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers(update_window_dims=(1, ), inserted_window_dims=(0, ), scatter_dims_to_operand_dims=( 0, ))), ] for bdims in all_bdims(arg_shape, idxs.shape, update_shape))) def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3}) def testShapeUsesBuiltinInt(self): x = lax.iota(np.int32, 3) + 1 self.assertIsInstance(x.shape[0], int) # not np.int64 def testBroadcastShapesReturnsPythonInts(self): shape1, shape2 = (1, 2, 3), (2, 3) out_shape = lax.broadcast_shapes(shape1, shape2) self.assertTrue(all(type(s) is int for s in out_shape)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_k={}_bdims={}".format( jtu.format_shape_dtype_string(shape, dtype), k, bdims), "shape": shape, "dtype": dtype, "k": k, "bdims": bdims, "rng_factory": rng_factory } for shape in [(4, ), (3, 5, 3)] for k in [1, 3] for bdims in all_bdims(shape) # TODO(b/155170120): test with repeats once the XLA:CPU stable top_k bug is fixed: # The top_k indices for integer arrays with identical entries won't match between # vmap'd version and manual reference, so only test unique integer arrays for int_dtypes. # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. for dtype, rng_factory in itertools.chain( unsafe_zip(default_dtypes, itertools.repeat( jtu.rand_unique_int))))) def testTopK(self, shape, dtype, k, bdims, rng_factory): rng = rng_factory(self.rng()) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape, ), (dtype, ), rng) op2 = lambda x: lax.top_k(x, k=k)[1] self._CheckBatching(op2, 5, bdims, (shape, ), (dtype, ), rng) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_dimension={}_arity={}_bdims={}_isstable={}".format( jtu.format_shape_dtype_string(shape, np.float32), dimension, arity, bdims, is_stable), "shape": shape, "dimension": dimension, "arity": arity, "bdims": bdims, "is_stable": is_stable } for shape in [(2, 3)] for dimension in [0, 1] for arity in range(3) for bdims in all_bdims(*((shape, ) * arity)) for is_stable in [False, True])) def testSort(self, shape, dimension, arity, bdims, is_stable): rng = jtu.rand_default(self.rng()) if arity == 1: fun = partial(lax.sort, dimension=dimension) self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (np.float32, ) * arity, rng) else: for i in range(arity): fun = lambda *args, i=i: lax.sort( args, dimension=dimension, is_stable=is_stable)[i] self._CheckBatching(fun, 5, bdims, (shape, ) * arity, (np.float32, ) * arity, rng)
class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS))) def testOp(self, onp_op, lnp_op, rng, shapes, dtypes): args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} for shapes in filter( _shapes_are_broadcast_compatible, CombosWithReplacement(rec.shapes, rec.nargs)) for dtypes in filter( _dtypes_are_compatible_for_bitwise_ops, CombosWithReplacement(rec.dtypes, rec.nargs))) for rec in JAX_BITWISE_OP_RECORDS)) def testBitwiseOp(self, onp_op, lnp_op, rng, shapes, dtypes): if not FLAGS.jax_enable_x64 and any( onp.iinfo(dtype).bits == 64 for dtype in dtypes): self.skipTest("x64 types are disabled by jax_enable_x64") args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}_dtype={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, "None" if out_dtype is None else onp.dtype(out_dtype).name, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "out_dtype": out_dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims} for rec in JAX_REDUCER_RECORDS for shape in rec.shapes for dtype in rec.dtypes for out_dtype in [None] + rec.dtypes for axis in set(range(-len(shape), len(shape))) | set([None]) for keepdims in [False, True])) def testReducer(self, onp_op, lnp_op, rng, shape, dtype, out_dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, dtype=out_dtype, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, dtype=out_dtype, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims} for rec in JAX_REDUCER_NO_DTYPE_RECORDS for shape in rec.shapes for dtype in rec.dtypes for axis in set(range(-len(shape), len(shape))) | set([None]) for keepdims in [False, True])) def testReducerNoDtype(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}".format( jtu.format_shape_dtype_string(shape, dtype), axis), "shape": shape, "dtype": dtype, "axis": axis} for shape in all_shapes for dtype in all_dtypes for axis in set(range(-len(shape), len(shape))) | set([None]))) def testCountNonzero(self, shape, dtype, axis): rng = jtu.rand_some_zero() onp_fun = lambda x: onp.count_nonzero(x, axis) lnp_fun = lambda x: lnp.count_nonzero(x, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "{}_inshape={}_axis={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis} for rec in JAX_ARGMINMAX_RECORDS for shape in rec.shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape)))) def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis): def onp_fun(array_to_reduce): return onp_op(array_to_reduce, axis) def lnp_fun(array_to_reduce): return lnp_op(array_to_reduce, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("matrix-scalar", (3, 3), ()), ("scalar-matrix", (), (3, 3)), ("matrix-vector", (4, 5), (5,)), ("vector-matrix", (6,), (6, 4)), ("matrix-matrix", (3, 4), (4, 5)), ("tensor-vector", (4, 3, 2), (2,)), ("vector-tensor", (2,), (3, 2, 4)), ("tensor-matrix", (4, 3, 2), (2, 5)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-tensor", (2, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))) def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("vector-vector", (3,), (3,)), ("matrix-vector", (3, 3), (3,)), ("vector-matrix", (3,), (3, 3)), ("matrix-matrix", (3, 3), (3, 3)), ("vector-tensor", (3,), (5, 3, 2)), ("tensor-vector", (5, 3, 2), (2,)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-matrix", (5, 2, 3), (3, 2)), ("tensor-tensor", (5, 3, 4), (5, 4, 1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))) def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}_{}".format( jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype), axes), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "axes": axes, "rng": rng} for rng in [jtu.rand_default()] for lhs_shape, rhs_shape, axes in [ [(2, 3, 4), (3, 4, 5, 6), 2], [(2, 3, 4), (5, 4, 3, 6), [1, 2]], [(2, 3, 4), (5, 4, 3, 6), [[1, 2], [2, 1]]], [(1, 2, 3, 4), (4, 5, 3, 6), [[2, 3], [2, 0]]], ] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2))) def testTensordot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, axes, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] lnp_fun = lambda a, b: lnp.tensordot(a, b, axes) onp_fun = lambda a, b: onp.tensordot(a, b, axes) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_{}".format( jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": jtu.rand_default()} # TODO(phawkins): support integer dtypes too. for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2) for lhs_shape, rhs_shape in [ (l, r) for l, r in CombosWithReplacement(all_shapes, 2) if len(jtu._dims_of_shape(l)) == 0 or len(jtu._dims_of_shape(r)) == 0 or l[-1] == r[-1]])) def testInner(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] onp_fun = lambda lhs, rhs: onp.inner(lhs, rhs) lnp_fun = lambda lhs, rhs: lnp.inner(lhs, rhs) self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_amin={}_amax={}".format( jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, "rng": jtu.rand_default()} for shape in all_shapes for dtype in float_dtypes for a_min, a_max in [(-1, None), (None, 1), (-1, 1)])) def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng): onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_decimals={}".format( jtu.format_shape_dtype_string(shape, dtype), decimals), "shape": shape, "dtype": dtype, "decimals": decimals, "rng": jtu.rand_default()} for shape in all_shapes for dtype in float_dtypes for decimals in [0, 1, -2])) def testRoundStaticDecimals(self, shape, dtype, decimals, rng): onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), ",".join(onp.dtype(dtype).name for dtype in dtypes)), "axis": axis, "base_shape": base_shape, "dtypes": dtypes, "rng": jtu.rand_default()} for num_arrs in [3] for dtypes in CombosWithReplacement(default_dtypes, num_arrs) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape)))) def testConcatenate(self, axis, base_shape, dtypes, rng): wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] onp_fun = lambda *args: onp.concatenate(args, axis=axis) lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape=[{}]_axis={}_repeats={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, repeats), "axis": axis, "shape": shape, "dtype": dtype, "repeats": repeats, "rng": jtu.rand_default()} for repeats in [0, 1, 2] for dtype in default_dtypes for shape in all_shapes for axis in [None] + list(range(-len(shape), len(shape))))) def testRepeat(self, axis, shape, dtype, repeats, rng): onp_fun = lambda arg: onp.repeat(arg, repeats=repeats, axis=axis) lnp_fun = lambda arg: lnp.repeat(arg, repeats=repeats, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dtype={}_m={}_n={}_k={}".format( onp.dtype(dtype).name, m, n, k), "m": m, "n": n, "k": k, "dtype": dtype, "rng": jtu.rand_default()} for dtype in default_dtypes for n in [0, 4] for m in [None, 0, 1, 3, 4] for k in list(range(-4, 4)))) def testTri(self, m, n, k, dtype, rng): onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype) lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype) args_maker = lambda: [] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_k={}".format( op, jtu.format_shape_dtype_string(shape, dtype), k), "dtype": dtype, "shape": shape, "op": op, "k": k, "rng": jtu.rand_default()} for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 1] for op in ["tril", "triu"] for k in list(range(-3, 3)))) def testTriLU(self, dtype, shape, op, k, rng): onp_fun = lambda arg: getattr(onp, op)(arg, k=k) lnp_fun = lambda arg: getattr(lnp, op)(arg, k=k) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "dtype": dtype, "shape": shape, "k": k, "rng": jtu.rand_default()} for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) in (1, 2)] for k in list(range(-4, 4)))) def testDiag(self, shape, dtype, k, rng): onp_fun = lambda arg: onp.diag(arg, k) lnp_fun = lambda arg: lnp.diag(arg, k) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_offset={}_axis1={}_axis2={}".format( jtu.format_shape_dtype_string(shape, dtype), offset, axis1, axis2), "dtype": dtype, "shape": shape, "offset": offset, "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()} for dtype in default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 2] for (axis1, axis2) in itertools.combinations(range(len(shape)), 2) for offset in list(range(-4, 4)))) def testDiagonal(self, shape, dtype, offset, axis1, axis2, rng): onp_fun = lambda arg: onp.diagonal(arg, offset, axis1, axis2) lnp_fun = lambda arg: lnp.diagonal(arg, offset, axis1, axis2) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_n={}".format(onp.dtype(dtype).name, n), "dtype": dtype, "n": n} for dtype in default_dtypes for n in list(range(4)))) def testIdentity(self, n, dtype): onp_fun = lambda: onp.identity(n, dtype) lnp_fun = lambda: lnp.identity(n, dtype) args_maker = lambda: [] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype_{}_offset={}_axis1={}_axis2={}".format( jtu.format_shape_dtype_string(shape, dtype), out_dtype, offset, axis1, axis2), "dtype": dtype, "out_dtype": out_dtype, "shape": shape, "offset": offset, "axis1": axis1, "axis2": axis2, "rng": jtu.rand_default()} for dtype in default_dtypes for out_dtype in [None] + default_dtypes for shape in [shape for shape in all_shapes if len(shape) >= 2] for (axis1, axis2) in itertools.combinations(range(len(shape)), 2) for offset in list(range(-4, 4)))) def testTrace(self, shape, dtype, out_dtype, offset, axis1, axis2, rng): onp_fun = lambda arg: onp.trace(arg, offset, axis1, axis2, out_dtype) lnp_fun = lambda arg: lnp.trace(arg, offset, axis1, axis2, out_dtype) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}".format( jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), "shape": shape, "dtypes": dtypes, "rng": rng} for dtypes in [ [onp.float32], [onp.float32, onp.float32], [onp.float32, onp.int32, onp.float32], [onp.float32, onp.int64, onp.float32], [onp.float32, onp.int32, onp.float64], ] for shape in [(), (2,), (3, 4), (1, 100)] for rng in [jtu.rand_default()])) def testStack(self, shape, dtypes, rng): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outdtype={}".format( jtu.format_shape_dtype_string(shape, fill_value_dtype), onp.dtype(out_dtype).name), "shape": shape, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default()} for shape in array_shapes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes)) def testFull(self, shape, fill_value_dtype, out_dtype, rng): onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) args_maker = lambda: [rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_filldtype={}_outdtype={}".format( jtu.format_shape_dtype_string(shape, in_dtype), onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name), "shape": shape, "in_dtype": in_dtype, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default()} for shape in array_shapes for in_dtype in default_dtypes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes)) def testFullLike(self, shape, in_dtype, fill_value_dtype, out_dtype, rng): onp_fun = lambda x, fill_value: onp.full_like(x, fill_value, dtype=out_dtype) lnp_fun = lambda x, fill_value: lnp.full_like(x, fill_value, dtype=out_dtype) args_maker = lambda: [rng(shape, in_dtype), rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_{}sections".format( jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype, "rng": jtu.rand_default()} for shape, axis, num_sections in [ ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] for dtype in default_dtypes)) def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng): onp_fun = lambda x: onp.split(x, num_sections, axis=axis) lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype)), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng": jtu.rand_default()} for dtype in default_dtypes for arg_shape, out_shape in [ (jtu.NUMPY_SCALAR_SHAPE, (1, 1, 1)), ((), (1, 1, 1)), ((7, 0), (0, 42, 101)), ((3, 4), 12), ((3, 4), (12,)), ((3, 4), -1), ((2, 1, 4), (-1,)), ((2, 2, 4), (2, 8)) ])) def testReshape(self, arg_shape, out_shape, dtype, rng): onp_fun = lambda x: onp.reshape(x, out_shape) lnp_fun = lambda x: lnp.reshape(x, out_shape) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_expanddim={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), dim), "arg_shape": arg_shape, "dtype": dtype, "dim": dim, "rng": jtu.rand_default()} for arg_shape in [(), (3,), (3, 4)] for dtype in default_dtypes for dim in range(-len(arg_shape)+1, len(arg_shape)))) def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng): onp_fun = lambda x: onp.expand_dims(x, dim) lnp_fun = lambda x: lnp.expand_dims(x, dim) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axes=({},{})".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, "rng": jtu.rand_default()} for arg_shape, ax1, ax2 in [ ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] for dtype in default_dtypes)) def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng): onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax), "arg_shape": arg_shape, "dtype": dtype, "ax": ax, "rng": jtu.rand_default()} for arg_shape, ax in [ ((3, 1), None), ((3, 1), 1), ((1, 3, 1), (0, 2)), ((1, 4, 1), (0,))] for dtype in default_dtypes)) def testSqueeze(self, arg_shape, dtype, ax, rng): onp_fun = lambda x: onp.squeeze(x, ax) lnp_fun = lambda x: lnp.squeeze(x, ax) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_arg{}".format(i), "arg": arg} for i, arg in enumerate([ [1, 2, 3], [1., 2., 3.], [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]], [[3, onp.array(2), 1], onp.arange(3.)], ]))) def testArray(self, arg): args_maker = lambda: [arg] self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True) def testArrayAsarrayMethod(self): class arraylike(object): def __asarray__(self, dtype=None): return 3. a = arraylike() ans = lnp.array(a) assert ans == 3. def testAllClose(self): rng = onp.random.RandomState(0) x = rng.randn(2, 2) y = rng.randn(2) def same(list1, list2): allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) elements_close = list(map(allclose, list1, list2)) return lnp.all(lnp.array(elements_close)) csame = api.jit(same) a1 = same((x, y), (x, y)) a2 = csame((x, y), (x, y)) a3 = csame((x, y), (x, 2 * y)) self.assertTrue(a1) self.assertTrue(a2) self.assertFalse(a3) @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure def testOnesBroadcastingConstantHandler(self): # TODO(mattjj): update this test for jax3 self.skipTest("test needs jax3 update") def fun(x): ones = lnp.ones((3, 4)) assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) # To check that the constant handler generates a Broadcast for stride-zero # arrays, we monkey-patch the client instance. # TODO(mattjj): once we have better HLO dumping and inspecting facilities, # we can check the HLO more directly. c = x._node.c Broadcast = c.Broadcast # pylint: disable=invalid-name was_called = [] c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) out = x + ones # the ndarray constant handler should call Broadcast here assert was_called, "Broadcast was not called." return out fun = api.jit(fun) out_val = fun(lnp.ones(4)) self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) def testZeroStridesConstantHandler(self): raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) def fun(x): return x * const fun = api.jit(fun) out_val = fun(3.) self.assertAllClose(out_val, 3. * const, check_dtypes=False) def testIsInstanceNdarrayDuringTracing(self): arr = onp.ones(3) @api.jit def f(x): self.assertIsInstance(x, lnp.ndarray) return lnp.sum(x) f(arr) def testNonArrayErrorMessage(self): x = [1., 2.] y = onp.array([3., 4.]) def g(x, y): return lnp.add(x, y) def f(x, y): return lnp.dot(x, y) self.assertRaises(TypeError, lambda: g(x, y)) self.assertRaises(TypeError, lambda: f(x, y)) self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) def testAbstractionErrorMessage(self): @api.jit def f(x, n): for _ in range(n): x = x * x return x self.assertRaises(TypeError, lambda: f(3., 3)) @api.jit def g(x): if x > 0.: return x * 2 else: return x + 2 self.assertRaises(TypeError, lambda: g(3.)) def testTracingPrimitiveWithNoTranslationErrorMessage(self): # TODO(mattjj): update this for jax3 self.skipTest("test needs jax3 update") foo = lnp._not_implemented(lambda x: x) # No error if there's no tracing. foo(onp.arange(3)) cfoo = api.jit(foo) self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}".format( jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rng, "shape": shape, "dtype": dtype, "axis": axis} for shape in [(3,), (2, 3)] for dtype in default_dtypes for axis in range(len(shape)) for rng in [jtu.rand_default()])) def testFlip(self, shape, dtype, axis, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.flip(x, axis) onp_op = lambda x: onp.flip(x, axis) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_k={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), k, axes), "rng": rng, "shape": shape, "dtype": dtype, "k": k, "axes": axes} for shape, axes in [ [(2, 3), (0, 1)], [(2, 3), (1, 0)], [(4, 3, 2), (0, 2)], [(4, 3, 2), (2, 1)], ] for k in range(-3, 4) for dtype in default_dtypes for rng in [jtu.rand_default()])) def testRot90(self, shape, dtype, k, axes, rng): args_maker = self._GetArgsMaker(rng, [shape], [dtype]) lnp_op = lambda x: lnp.rot90(x, k, axes) onp_op = lambda x: onp.rot90(x, k, axes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) # TODO(mattjj): test infix operator overrides def testRavel(self): # TODO(mattjj): support this method-based syntax? rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True) def testAstype(self): rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] op = lambda x: x.astype(lnp.int32) self._CheckAgainstNumpy(op, op, args_maker, check_dtypes=True) self._CompileAndCheck(op, args_maker, check_dtypes=True) # TODO(mattjj): test other ndarray-like method overrides def testOnpMean(self): # from https://github.com/google/jax/issues/125 x = lax.add(lnp.eye(3), 0.) ans = onp.mean(x) self.assertAllClose(ans, onp.array([1./3, 1./3, 1./3]), check_dtypes=False) # TODO(mattjj): more exhaustive arange tests def testArangeOnFloats(self): # from https://github.com/google/jax/issues/145 expected = onp.arange(0.0, 1.0, 0.1) ans = lnp.arange(0.0, 1.0, 0.1) self.assertAllClose(expected, ans, check_dtypes=True)
class LaxBackedScipyStatsTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @genNamedParametersNArgs(3) def testPoissonLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.logpmf lax_fun = lsp_stats.poisson.logpmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.pmf lax_fun = lsp_stats.poisson.pmf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) loc = np.floor(loc) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.poisson.cdf lax_fun = lsp_stats.poisson.cdf def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) return [k, mu, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testBernoulliLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.bernoulli.logpmf lax_fun = lsp_stats.bernoulli.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.geom.logpmf lax_fun = lsp_stats.geom.logpmf def args_maker(): x, logit, loc = map(rng, shapes, dtypes) x = np.floor(x) p = expit(logit) loc = np.floor(loc) return [x, p, loc] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.beta.logpdf lax_fun = lsp_stats.beta.logpdf def args_maker(): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float32: 2e-3, np.float64: 1e-4}) @genNamedParametersNArgs(3) def testCauchyLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.cauchy.logpdf lax_fun = lsp_stats.cauchy.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix("", [x_shape, alpha_shape], dtypes), "shapes": [x_shape, alpha_shape], "dtypes": dtypes} for x_shape in one_and_two_dim_shapes for alpha_shape in [(x_shape[0],), (x_shape[0] + 1,)] for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, 2) )) def testDirichletLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) def _normalize(x, alpha): x_norm = x.sum(0) + (0.0 if x.shape[0] == alpha.shape[0] else 0.1) return (x / x_norm).astype(x.dtype), alpha def lax_fun(x, alpha): return lsp_stats.dirichlet.logpdf(*_normalize(x, alpha)) def scipy_fun(x, alpha): # scipy validates the x normalization using float64 arithmetic, so we must # cast x to float64 before normalization to ensure this passes. x, alpha = _normalize(x.astype('float64'), alpha) result = osp_stats.dirichlet.logpdf(x, alpha) # if x.shape is (N, 1), scipy flattens the output, while JAX returns arrays # of a consistent rank. This check ensures the results have the same shape. return result if x.ndim == 1 else np.atleast_1d(result) def args_maker(): # Don't normalize here, because we want normalization to happen at 64-bit # precision in the scipy version. x, alpha = map(rng, shapes, dtypes) return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.expon.logpdf lax_fun = lsp_stats.expon.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.gamma.logpdf lax_fun = lsp_stats.gamma.logpdf def args_maker(): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.laplace.logpdf lax_fun = lsp_stats.laplace.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.laplace.cdf lax_fun = lsp_stats.laplace.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low scale = np.clip(scale, a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol={np.float32: 1e-5, np.float64: 1e-6}) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.cdf lax_fun = lsp_stats.logistic.cdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.logpdf lax_fun = lsp_stats.logistic.logpdf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.ppf lax_fun = lsp_stats.logistic.ppf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticSf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.logistic.sf lax_fun = lsp_stats.logistic.sf def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormLogCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.logcdf lax_fun = lsp_stats.norm.logcdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormCdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.cdf lax_fun = lsp_stats.norm.cdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-6) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testNormPpf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.norm.ppf lax_fun = lsp_stats.norm.ppf def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: q = np.clip(np.abs(q / 3), a_min=None, a_max=1) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [q, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4) def testParetoLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.pareto.logpdf lax_fun = lsp_stats.pareto.logpdf def args_maker(): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testTLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.t.logpdf lax_fun = lsp_stats.t.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-3) self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testUniformLogPdf(self, shapes, dtypes): rng = jtu.rand_default(self.rng()) scipy_fun = osp_stats.uniform.logpdf lax_fun = lsp_stats.uniform.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) scipy_fun = osp_stats.chi2.logpdf lax_fun = lsp_stats.chi2.logpdf def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng()) lax_fun = lsp_stats.betabinom.logpmf def args_maker(): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) a = np.clip(a, a_min = 0.1, a_max = None) b = np.clip(a, a_min = 0.1, a_max = None) loc = np.floor(loc) return [k, n, a, b, loc] if scipy_version >= (1, 4): scipy_fun = osp_stats.betabinom.logpmf self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4) self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self): self.assertAllClose( np.ones((4,), np.float32), lsp_stats.norm.cdf(np.full((4,), np.inf, np.float32)), check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_x={}_mean={}_cov={}".format( jtu.format_shape_dtype_string(x_shape, x_dtype), jtu.format_shape_dtype_string(mean_shape, mean_dtype) if mean_shape is not None else None, jtu.format_shape_dtype_string(cov_shape, cov_dtype) if cov_shape is not None else None), "x_shape": x_shape, "x_dtype": x_dtype, "mean_shape": mean_shape, "mean_dtype": mean_dtype, "cov_shape": cov_shape, "cov_dtype": cov_dtype} for x_shape, mean_shape, cov_shape in [ # # These test cases cover default values for mean/cov, but we don't # # support those yet (and they seem not very valuable). # [(), None, None], # [(), (), None], # [(2,), None, None], # [(2,), (), None], # [(2,), (2,), None], # [(3, 2), (3, 2,), None], # [(5, 3, 2), (5, 3, 2,), None], [(), (), ()], [(3,), (), ()], [(3,), (3,), ()], [(3,), (3,), (3, 3)], [(3, 4), (4,), (4, 4)], # # These test cases are where scipy flattens things, which has # # different batch semantics than some might expect # [(5, 3, 2), (5, 3, 2,), ()], # [(5, 3, 2), (5, 3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (5, 3, 2, 2)], # [(5, 3, 2), (3, 2,), (2, 2)], ] for x_dtype, mean_dtype, cov_dtype in itertools.combinations_with_replacement(jtu.dtypes.floating, 3) if (mean_shape is not None or mean_dtype == np.float32) and (cov_shape is not None or cov_dtype == np.float32))) def testMultivariateNormalLogpdf(self, x_shape, x_dtype, mean_shape, mean_dtype, cov_shape, cov_dtype): rng = jtu.rand_default(self.rng()) def args_maker(): args = [rng(x_shape, x_dtype)] if mean_shape is not None: args.append(5 * rng(mean_shape, mean_dtype)) if cov_shape is not None: if cov_shape == (): args.append(0.1 + rng(cov_shape, cov_dtype) ** 2) else: factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) return args self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, args_maker, tol=1e-3) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_ndim={}_nbatch={}_dtype={}".format(ndim, nbatch, dtype.__name__), "ndim": ndim, "nbatch": nbatch, "dtype": dtype} for ndim in [2, 3] for nbatch in [1, 3, 5] for dtype in jtu.dtypes.floating)) def testMultivariateNormalLogpdfBatch(self, ndim, nbatch, dtype): # Regression test for #5570 rng = jtu.rand_default(self.rng()) x = rng((nbatch, ndim), dtype) mean = 5 * rng((nbatch, ndim), dtype) factor = rng((nbatch, ndim, 2 * ndim), dtype) cov = factor @ factor.transpose(0, 2, 1) result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = api.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) self.assertArraysEqual(result1, result2)
class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.name, shapes, itertools.repeat(dtype)), "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "order": rec.order, "tol": rec.tol} for shape_group in compatible_shapes for shapes in CombosWithReplacement(shape_group, rec.nargs) for dtype in rec.dtypes) for rec in LAX_GRAD_OPS)) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.pow: raise SkipTest("pow grad imprecise on tpu") tol = jtu.join_tolerance(1e-1, tol) if num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), "op": rec.op, "special_value": special_value, "tol": rec.tol} for special_value in rec.values) for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}".format( jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), "from_dtype": from_dtype, "to_dtype": to_dtype, "rng_factory": rng_factory} for from_dtype, to_dtype in itertools.product( float_dtypes + complex_dtypes, repeat=2) for rng_factory in [jtu.rand_default])) def testConvertElementTypeGrad(self, from_dtype, to_dtype, rng_factory): rng = rng_factory(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = jtu.ignore_warning(category=onp.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in grad_float_dtypes for rng_factory in [jtu.rand_default])) def testClampGrad(self, shape, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) low = operand - dtype(10) high = operand + dtype(10) # Avoids points near the boundary where the gradient may be inaccurate. check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( dim, ",".join(str(d) for d in base_shape), onp.dtype(dtype).name, num_arrs), "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs, "rng_factory": rng_factory} for num_arrs in [3] for dtype in float_dtypes for base_shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(base_shape)) for rng_factory in [jtu.rand_default])) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "rng_factory": rng_factory,} for lhs_shape, rhs_shape, all_strides in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)])]) for strides in all_strides for dtype in float_dtypes for padding in ["VALID", "SAME"] for rng_factory in [jtu.rand_small])) def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv, window_strides=strides, padding=padding, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory} for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], [(1, 1), (2, 1)], [(1, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], [(1,), (2,)], [(1,), (2,)])]) for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in float_dtypes for padding in all_pads for rng_factory in [jtu.rand_small])) def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_with_general_padding, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "rng_factory": rng_factory, "dimension_numbers": dim_nums, "perms": perms, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count} for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ ([(b * batch_group_count, i * feature_group_count, 6, 7), (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)] for lhs_shape in lhs_shapes for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in grad_float_dtypes for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) for dim_nums, perms in [ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))] for rng_factory in [jtu.rand_default] )) def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count, rng_factory): if dtype == onp.float16: raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve rng = rng_factory(self.rng()) tol = {dtypes.bfloat16: 1e-0, onp.float16: 5e-1, onp.float32: 2e-4} # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(onp.take(lhs_shape, lhs_perm)) rhs_shape = list(onp.take(rhs_shape, rhs_perm)) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng_factory": jtu.rand_default} for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] for dtype in float_dtypes)) def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers, "rng_factory": jtu.rand_small} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 2), (2, 4), (([1], [0]), ([], []))), ((3, 5), (2, 5), (([1], [1]), ([], []))), ((5, 3), (5, 2), (([0], [0]), ([], []))), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ] for dtype in float_dtypes)) def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( shape, onp.dtype(dtype).name, broadcast_sizes), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes, "rng_factory": rng_factory} for shape in [(), (2, 3)] for dtype in float_dtypes for broadcast_sizes in [(), (2,), (1, 2)] for rng_factory in [jtu.rand_default])) def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): rng = rng_factory(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions, "rng_factory": rng_factory} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions, rng_factory): rng = rng_factory(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_perm={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), permutation), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng_factory": rng_factory, "permutation": permutation} for dtype in float_dtypes for arg_shape, out_shape, permutation in [ [(3, 4), (12,), None], [(2, 1, 4), (8,), None], [(2, 2, 4), (2, 8), None], [(3, 4), (12,), (0, 1)], [(3, 4), (12,), (1, 0)], [(2, 1, 4), (8,), (0, 2, 1)], [(2, 1, 4), (8,), (2, 0, 1)], [(2, 2, 4), (2, 8), (0, 2, 1)], [(2, 2, 4), (2, 8), (2, 0, 1)], ] for rng_factory in [jtu.rand_default])) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads), "shape": shape, "dtype": dtype, "pads": pads, "rng_factory": jtu.rand_small} for shape in [(2, 3)] for dtype in float_dtypes for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) def testPadGrad(self, shape, dtype, pads, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, onp.array(0, dtype), pads) check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = onp.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) dimensions = [0] check_grads(rev, (onp.array([3., 2., 1.]),), 2) dimensions = [0, 1] check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={onp.float32: 3e-3}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( jtu.format_shape_dtype_string(pred_shape, onp.bool_), jtu.format_shape_dtype_string(arg_shape, dtype)), "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype, "rng_factory": rng_factory} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSelectGrad(self, pred_shape, arg_shape, dtype, rng_factory): rng = rng_factory(self.rng()) pred = rng(pred_shape, onp.bool_) on_true = rng(arg_shape, dtype) on_false = rng(arg_shape, dtype) select = lambda on_true, on_false: lax.select(pred, on_true, on_false) check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides, "rng_factory": rng_factory} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testSliceGrad(self, shape, dtype, starts, limits, strides, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, size_indices), "shape": shape, "dtype": dtype, "start_indices": start_indices, "size_indices": size_indices, "rng_factory": rng_factory} for shape, start_indices, size_indices in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape, "rng_factory": rng_factory} for shape, start_indices, update_shape in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = onp.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( jtu.format_shape_dtype_string(shape, dtype), perm), "shape": shape, "dtype": dtype, "perm": perm, "rng_factory": rng_factory} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for dtype in float_dtypes for rng_factory in [jtu.rand_default])) def testTransposeGrad(self, shape, dtype, perm, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, inexact_dtypes, jtu.rand_default), (-onp.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), (onp.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), ] for dtype in dtypes for shape, dims in [ [(3, 4, 5), ()], [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], ])) def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, onp.float16: 1e-1, onp.float32: 1e-1, onp.float64: 1e-3, onp.complex64: 1e-1} operand = rng(shape, dtype) init_val = onp.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_dtype={}_padding={}" .format(op.__name__, onp.dtype(dtype).name, padding), "op": op, "init_val": init_val, "dtype": dtype, "padding": padding, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-onp.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (onp.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] for dtype in dtypes for padding in ["VALID", "SAME"])) @jtu.skip_on_flag("jax_skip_slow_tests", True) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def testReduceWindowGrad(self, op, init_val, dtype, padding, rng_factory): rng = rng_factory(self.rng()) init_val = onp.asarray(init_val, dtype=dtype) # We need this conditional and the corresponding loop logic to be in the # test method, rather than at the parameterized test level, because it # depends on FLAGS for the device under test. # TODO(b/31565929): enable when fixed. if jtu.device_under_test() == "tpu" and op is not lax.add: all_configs = [((6, 5, 4, 3), (2, 2, 1, 1), (1, 2, 1, 1))] # TODO(b/73062247): need variadic reduce-window for better precision. gradient_order = 1 else: all_configs = itertools.chain( itertools.product( [(4, 6)], # shapes [(2, 1), (1, 2)], # window_dimensions [(1, 1), (2, 1), (1, 2)] # strides ), itertools.product( [(3, 2, 4, 6)], # shapes [(1, 1, 2, 1), (2, 1, 2, 1)], # window_dimensions [(1, 2, 2, 1), (1, 1, 1, 1)]), # strides ) gradient_order = 3 def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding) for shape, dims, strides in all_configs: operand = rng(shape, dtype) if op is lax.add: eps = 1. tol = None else: # this test can fail if there are duplicates in operand self.assertEqual(onp.unique(operand).size, operand.size, msg="test requires operand elements to be unique.") eps = 1e-2 tol = {onp.float16: 1e-1, onp.float32: 6e-2, onp.float64: 6e-2} check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis), "op": op, "shape": shape, "dtype": dtype, "axis": axis, "rng_factory": rng_factory} for op, types in [ (lax.cumsum, [onp.float32, onp.float64]), (lax.cumprod, [onp.float32, onp.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for rng_factory in [ jtu.rand_default if dtypes.issubdtype(dtype, onp.integer) else jtu.rand_small])) def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable} for dtype in [onp.float32] for shape in [(5,), (5, 7)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortGrad(self, shape, dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, key_dtype), jtu.format_shape_dtype_string(shape, val_dtype), axis, is_stable), "rng_factory": rng_factory, "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis, "is_stable": is_stable} for key_dtype in [onp.float32] for val_dtype in [onp.float32] for shape in [(3,), (5, 3)] for axis in [len(shape) - 1] for is_stable in [False, True] for rng_factory in [jtu.rand_default])) def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable, rng_factory): rng = rng_factory(self.rng()) # This test relies on the property that wherever keys are tied, values are # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): flat_keys = onp.arange(onp.prod(shape, dtype=int), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values keys, values = args_maker() fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "rng_factory": rng_factory, "shape": shape, "dtype": dtype, "k": k} for dtype in [onp.float32,] for shape in [(4,), (5, 5), (2, 1, 4)] for k in [1, 3] for rng_factory in [jtu.rand_default])) def testTopKGrad(self, shape, dtype, k, rng_factory): flat_values = onp.arange(onp.prod(shape, dtype=int), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, axes), "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes, "rng_factory": rng_factory} for dtype in float_dtypes for shape, idxs, axes in [ [(3, 4, 5), (onp.array([0, 2, 1]),), (0,)], [(3, 4, 5), (onp.array([-1, -2]),), (0,)], [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 1)], [(3, 4, 5), (onp.array([0, 2]), onp.array([1, 3])), (0, 2)], ] for rng_factory in [jtu.rand_default])) def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): rng = rng_factory(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for shape, idxs, dnums, slice_sizes, max_idx in [ ((5,), onp.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,), 5), ((10,), onp.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,), 9), ((10, 5,), onp.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes) x = rng(shape, dtype) check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] # Scatters with conflicting indices are not deterministic on GPU, so we # use indices that do not collide. for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)] for rng_factory in [jtu.rand_default])) def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): # https://github.com/google/jax/issues/1901 def f(x): n = x.shape[0] y = onp.arange(n, dtype=x.dtype) return jax.ops.index_update(x, onp.diag_indices(n), y) rng = jtu.rand_default(self.rng()) check_grads(f, (rng((5, 5), onp.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_factory": rng_factory, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), onp.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), onp.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), onp.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ] for rng_idx_factory in [partial(jtu.rand_int, high=max(arg_shape))] for rng_factory in [jtu.rand_default])) def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums, rng_factory, rng_idx_factory): rng = rng_factory(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = onp.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with core.skipping_checks(): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test def testRemainder(self): rng = onp.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(3, 4)) y = rng.uniform(0.7, 1.9, size=(3, 1)) assert not set(onp.unique(x)) & set(onp.unique(y)) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) rng = onp.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(1, 4)) y = rng.uniform(0.7, 1.9, size=(3, 4)) assert not set(onp.unique(x)) & set(onp.unique(y)) tol = 1e-1 if num_float_bits(onp.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) def testHigherOrderGradientOfReciprocal(self): # Regression test for https://github.com/google/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) self.assertAllClose(onp.float32(0.0439453125), grad_fn(onp.float32(4.)))
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "rng_factory": jtu.rand_some_inf_and_nan if jtu.device_under_test() != "cpu" else jtu.rand_default, "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims } for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng_factory, shape, dtype, axis, keepdims): rng = rng_factory(self.rng()) # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff): rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True, rtol=1e-5) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng_factory, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True, tol={ onp.float32: 1e-3, onp.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue980(self): x = onp.full((4, ), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4, ), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b } for shape_group in compatible_shapes for dtype in float_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product( *((shape_group, shape_group) if use_b else (shape_group, ))) for axis in range( -max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in itertools.combinations_with_replacement( all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement( rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4, ), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4, ), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters( {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS) for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) def testOp(self, onp_op, lnp_op, rng, shapes, dtypes): args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims} for rec in JAX_REDUCER_RECORDS for shape in all_shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True]) def testReducer(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_axis={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis} for rec in JAX_ARGMINMAX_RECORDS for shape in all_shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape))) def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis): def onp_fun(array_to_reduce): return onp_op(array_to_reduce, axis) def lnp_fun(array_to_reduce): return lnp_op(array_to_reduce, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("matrix-scalar", (3, 3), ()), ("scalar-matrix", (), (3, 3)), ("matrix-vector", (4, 5), (5,)), ("vector-matrix", (6,), (6, 4)), ("matrix-matrix", (3, 4), (4, 5)), ("tensor-vector", (4, 3, 2), (2,)), ("vector-tensor", (2,), (3, 2, 4)), ("tensor-matrix", (4, 3, 2), (2, 5)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-tensor", (2, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)) def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("vector-vector", (3,), (3,)), ("matrix-vector", (3, 3), (3,)), ("vector-matrix", (3,), (3, 3)), ("matrix-matrix", (3, 3), (3, 3)), ("vector-tensor", (3,), (5, 3, 2)), ("tensor-vector", (5, 3, 2), (2,)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-matrix", (5, 2, 3), (3, 2)), ("tensor-tensor", (5, 3, 4), (5, 4, 1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)) def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_amin={}_amax={}".format( jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, "rng": jtu.rand_default()} for shape in all_shapes for dtype in float_dtypes for a_min, a_max in [(-1, None), (None, 1), (-1, 1)]) def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng): onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_decimals={}".format( jtu.format_shape_dtype_string(shape, dtype), decimals), "shape": shape, "dtype": dtype, "decimals": decimals, "rng": jtu.rand_default()} for shape in all_shapes for dtype in float_dtypes for decimals in [0, 1, -2]) def testRoundStaticDecimals(self, shape, dtype, decimals, rng): onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), ",".join(onp.dtype(dtype).name for dtype in dtypes)), "axis": axis, "base_shape": base_shape, "dtypes": dtypes, "rng": jtu.rand_default()} for num_arrs in [3] for dtypes in CombosWithReplacement(default_dtypes, num_arrs) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape))) def testConcatenate(self, axis, base_shape, dtypes, rng): wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] onp_fun = lambda *args: onp.concatenate(args, axis=axis) lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}".format( jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), "shape": shape, "dtypes": dtypes, "rng": rng} for dtypes in [ [onp.float32], [onp.float32, onp.float32], [onp.float32, onp.int32, onp.float32], [onp.float32, onp.int64, onp.float32], [onp.float32, onp.int32, onp.float64], ] for shape in [(), (2,), (3, 4), (1, 100)] for rng in [jtu.rand_default()]) def testStack(self, shape, dtypes, rng): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape=[{}]_indtype={}_outdtype={}".format( "_".join(str(d) for d in shape), onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name), "shape": shape, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default()} for shape in all_shapes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes) def testFull(self, shape, fill_value_dtype, out_dtype, rng): onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) args_maker = lambda: [rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_axis={}_{}sections".format( jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype, "rng": jtu.rand_default()} for shape, axis, num_sections in [ ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] for dtype in default_dtypes) def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng): onp_fun = lambda x: onp.split(x, num_sections, axis=axis) lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_outshape={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype)), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng": jtu.rand_default()} for dtype in default_dtypes for arg_shape, out_shape in [ ((3, 4), 12), ((3, 4), (12,)), ((3, 4), -1), ((2, 1, 4), (-1,)), ((2, 2, 4), (2, 8)) ]) def testReshape(self, arg_shape, out_shape, dtype, rng): onp_fun = lambda x: onp.reshape(x, out_shape) lnp_fun = lambda x: lnp.reshape(x, out_shape) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_expanddim={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), dim), "arg_shape": arg_shape, "dtype": dtype, "dim": dim, "rng": jtu.rand_default()} for arg_shape in [(), (3,), (3, 4)] for dtype in default_dtypes for dim in range(-len(arg_shape)+1, len(arg_shape))) def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng): onp_fun = lambda x: onp.expand_dims(x, dim) lnp_fun = lambda x: lnp.expand_dims(x, dim) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_axes=({},{})".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, "rng": jtu.rand_default()} for arg_shape, ax1, ax2 in [ ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] for dtype in default_dtypes) def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng): onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_axis={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax), "arg_shape": arg_shape, "dtype": dtype, "ax": ax, "rng": jtu.rand_default()} for arg_shape, ax in [ ((3, 1), None), ((3, 1), 1), ((1, 3, 1), (0, 2)), ((1, 4, 1), (0,))] for dtype in default_dtypes) def testSqueeze(self, arg_shape, dtype, ax, rng): onp_fun = lambda x: onp.squeeze(x, ax) lnp_fun = lambda x: lnp.squeeze(x, ax) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_arg{}".format(i), "arg": arg} for i, arg in enumerate([ [1, 2, 3], [1., 2., 3.], [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]], [[3, onp.array(2), 1], onp.arange(3.)], ])) def testArray(self, arg): args_maker = lambda: [arg] self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True) def testAllClose(self): rng = onp.random.RandomState(0) x = rng.randn(2, 2) y = rng.randn(2) def same(list1, list2): allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) elements_close = list(map(allclose, list1, list2)) return lnp.all(lnp.array(elements_close)) csame = api.jit(same) a1 = same((x, y), (x, y)) a2 = csame((x, y), (x, y)) a3 = csame((x, y), (x, 2 * y)) self.assertTrue(a1) self.assertTrue(a2) self.assertFalse(a3) @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure def DISABLED_testOnesBroadcastingConstantHandler(self): # TODO(mattjj): update this test for jax3 def fun(x): ones = lnp.ones((3, 4)) assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) # To check that the constant handler generates a Broadcast for stride-zero # arrays, we monkey-patch the client instance. # TODO(mattjj): once we have better HLO dumping and inspecting facilities, # we can check the HLO more directly. c = x._node.c Broadcast = c.Broadcast # pylint: disable=invalid-name was_called = [] c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) out = x + ones # the ndarray constant handler should call Broadcast here assert was_called, "Broadcast was not called." return out fun = api.jit(fun) out_val = fun(lnp.ones(4)) self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) def testZeroStridesConstantHandler(self): raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) def fun(x): return x * const fun = api.jit(fun) out_val = fun(3.) self.assertAllClose(out_val, 3. * const, check_dtypes=False) def testIsInstanceNdarrayDuringTracing(self): arr = onp.ones(3) @api.jit def f(x): self.assertIsInstance(x, lnp.ndarray) return lnp.sum(x) f(arr) def testNonArrayErrorMessage(self): x = [1., 2.] y = onp.array([3., 4.]) def g(x, y): return lnp.add(x, y) def f(x, y): return lnp.dot(x, y) self.assertRaises(TypeError, lambda: g(x, y)) self.assertRaises(TypeError, lambda: f(x, y)) self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) def testAbstractionErrorMessage(self): @api.jit def f(x, n): for _ in range(n): x = x * x return x self.assertRaises(TypeError, lambda: f(3., 3)) @api.jit def g(x): if x > 0.: return x * 2 else: return x + 2 self.assertRaises(TypeError, lambda: g(3.)) def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self): # TODO(mattjj): update this for jax3 foo = lnp._not_implemented(lambda x: x) # No error if there's no tracing. foo(onp.arange(3)) cfoo = api.jit(foo) self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) # TODO(mattjj): test infix operator overrides def DISABLED_testRavel(self): # TODO(mattjj): support this method-based syntax? rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims} for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, test_autodiff): args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue980(self): x = onp.full((4,), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4,), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b } for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product( *((shape_group, shape_group) if use_b else (shape_group, ))) for axis in range( -max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in itertools.combinations_with_replacement( all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement( rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4, ), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4, ), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_maxdegree={}_inputsize={}".format( l_max, num_z), "l_max": l_max, "num_z": num_z } for l_max, num_z in zip([1, 2, 3], [6, 7, 8]))) def testLpmn(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): val, derivative = osp_special.lpmn(l_max, l_max, z[i]) expected_p_vals[:, :, i] = val expected_p_derivatives[:, :, i] = derivative with self.subTest('Test values.'): self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6) with self.subTest('Test derivatives.'): self.assertAllClose(actual_p_derivatives, expected_p_derivatives, rtol=1e-6, atol=8.4e-4) with self.subTest('Test JIT compatibility'): args_maker = lambda: [z] lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z) self._CompileAndCheck(lsp_special_fn, args_maker) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_maxdegree={}_inputsize={}".format( l_max, num_z), "l_max": l_max, "num_z": num_z } for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64]))) def testNormalizedLpmnValues(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) is_normalized = True actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0] def apply_normalization(a): """Applies normalization to the associated Legendre functions.""" num_m, num_l, _ = a.shape a_normalized = np.zeros_like(a) for m in range(num_m): for l in range(num_l): c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m) c1 = (4.0 * np.pi) * osp_special.factorial(l + m) c2 = np.sqrt(c0 / c1) a_normalized[m, l] = c2 * a[m, l] return a_normalized # The results from scipy are not normalized and the comparison requires # normalizing the results. expected_p_vals_normalized = apply_normalization(expected_p_vals) with self.subTest('Test accuracy.'): self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6) with self.subTest('Test JIT compatibility'): args_maker = lambda: [z] lsp_special_fn = lambda z: lsp_special.lpmn_values( l_max, l_max, z, is_normalized) self._CompileAndCheck(lsp_special_fn, args_maker) def testSphHarmAccuracy(self): m = jnp.arange(-3, 3)[:, None] n = jnp.arange(3, 6) n_max = 5 theta = 0.0 phi = jnp.pi expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = osp_special.sph_harm(m, n, theta, phi) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) def testSphHarmOrderZeroDegreeZero(self): """Tests the spherical harmonics of order zero and degree zero.""" theta = jnp.array([0.3]) phi = jnp.array([2.3]) n_max = 0 expected = jnp.array([1.0 / jnp.sqrt(4.0 * np.pi)]) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([0]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=1.1e-7, atol=3e-8) def testSphHarmOrderZeroDegreeOne(self): """Tests the spherical harmonics of order one and degree zero.""" theta = jnp.array([2.0]) phi = jnp.array([3.1]) n_max = 1 expected = jnp.sqrt(3.0 / (4.0 * np.pi)) * jnp.cos(phi) actual = jnp.real( lsp_special.sph_harm(jnp.array([0]), jnp.array([1]), theta, phi, n_max)) self.assertAllClose(actual, expected, rtol=7e-8, atol=1.5e-8) def testSphHarmOrderOneDegreeOne(self): """Tests the spherical harmonics of order one and degree one.""" theta = jnp.array([2.0]) phi = jnp.array([2.5]) n_max = 1 expected = (-1.0 / 2.0 * jnp.sqrt(3.0 / (2.0 * np.pi)) * jnp.sin(phi) * jnp.exp(1j * theta)) actual = lsp_special.sph_harm(jnp.array([1]), jnp.array([1]), theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=6e-8) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_maxdegree={}_inputsize={}_dtype={}'.format(l_max, num_z, dtype), 'l_max': l_max, 'num_z': num_z, 'dtype': dtype } for l_max, num_z in zip([1, 3, 8, 10], [2, 6, 7, 8]) for dtype in jtu.dtypes.all_integer)) def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" n_max = l_max shape = (num_z, ) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) lsp_special_fn = partial(lsp_special.sph_harm, n_max=n_max) def args_maker(): m = rng(shape, dtype) n = abs(m) theta = jnp.linspace(-4.0, 5.0, num_z) phi = jnp.linspace(-2.0, 1.0, num_z) return m, n, theta, phi with self.subTest('Test JIT compatibility'): self._CompileAndCheck(lsp_special_fn, args_maker) with self.subTest('Test against numpy.'): self._CheckAgainstNumpy(osp_special.sph_harm, lsp_special_fn, args_maker) def testSphHarmCornerCaseWithWrongNmax(self): """Tests the corner case where `n_max` is not the maximum value of `n`.""" m = jnp.array([2]) n = jnp.array([10]) n_clipped = jnp.array([6]) n_max = 6 theta = jnp.array([0.9]) phi = jnp.array([0.2]) expected = lsp_special.sph_harm(m, n, theta, phi, n_max) actual = lsp_special.sph_harm(m, n_clipped, theta, phi, n_max) self.assertAllClose(actual, expected, rtol=1e-8, atol=9e-5) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_shape={}' '_n_zero_sv={}_degeneracy={}_geometric_spectrum={}' '_max_sv={}_method={}_side={}' '_nonzero_condition_number={}_seed={}'.format( jtu.format_shape_dtype_string( shape, jnp.dtype(dtype).name).replace(" ", ""), n_zero_sv, degeneracy, geometric_spectrum, max_sv, method, side, nonzero_condition_number, seed), 'n_zero_sv': n_zero_sv, 'degeneracy': degeneracy, 'geometric_spectrum': geometric_spectrum, 'max_sv': max_sv, 'shape': shape, 'method': method, 'side': side, 'nonzero_condition_number': nonzero_condition_number, 'dtype': dtype, 'seed': seed } for n_zero_sv in n_zero_svs for degeneracy in degeneracies for geometric_spectrum in geometric_spectra for max_sv in max_svs for shape in polar_shapes for method in methods for side in sides for nonzero_condition_number in nonzero_condition_numbers for dtype in jtu.dtypes.floating for seed in seeds)) def testPolar(self, n_zero_sv, degeneracy, geometric_spectrum, max_sv, shape, method, side, nonzero_condition_number, dtype, seed): """ Tests jax.scipy.linalg.polar.""" if jtu.device_under_test() != "cpu": if jnp.dtype(dtype).name in ("bfloat16", "float16"): raise unittest.SkipTest("Skip half precision off CPU.") if method == "svd": raise unittest.SkipTest("Can't use SVD mode on TPU/GPU.") np.random.seed(seed) matrix, _ = _initialize_polar_test(shape, n_zero_sv, degeneracy, geometric_spectrum, max_sv, nonzero_condition_number, dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jsp.linalg.polar, matrix, method=method, side=side) return unitary, posdef = jsp.linalg.polar(matrix, method=method, side=side) if shape[0] >= shape[1]: should_be_eye = np.matmul(unitary.conj().T, unitary) else: should_be_eye = np.matmul(unitary, unitary.conj().T) tol = 10 * jnp.finfo(matrix.dtype).eps eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): self.assertAllClose(eye_mat, should_be_eye, atol=tol * min(shape)) with self.subTest('Test Hermiticity.'): self.assertAllClose(posdef, posdef.conj().T, atol=tol * jnp.linalg.norm(posdef)) ev, _ = np.linalg.eigh(posdef) ev = ev[np.abs(ev) > tol * np.linalg.norm(posdef)] negative_ev = jnp.sum(ev < 0.) with self.subTest('Test positive definiteness.'): assert negative_ev == 0. if side == "right": recon = jnp.matmul(unitary, posdef, precision=lax.Precision.HIGHEST) elif side == "left": recon = jnp.matmul(posdef, unitary, precision=lax.Precision.HIGHEST) with self.subTest('Test reconstruction.'): self.assertAllClose(matrix, recon, atol=tol * jnp.linalg.norm(matrix)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed, jnp.dtype(dtype).name), 'linear_size': linear_size, 'seed': seed, 'dtype': dtype } for linear_size in linear_sizes for seed in seeds for dtype in jtu.dtypes.floating)) def test_spectral_dac_eigh(self, linear_size, seed, dtype): if jtu.device_under_test != "cpu": raise unittest.SkipTest("Skip eigh off CPU for now.") if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") np.random.seed(seed) H = np.random.randn(linear_size, linear_size) H = jnp.array(0.5 * (H + H.conj().T)).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.scipy.eigh.eigh, H) return evs, V = jax._src.scipy.eigh.eigh(H) ev_exp, eV_exp = jnp.linalg.eigh(H) HV = jnp.dot(H, V, precision=lax.Precision.HIGHEST) vV = evs * V eps = jnp.finfo(H.dtype).eps atol = jnp.linalg.norm(H) * eps self.assertAllClose(ev_exp, jnp.sort(evs), atol=20 * atol) self.assertAllClose(HV, vV, atol=30 * atol) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_linear_size_={}_seed={}_dtype={}'.format(linear_size, seed, jnp.dtype(dtype).name), 'linear_size': linear_size, 'seed': seed, 'dtype': dtype } for linear_size in linear_sizes for seed in seeds for dtype in jtu.dtypes.floating)) def test_spectral_dac_svd(self, linear_size, seed, dtype): if jnp.dtype(dtype).name in ("bfloat16", "float16"): if jtu.device_under_test() != "cpu": raise unittest.SkipTest("Skip half precision off CPU.") np.random.seed(seed) A = np.random.randn(linear_size, linear_size).astype(dtype) if jnp.dtype(dtype).name in ("bfloat16", "float16"): self.assertRaises(NotImplementedError, jax._src.scipy.eigh.svd, A) return S_expected = np.linalg.svd(A, compute_uv=False) U, S, V = jax._src.scipy.eigh.svd(A) recon = jnp.dot((U * S), V, precision=lax.Precision.HIGHEST) eps = jnp.finfo(dtype).eps eps = eps * jnp.linalg.norm(A) * 10 self.assertAllClose(np.sort(S), np.sort(S_expected), atol=eps) self.assertAllClose(A, recon, atol=eps) # U is unitary. u_unitary_delta = jnp.dot(U.conj().T, U, precision=lax.Precision.HIGHEST) u_eye = jnp.eye(u_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(u_unitary_delta, u_eye, atol=eps) # V is unitary. v_unitary_delta = jnp.dot(V.conj().T, V, precision=lax.Precision.HIGHEST) v_eye = jnp.eye(v_unitary_delta.shape[0], dtype=dtype) self.assertAllClose(v_unitary_delta, v_eye, atol=eps)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b } for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product( *((shape_group, shape_group) if use_b else (shape_group, ))) for axis in range( -max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in itertools.combinations_with_replacement( all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement( rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4, ), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4, ), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_maxdegree={}_inputsize={}".format( l_max, num_z), "l_max": l_max, "num_z": num_z } for l_max, num_z in zip([1, 2, 3], [6, 7, 8]))) def testLpmn(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): val, derivative = osp_special.lpmn(l_max, l_max, z[i]) expected_p_vals[:, :, i] = val expected_p_derivatives[:, :, i] = derivative with self.subTest('Test values.'): self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6) with self.subTest('Test derivatives.'): self.assertAllClose(actual_p_derivatives, expected_p_derivatives, rtol=1e-6, atol=8.4e-4) with self.subTest('Test JIT compatibility'): args_maker = lambda: [z] lsp_special_fn = lambda z: lsp_special.lpmn(l_max, l_max, z) self._CompileAndCheck(lsp_special_fn, args_maker) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_maxdegree={}_inputsize={}".format( l_max, num_z), "l_max": l_max, "num_z": num_z } for l_max, num_z in zip([3, 4, 6, 32], [2, 3, 4, 64]))) def testNormalizedLpmnValues(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) is_normalized = True actual_p_vals = lsp_special.lpmn_values(l_max, l_max, z, is_normalized) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): expected_p_vals[:, :, i] = osp_special.lpmn(l_max, l_max, z[i])[0] def apply_normalization(a): """Applies normalization to the associated Legendre functions.""" num_m, num_l, _ = a.shape a_normalized = np.zeros_like(a) for m in range(num_m): for l in range(num_l): c0 = (2.0 * l + 1.0) * osp_special.factorial(l - m) c1 = (4.0 * np.pi) * osp_special.factorial(l + m) c2 = np.sqrt(c0 / c1) a_normalized[m, l] = c2 * a[m, l] return a_normalized # The results from scipy are not normalized and the comparison requires # normalizing the results. expected_p_vals_normalized = apply_normalization(expected_p_vals) with self.subTest('Test accuracy.'): self.assertAllClose(actual_p_vals, expected_p_vals_normalized, rtol=1e-6, atol=3.2e-6) with self.subTest('Test JIT compatibility'): args_maker = lambda: [z] lsp_special_fn = lambda z: lsp_special.lpmn_values( l_max, l_max, z, is_normalized) self._CompileAndCheck(lsp_special_fn, args_maker)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters( {"testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims} for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True]) def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_misc.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "modes": rec.diff_modes, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for rec in JAX_SPECIAL_FUNCTION_RECORDS for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, modes): # TODO(mattjj): unskip this test combination when real() on tpu is improved # TODO(mattjj): test autodiff if (FLAGS.jax_test_dut and FLAGS.jax_test_dut.startswith("tpu") and not shapes[0]): return absltest.unittest.skip("real() on scalar not supported on tpu") args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": jtu.format_test_name_suffix( "", shapes, dtypes), "rng": rng, "shapes": shapes, "dtypes": dtypes} for shapes in CombosWithReplacement(all_shapes, 3) for dtypes in CombosWithReplacement(default_dtypes, 3) for rng in [jtu.rand_default()]) def testNormLogPdfThreeArgs(self, rng, shapes, dtypes): # TODO(mattjj): test autodiff scipy_fun = osp_stats.norm.logpdf lax_fun = lsp_stats.norm.logpdf def args_maker(): x, loc, scale = map(rng, shapes, dtypes) scale = 0.5 + onp.abs(scale) return [x, loc, scale] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": jtu.format_test_name_suffix( "", shapes, dtypes), "rng": rng, "shapes": shapes, "dtypes": dtypes} for shapes in CombosWithReplacement(all_shapes, 2) for dtypes in CombosWithReplacement(default_dtypes, 2) for rng in [jtu.rand_default()]) def testNormLogPdfTwoArgs(self, rng, shapes, dtypes): # TODO(mattjj): test autodiff scale = 0.5 scipy_fun = functools.partial(osp_stats.norm.logpdf, scale=scale) lax_fun = functools.partial(lsp_stats.norm.logpdf, scale=scale) def args_maker(): return list(map(rng, shapes, dtypes)) self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True)