class TestBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter), "maxiter": maxiter, "func_and_init": func_and_init} for maxiter in [None] for func_and_init in [(rosenbrock, np.zeros(2)), (himmelblau, np.zeros(2)), (matyas, np.ones(2) * 6.), (eggholder, np.ones(2) * 100.)])) def test_minimize(self, maxiter, func_and_init): # Note, cannot compare step for step with scipy BFGS because our line search is _slightly_ different. func, x0 = func_and_init @jit def min_op(x0): result = jax.scipy.optimize.minimize( func(jnp), x0, method='BFGS', options=dict(maxiter=maxiter, gtol=1e-6), ) return result.x jax_res = min_op(x0) scipy_res = scipy.optimize.minimize(func(np), x0, method='BFGS').x self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False) def test_fixes4594(self): n = 2 A = jnp.eye(n) * 1e4 def f(x): return jnp.mean((A @ x) ** 2) results = jax.scipy.optimize.minimize(f, jnp.ones(n), method='BFGS') self.assertAllClose(results.x, jnp.zeros(n), atol=1e-6, rtol=1e-6) @jtu.skip_on_flag('jax_enable_x64', False) def test_zakharov(self): def zakharov_fn(x): ii = jnp.arange(1, len(x) + 1, step=1) answer = zakharovFromIndices(x=x, ii=ii) return answer x0 = jnp.array([600.0, 700.0, 200.0, 100.0, 90.0, 1e4]) eval_func = jax.jit(zakharov_fn) jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS') self.assertLess(jax_res.fun, 1e-6) def test_minimize_bad_initial_values(self): # This test runs deliberately "bad" initial values to test that handling # of failed line search, etc. is the same across implementations initial_value = jnp.array([92, 0.001]) opt_fn = himmelblau(jnp) jax_res = jax.scipy.optimize.minimize( fun=opt_fn, x0=initial_value, method='BFGS', ).x scipy_res = scipy.optimize.minimize( fun=opt_fn, jac=jax.grad(opt_fn), method='BFGS', x0=initial_value ).x self.assertAllClose(scipy_res, jax_res, atol=2e-5, check_dtypes=False) def test_args_must_be_tuple(self): A = jnp.eye(2) * 1e4 def f(x): return jnp.mean((A @ x) ** 2) with self.assertRaisesRegex(TypeError, "args .* must be a tuple"): jax.scipy.optimize.minimize(f, jnp.ones(2), args=45, method='BFGS')
class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve', 'correlate'] for dtype in default_dtypes for shapeset in [onedim_shapes, twodim_shapes, threedim_shapes] for xshape in shapeset for yshape in shapeset)) def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-8} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "op={}_xshape={}_yshape={}_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve2d', 'correlate2d'] for dtype in default_dtypes for xshape in twodim_shapes for yshape in twodim_shapes)) def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = {np.float16: 1e-2, np.float32: 1e-2, np.float64: 1e-14} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_axis={}_type={}_bp={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, type, bp), "shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp } for shape in [(5, ), (4, 5), (3, 4, 5)] for dtype in default_dtypes for axis in [0, -1] for type in ['constant', 'linear'] for bp in [0, [0, 2]])) def testDetrend(self, shape, dtype, axis, type, bp): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp) jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp) tol = {np.float32: 1e-5, np.float64: 1e-12} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
class LaxAutodiffTest(jtu.JaxTestCase): @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.name, shapes, itertools.repeat(dtype)), "op": rec.op, "rng_factory": rec.rng_factory, "shapes": shapes, "dtype": dtype, "order": rec.order, "tol": rec.tol} for shape_group in compatible_shapes for shapes in itertools.combinations_with_replacement(shape_group, rec.nargs) for dtype in rec.dtypes) for rec in LAX_GRAD_OPS)) def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.pow: raise SkipTest("pow grad imprecise on tpu") tol = jtu.join_tolerance(1e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol args = tuple(rng(shape, dtype) for shape in shapes) check_grads(op, args, order, ["fwd", "rev"], tol, tol) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": "_{}_{}".format(rec.op.__name__, special_value), "op": rec.op, "special_value": special_value, "tol": rec.tol} for special_value in rec.values) for rec in LAX_GRAD_SPECIAL_VALUE_TESTS)) def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_from_dtype={}_to_dtype={}".format( jtu.dtype_str(from_dtype), jtu.dtype_str(to_dtype)), "from_dtype": from_dtype, "to_dtype": to_dtype} for from_dtype, to_dtype in itertools.product(inexact_dtypes, repeat=2))) def testConvertElementTypeGrad(self, from_dtype, to_dtype): rng = jtu.rand_default(self.rng()) tol = max(jtu.tolerance(to_dtype, jtu.default_gradient_tolerance), jtu.tolerance(from_dtype, jtu.default_gradient_tolerance)) args = (rng((2, 3), from_dtype),) convert_element_type = lambda x: lax.convert_element_type(x, to_dtype) convert_element_type = jtu.ignore_warning(category=np.ComplexWarning)( convert_element_type) check_grads(convert_element_type, args, 2, ["fwd", "rev"], tol, tol, eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype} for shape in [(), (2, 3)] for dtype in grad_float_dtypes)) def testClampGrad(self, shape, dtype): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) low = operand - dtype(10) high = operand + dtype(10) # Avoids points near the boundary where the gradient may be inaccurate. check_grads(lax.clamp, (operand, low, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, operand, high), 2, ["fwd", "rev"], eps=1e-2) check_grads(lax.clamp, (low, high, operand), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_baseshape=[{}]_dtype={}_narrs={}".format( dim, ",".join(str(d) for d in base_shape), np.dtype(dtype).name, num_arrs), "dim": dim, "base_shape": base_shape, "dtype": dtype, "num_arrs": num_arrs} for num_arrs in [3] for dtype in float_dtypes for base_shape in [(4,), (3, 4), (2, 3, 4)] for dim in range(len(base_shape)))) def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): rng = jtu.rand_default(self.rng()) shapes = [base_shape[:dim] + (size,) + base_shape[dim+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs))] operands = tuple(rng(shape, dtype) for shape in shapes) concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding} for lhs_shape, rhs_shape, all_strides in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)])]) for strides in all_strides for dtype in float_dtypes for padding in ["VALID", "SAME"])) def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv, window_strides=strides, padding=padding, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil} for lhs_shape, rhs_shape, all_strides, all_pads, lhs_dils, rhs_dils in itertools.chain( [((b, i, 3, 4), (j, i, 1, 2), [(1, 1), (1, 2), (2, 1)], [((0, 0), (0, 0)), ((-1, 0), (0, -1)), ((1, 0), (0, 1))], [(1, 1), (2, 1)], [(1, 1)]) for b, i, j in itertools.product([2, 3], repeat=3)], [((4, 2, 1), (3, 2, 1), [(1,)], [((1, 1),), ((0, 0),)], [(1,), (2,)], [(1,), (2,)])]) for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in float_dtypes for padding in all_pads)) def testConvWithGeneralPaddingGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_with_general_padding, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_strides={}_padding={}_lhs_dilation={}_" "rhs_dilation={}_dims={}_feature_group_count={}_batch_group_count={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), strides, padding, lhs_dil, rhs_dil, ",".join(dim_nums), feature_group_count, batch_group_count), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "strides": strides, "padding": padding, "lhs_dil": lhs_dil, "rhs_dil": rhs_dil, "dimension_numbers": dim_nums, "perms": perms, "feature_group_count": feature_group_count, "batch_group_count": batch_group_count} for batch_group_count, feature_group_count in ([(1, 1), (2, 1), (1, 2)]) for lhs_shapes, rhs_shape, all_strides, lhs_dils, rhs_dils in [ ([(b * batch_group_count, i * feature_group_count, 6, 7), (b * batch_group_count, i * feature_group_count, 0, 4)], # lhs_shape (j * batch_group_count * feature_group_count, i, 1, 2), # rhs_shape [(1, 1), (1, 2), (2, 1)], # strides [(1, 1), (2, 1)], # lhs_dils [(1, 1), (2, 2)]) # rhs_dils for b, i, j in itertools.product([1, 2], repeat=3)] for lhs_shape in lhs_shapes for strides in all_strides for rhs_dil in rhs_dils for lhs_dil in lhs_dils for dtype in grad_inexact_dtypes for padding in ([((0, 0), (0, 0)), ((1, 0), (0, 1))] + ([((0, -1), (0, 0))] if lhs_shape[2] != 0 else [])) for dim_nums, perms in [ (("NCHW", "OIHW", "NCHW"), ([0, 1, 2, 3], [0, 1, 2, 3])), (("NHWC", "HWIO", "NHWC"), ([0, 2, 3, 1], [2, 3, 1, 0])), (("NHWC", "OIHW", "NCHW"), ([0, 2, 3, 1], [0, 1, 2, 3]))])) def testConvGeneralDilatedGrad(self, lhs_shape, rhs_shape, dtype, strides, padding, lhs_dil, rhs_dil, dimension_numbers, perms, feature_group_count, batch_group_count): if dtype == np.float16: raise SkipTest("float16 numerical issues") # TODO(mattjj): resolve if (jtu.device_under_test() == "cpu" and dtype == np.float64 and lhs_shape == (1,1,6,7) and rhs_shape == (2,1,1,2) and strides == (2, 1) and padding == ((0, -1), (0, 0)) and lhs_dil == (1, 1) and rhs_dil == (1, 1)): # TODO(b/173608403): reenable after LLVM fix. raise SkipTest("Skipping test due to LLVM lowering bug") rng = jtu.rand_default(self.rng()) tol = {dtypes.bfloat16: 1e-0, np.float16: 5e-1, np.float32: 1e-3} # permute shapes to match dim_spec, scale by feature_group_count lhs_perm, rhs_perm = perms lhs_shape = list(np.take(lhs_shape, lhs_perm)) rhs_shape = list(np.take(rhs_shape, rhs_perm)) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv_general_dilated, window_strides=strides, padding=padding, lhs_dilation=lhs_dil, rhs_dilation=rhs_dil, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype} for lhs_shape in [(2,), (3, 2)] for rhs_shape in [(2,), (2, 4)] for dtype in float_dtypes)) def testDotGrad(self, lhs_shape, rhs_shape, dtype): rng = jtu.rand_default(self.rng()) tol = {np.float16: 1e-1, np.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "Precision.HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs_shape={}_rhs_shape={}_dimension_numbers={}" .format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), dimension_numbers), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "dimension_numbers": dimension_numbers} for lhs_shape, rhs_shape, dimension_numbers in [ ((3, 2), (2, 4), (([1], [0]), ([], []))), ((3, 5), (2, 5), (([1], [1]), ([], []))), ((5, 3), (5, 2), (([0], [0]), ([], []))), ((3, 3, 2), (3, 2, 4), (([2], [1]), ([0], [0]))), ((3, 5, 2), (2, 4, 5), (([2], [0]), ([1], [2]))), ((7, 3, 5, 2), (2, 2, 4, 5), (([3], [0]), ([2], [3]))), ] for dtype in float_dtypes)) def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "Precision.HIGHEST" in s @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_dtype={}_broadcast_sizes={}".format( shape, np.dtype(dtype).name, broadcast_sizes), "shape": shape, "dtype": dtype, "broadcast_sizes": broadcast_sizes} for shape in [(), (2, 3)] for dtype in float_dtypes for broadcast_sizes in [(), (2,), (1, 2)])) def testBroadcastGrad(self, shape, dtype, broadcast_sizes): rng = jtu.rand_default(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_bcdims={}".format( jtu.format_shape_dtype_string(inshape, dtype), outshape, broadcast_dimensions), "inshape": inshape, "dtype": dtype, "outshape": outshape, "dimensions": broadcast_dimensions} for inshape, outshape, broadcast_dimensions in [ ([2], [2, 2], [0]), ([2], [2, 2], [1]), ([2], [2, 3], [0]), ([], [2, 3], []), ] for dtype in float_dtypes)) def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions): rng = jtu.rand_default(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_outshape={}_perm={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype), permutation), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "permutation": permutation} for dtype in float_dtypes for arg_shape, out_shape, permutation in [ [(3, 4), (12,), None], [(2, 1, 4), (8,), None], [(2, 2, 4), (2, 8), None], [(3, 4), (12,), (0, 1)], [(3, 4), (12,), (1, 0)], [(2, 1, 4), (8,), (0, 2, 1)], [(2, 1, 4), (8,), (2, 0, 1)], [(2, 2, 4), (2, 8), (0, 2, 1)], [(2, 2, 4), (2, 8), (2, 0, 1)], ])) def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype): rng = jtu.rand_default(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_pads={}" .format(jtu.format_shape_dtype_string(shape, dtype), pads), "shape": shape, "dtype": dtype, "pads": pads} for shape in [(2, 3)] for dtype in float_dtypes for pads in [[(1, 2, 1), (0, 1, 0)], [(-1, 0, 0), (-1, 0, 2)]])) def testPadGrad(self, shape, dtype, pads): rng = jtu.rand_small(self.rng()) operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads) check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = np.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.) def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) dimensions = [0] check_grads(rev, (np.array([3., 2., 1.]),), 2) dimensions = [0, 1] check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={np.float32: 3e-3}) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_), jtu.format_shape_dtype_string(arg_shape, dtype)), "pred_shape": pred_shape, "arg_shape": arg_shape, "dtype": dtype} for arg_shape in [(), (3,), (2, 3)] for pred_shape in ([(), arg_shape] if arg_shape else [()]) for dtype in float_dtypes)) def testSelectGrad(self, pred_shape, arg_shape, dtype): rng = jtu.rand_default(self.rng()) pred = rng(pred_shape, np.bool_) on_true = rng(arg_shape, dtype) on_false = rng(arg_shape, dtype) select = lambda on_true, on_false: lax.select(pred, on_true, on_false) check_grads(select, (on_true, on_false), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_limit_indices={}_strides={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, limit_indices, strides), "shape": shape, "dtype": dtype, "starts": start_indices, "limits": limit_indices, "strides": strides} for shape, start_indices, limit_indices, strides in [ [(3,), (1,), (2,), None], [(7,), (4,), (7,), None], [(5,), (1,), (5,), (2,)], [(8,), (1,), (6,), (2,)], [(5, 3), (1, 1), (3, 2), None], [(5, 3), (1, 1), (3, 1), None], [(7, 5, 3), (4, 0, 1), (7, 1, 3), None], [(5, 3), (1, 1), (2, 1), (1, 1)], [(5, 3), (1, 1), (5, 3), (2, 1)], [(3, 3, 5), (0, 2, 0), (3, 2, 5), (1, 2, 1)] ] for dtype in float_dtypes)) def testSliceGrad(self, shape, dtype, starts, limits, strides): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_size_indices={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, size_indices), "shape": shape, "dtype": dtype, "start_indices": start_indices, "size_indices": size_indices} for shape, start_indices, size_indices in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes)) def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_start_indices={}_update_shape={}".format( jtu.format_shape_dtype_string(shape, dtype), start_indices, update_shape), "shape": shape, "dtype": dtype, "start_indices": start_indices, "update_shape": update_shape} for shape, start_indices, update_shape in [ [(3,), (1,), (1,)], [(5, 3), (1, 1), (3, 1)], [(7, 5, 3), (4, 1, 0), (2, 0, 1)], ] for dtype in float_dtypes)) def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, update_shape): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) update = rng(update_shape, dtype) start_indices = np.array(start_indices) dus = lambda x, y: lax.dynamic_update_slice(x, y, start_indices) check_grads(dus, (operand, update), 2, ["fwd", "rev"], eps=1.) dus = lambda x: lax.dynamic_update_slice(x, update, start_indices) check_grads(dus, (operand,), 2, ["fwd", "rev"], eps=1.) dus = lambda y: lax.dynamic_update_slice(operand, y, start_indices) check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_perm={}".format( jtu.format_shape_dtype_string(shape, dtype), perm), "shape": shape, "dtype": dtype, "perm": perm} for shape, perm in [ [(3, 4), (1, 0)], [(3, 4), (0, 1)], [(3, 4, 5), (2, 1, 0)], [(3, 4, 5), (1, 0, 2)], ] for dtype in float_dtypes)) def testTransposeGrad(self, shape, dtype, perm): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_inshape={}_reducedims={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims), "op": op, "init_val": init_val, "shape": shape, "dtype": dtype, "dims": dims, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, float_dtypes + jtu.dtypes.complex, jtu.rand_default), (-np.inf, lax.max, grad_inexact_dtypes, jtu.rand_unique_int), (np.inf, lax.min, grad_inexact_dtypes, jtu.rand_unique_int), (1, lax.mul, grad_float_dtypes, partial(jtu.rand_default, scale=1)), ] for dtype in dtypes for shape, dims in [ [(), ()], [(3, 4, 5), ()], [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], [(3, 0, 5), (1,)], ])) def testReduceGrad(self, op, init_val, shape, dtype, dims, rng_factory): rng = rng_factory(self.rng()) if jtu.device_under_test() == "tpu" and op is lax.mul: raise SkipTest("unimplemented case") tol = {dtypes.bfloat16: 2e-1, np.float16: 1e-1, np.float32: 1e-1, np.float64: 1e-3, np.complex64: 1e-1} operand = rng(shape, dtype) init_val = np.asarray(init_val, dtype=dtype) reduce = lambda operand: lax.reduce(operand, init_val, op, dims) eps = (1.0 if dtypes.finfo(dtype).bits == 16 and op is lax.add else 1e-1 if dtype == dtypes.bfloat16 else 1e-2 if dtypes.finfo(dtype).bits == 32 else None) if op not in (lax.max, lax.min) or all(d > 0 for d in shape): check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_reducedims={}" .format(jtu.format_shape_dtype_string(shape, dtype), dims), "shape": shape, "dtype": dtype, "dims": dims} for dtype in grad_float_dtypes for shape, dims in [ [(3, 4, 5), ()], [(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)], [(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)], [(3, 1), (1,)], [(3, 0, 5), (1,)], ])) def testReducePairGrad(self, shape, dtype, dims): rng = jtu.rand_default(self.rng(), scale=1) tol = {np.float32: 1e-2, np.float64: 1e-4} operands = (rng(shape, dtype), rng(shape, dtype)) init_vals = (np.array(0, dtype), np.array(1, dtype)) def op(xs, ys): return (xs[0] + ys[0], xs[1] * ys[1]) reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims) check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}" "_basedilation={}_windowdilation={}") .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims, strides, padding, base_dilation, window_dilation), "op": op, "init_val": init_val, "dtype": dtype, "shape": shape, "dims": dims, "strides": strides, "padding": padding, "base_dilation": base_dilation, "window_dilation": window_dilation, "rng_factory": rng_factory} for init_val, op, dtypes, rng_factory in [ (0, lax.add, grad_float_dtypes, jtu.rand_small), (-np.inf, lax.max, grad_float_dtypes, jtu.rand_unique_int), (np.inf, lax.min, grad_float_dtypes, jtu.rand_unique_int), ] for shape, dims, strides, padding, base_dilation, window_dilation in ( itertools.chain( itertools.product( [(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)], ["VALID", "SAME", [(0, 3), (1, 2)]], [(1, 1)] + ([(2, 3)] if op is lax.add else []), [(1, 1)] + ([(1, 2)] if op is lax.add else [])), itertools.product( [(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)], ["VALID", "SAME", [(0, 1), (1, 0), (2, 3), (0, 2)]], [(1, 1, 1, 1)] + ([(2, 1, 3, 2)] if op is lax.add else []), [(1, 1, 1, 1)] + ([(1, 2, 2, 1)] if op is lax.add else [])))) for dtype in dtypes)) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def testReduceWindowGrad( self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation, rng_factory): rng = rng_factory(self.rng()) init_val = np.asarray(init_val, dtype=dtype) gradient_order = 3 # We need this conditional and the corresponding loop logic to be in the # test method, rather than at the parameterized test level, because it # depends on FLAGS for the device under test. # TODO(b/31565929): enable when fixed. if jtu.device_under_test() == "tpu" and op is not lax.add: if (len(shape) != 4 or dims != (1, 1, 2, 1) or not isinstance(padding, str)): raise SkipTest("Only R4 SelectAndScatter implemented on TPU") # TODO(b/73062247): need variadic reduce-window for better precision. gradient_order = 1 def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) operand = rng(shape, dtype) if op is lax.add: eps = 1. tol = None else: # this test can fail if there are duplicates in operand self.assertEqual(np.unique(operand).size, operand.size, msg="test requires operand elements to be unique.") eps = 1e-2 tol = {np.float16: 1e-1, np.float32: 6e-2, np.float64: 6e-2} check_grads(fun, (operand,), gradient_order, ["fwd", "rev"], tol, tol, eps) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_op={}_shape={}_axis={}_reverse={}" .format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), axis, reverse), "op": op, "shape": shape, "dtype": dtype, "axis": axis, "reverse": reverse} for op, types in [ (lax.cumsum, [np.float32, np.float64]), (lax.cumprod, [np.float32, np.float64]), ] for dtype in types for shape in [[10], [3, 4, 5]] for axis in range(len(shape)) for reverse in [False, True])) def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),), order=2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, is_stable), "shape": shape, "dtype": dtype, "axis": axis, "is_stable": is_stable} for dtype in [np.float32] for shape in [(5,), (5, 7)] for axis in [len(shape) - 1] for is_stable in [False, True])) def testSortGrad(self, shape, dtype, axis, is_stable): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2) # TODO(b/205052657): enable more tests when supported @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_keyshape={}_valshape={}_axis={}_isstable={}".format( jtu.format_shape_dtype_string(shape, key_dtype), jtu.format_shape_dtype_string(shape, val_dtype), axis, is_stable), "shape": shape, "key_dtype": key_dtype, "val_dtype": val_dtype, "axis": axis, "is_stable": is_stable} for key_dtype in [np.float32] for val_dtype in [np.float32] for shape in [(3,), (5, 3)] for axis in [len(shape) - 1] for is_stable in [False, True])) def testSortKeyValGrad(self, shape, key_dtype, val_dtype, axis, is_stable): rng = jtu.rand_default(self.rng()) # This test relies on the property that wherever keys are tied, values are # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): flat_keys = np.arange(prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values keys, values = args_maker() fun = lambda keys, values: lax.sort_key_val(keys, values, axis, is_stable) check_grads(fun, (keys, values), 2, ["fwd", "rev"], 1e-2, 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_k={}".format( jtu.format_shape_dtype_string(shape, dtype), k), "shape": shape, "dtype": dtype, "k": k} for dtype in [np.float32,] for shape in [(4,), (5, 5), (2, 1, 4)] for k in [1, 3])) def testTopKGrad(self, shape, dtype, k): flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, axes), "shape": shape, "dtype": dtype, "idxs": idxs, "axes": axes} for dtype in float_dtypes for shape, idxs, axes in [ [(3, 4, 5), (np.array([0, 2, 1]),), (0,)], [(3, 4, 5), (np.array([-1, -2]),), (0,)], [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 1)], [(3, 4, 5), (np.array([0, 2]), np.array([1, 3])), (0, 2)], ])) def testIndexTakeGrad(self, shape, dtype, idxs, axes): rng = jtu.rand_default(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), idxs, dnums, slice_sizes), "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for shape, idxs, dnums, slice_sizes, max_idx in [ ((5,), np.array([[0], [2]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)), (1,), 5), ((10,), np.array([[0], [0], [0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,)), (2,), 9), ((10, 5,), np.array([[0], [2], [1]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)), (1, 3), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)])) def testGatherGrad(self, shape, dtype, idxs, dnums, slice_sizes, rng_idx_factory): rng = jtu.rand_default(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) gather = lambda x: lax.gather(x, idxs, dimension_numbers=dnums, slice_sizes=slice_sizes) x = rng(shape, dtype) check_grads(gather, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] for rng_idx_factory in [partial(jtu.rand_int, high=max_idx)])) def testScatterAddGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_idx_factory): rng = jtu.rand_default(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_add = lambda x, y: lax.scatter_add(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_add, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums, "rng_idx_factory": rng_idx_factory} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums, max_idx in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 4), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,)), 9), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)), 3), ] # Scatters with conflicting indices are not deterministic on GPU, so we # use indices that do not collide. for rng_idx_factory in [partial(jtu.rand_unique_int, high=max_idx)])) def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, rng_idx_factory): rng = jtu.rand_default(self.rng()) rng_idx = rng_idx_factory(self.rng()) idxs = rng_idx(idxs.shape, idxs.dtype) scatter = lambda x, y: lax.scatter(x, idxs, y, dimension_numbers=dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): # https://github.com/google/jax/issues/1901 def f(x): n = x.shape[0] y = np.arange(n, dtype=x.dtype) return jax.ops.index_update(x, np.diag_indices(n), y) rng = jtu.rand_default(self.rng()) check_grads(f, (rng((5, 5), np.float32),), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ])) def testScatterMax(self, arg_shape, dtype, idxs, update_shape, dnums): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_max = lambda x, y: lax.scatter_max(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_max, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_idxs={}_update={}_dnums={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), idxs, update_shape, dnums), "arg_shape": arg_shape, "dtype": dtype, "idxs": idxs, "update_shape": update_shape, "dnums": dnums} for dtype in grad_float_dtypes for arg_shape, idxs, update_shape, dnums in [ ((5,), np.array([[0], [2]]), (2,), lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ((10,), np.array([[0], [0], [0]]), (3, 2), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))), ((10, 5,), np.array([[0], [2], [1]]), (3, 3), lax.ScatterDimensionNumbers( update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))), ])) def testScatterMin(self, arg_shape, dtype, idxs, update_shape, dnums): rng = jtu.rand_default(self.rng()) rng_idx = jtu.rand_int(self.rng(), high=max(arg_shape)) idxs = rng_idx(idxs.shape, idxs.dtype) scatter_min = lambda x, y: lax.scatter_min(x, idxs, y, dnums) x = rng(arg_shape, dtype) y = rng(update_shape, dtype) check_grads(scatter_min, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2) def testStopGradient(self): def f(x): return lax.sin(x) * lax.cos(lax.stop_gradient(x)) def f2(x, y): return lax.sin(x) * lax.cos(y) x = 3.14 ans = api.grad(f)(x) expected = api.grad(f2)(x, x) self.assertAllClose(ans, expected) ans = api.grad(api.grad(f))(x) expected = api.grad(api.grad(f2))(x, x) self.assertAllClose(ans, expected) ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.) expected = np.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) with jax.enable_checks(False): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) # TODO(mattjj): make this a more systematic test def testRemainder(self): rng = np.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(3, 4)) y = rng.uniform(0.7, 1.9, size=(3, 1)) assert not set(np.unique(x)) & set(np.unique(y)) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) rng = np.random.RandomState(0) x = rng.uniform(-0.9, 9, size=(1, 4)) y = rng.uniform(0.7, 1.9, size=(3, 4)) assert not set(np.unique(x)) & set(np.unique(y)) tol = 1e-1 if jtu.num_float_bits(np.float64) == 32 else 1e-3 check_grads(lax.rem, (x, y), 2, ["fwd", "rev"], tol, tol) def testHigherOrderGradientOfReciprocal(self): # Regression test for https://github.com/google/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x grad_fn = jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(jax.grad(inv)))))) self.assertAllClose(np.float32(0.0439453125), grad_fn(np.float32(4.))) def test_linear_transpose_real(self): f = lambda x: x.real transpose = api.linear_transpose(f, 1.j) actual, = transpose(1.) expected = 1. self.assertEqual(actual, expected) def test_linear_transpose_imag(self): f = lambda x: x.imag transpose = api.linear_transpose(f, 1.j) actual, = transpose(1.) expected = -1.j self.assertEqual(actual, expected)
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): def test_primitive_coverage(self): """Fail if there are JAX primitives that are not implemented.""" # Harvest primitives from XLA translation tables all_primitives = (set(xla.translations) | set(xla.backend_specific_translations['cpu']) | set(xla.backend_specific_translations['gpu']) | set(xla.backend_specific_translations['tpu']) | set(xla.initial_style_translations) | set(xla.parallel_translations)) tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl) all_primitives = tuple(sorted(all_primitives, key=str)) for p in all_primitives: if p in tf_not_yet_impl: self.assertNotIn( p, tf_impl ) # Should not be in both tf_impl and tf_not_yet_impl else: self.assertIn(p, tf_impl) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [ jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less, jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal, jnp.not_equal, jnp.maximum, jnp.minimum ])) def test_type_promotion(self, f_jax=jnp.add): # We only test a few types here, as tensorflow does not support many # types like uint* or bool in binary ops. types = [np.int32, np.int64, np.float32] for x_dtype in types: for y_dtype in types: x = np.array([1, 2], dtype=x_dtype) y = np.array([3, 4], dtype=y_dtype) self.ConvertAndCompare(f_jax, x, y, with_function=True) def test_concat(self): values = [ np.array([1, 2], dtype=np.float32), np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int8) ] f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0)) self.ConvertAndCompare(f_jax, values, with_function=True) @primitive_harness.parameterized(primitive_harness.lax_pad) def test_pad(self, harness: primitive_harness.Harness): # TODO: figure out the bfloat16 story if harness.params["dtype"] is dtypes.bfloat16: raise unittest.SkipTest("bfloat16 not implemented") # TODO: fix pad with negative padding in XLA (fixed on 06/16/2020) if any([lo < 0 or hi < 0 for lo, hi, mid in harness.params["pads"]]): raise unittest.SkipTest("pad with negative pad not supported") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=False) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in LAX_ELEMENTWISE_UNARY)) def test_unary_elementwise(self, f_jax=lax.abs): x = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6], dtype=np.float32) f_tf = tf.function(jax2tf.convert(f_jax)) r_jax = f_jax(x) r_tf = f_tf(x) self.assertAllClose(r_jax[np.isfinite(r_jax)], r_tf[np.isfinite(r_tf)], atol=1e-4) def test_bitwise_not(self): x = np.array([-1, 3, -2, 0, 0, 2, 1, 3], dtype=np.int32) f_jax = jax.jit(lax.bitwise_not) f_tf = tf.function(jax2tf.convert(f_jax)) r_jax = f_jax(x) r_tf = f_tf(x) self.assertAllClose(r_jax[np.isfinite(r_jax)], r_tf[np.isfinite(r_tf)], atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in LAX_ELEMENTWISE_BINARY)) def test_binary_elementwise(self, f_jax=lax.add): a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.2, 1, 1.4, 1.6], dtype=np.float32) b = np.array([-1.6, 1.4, 1.0, 0.0, 0.1, 0.2, 1, 1.4, -1.6], dtype=np.float32) f_tf = tf.function(jax2tf.convert(f_jax)) r_jax = f_jax(a, b) r_tf = f_tf(a, b) # Jax outputs 0 and 1 instead of NaN for values outside the domain. # Whereas tensorflow does this for other combinations, if f_jax in (lax.igamma, lax.igammac): # Make returned array writeable. r_jax = np.copy(r_jax) r_jax[r_jax == 0] = np.nan r_jax[r_jax == 1] = np.nan r_tf = np.copy(r_tf) r_tf[r_tf == 0] = np.nan r_tf[r_tf == 1] = np.nan self.assertAllClose(r_jax[np.isfinite(r_jax)], r_tf[np.isfinite(r_tf)], atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY)) def test_binary_logical_elementwise(self, f_jax): a = np.array([1, 3, 2, 0, 0, 2, 1, 3], dtype=np.uint32) b = np.array([1, 2, 3, 0, 1, 0, 2, 3], dtype=np.uint32) f_tf = tf.function(jax2tf.convert(f_jax)) r_jax = f_jax(a, b) r_tf = f_tf(a, b) self.assertAllClose(r_jax[np.isfinite(r_jax)], r_tf[np.isfinite(r_tf)], atol=1e-4) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in LAX_LOGICAL_ELEMENTWISE_BINARY)) def test_binary_logical_elementwise_bool(self, f_jax): if f_jax == lax.shift_left: self.skipTest("Shift of bool not supported") a = np.array([0, 0, 1, 1, 0, 0, 1, 1], dtype=np.bool_) b = np.array([0, 1, 0, 1, 0, 1, 0, 1], dtype=np.bool_) f_tf = tf.function(jax2tf.convert(f_jax)) r_jax = f_jax(a, b) r_tf = f_tf(a, b) self.assertAllClose(r_jax, r_tf) # TODO(necula): combine tests that are identical except for the harness # wait until we get more experience with using harnesses. @primitive_harness.parameterized(primitive_harness.lax_shift_left) def test_shift_left(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) @primitive_harness.parameterized(primitive_harness.lax_shift_right_logical) def test_shift_right_logical(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) @primitive_harness.parameterized( primitive_harness.lax_shift_right_arithmetic) def test_shift_right_arithmetic(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) @primitive_harness.parameterized(primitive_harness.lax_slice) def test_slice(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) @primitive_harness.parameterized(primitive_harness.lax_dynamic_slice) def test_dynamic_slice(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) @primitive_harness.parameterized(primitive_harness.lax_dynamic_update_slice ) def test_dynamic_update_slice(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in (lax.betainc, ))) def test_trinary_elementwise(self, f_jax): a = np.array([-1.6, -1.4, -1.0, 0.0, 0.1, 0.3, 1, 1.4, 1.6], dtype=np.float32) b = np.array([-1.6, 1.4, 1.0, 0.0, 0.2, 0.1, 1, 1.4, -1.6], dtype=np.float32) c = np.array([1.0, -1.0, 2.0, 1.0, 0.3, 0.3, -1.0, 2.4, 1.6], dtype=np.float32) f_tf = tf.function(jax2tf.convert(f_jax)) r_jax = f_jax(a, b, c) r_tf = f_tf(a, b, c) self.assertAllClose(r_jax[np.isfinite(r_jax)], r_tf[np.isfinite(r_tf)], atol=1e-4) @primitive_harness.parameterized(primitive_harness.lax_squeeze) def test_squeeze(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), with_function=True) def test_gather(self): values = np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32) indices = np.array([0, 1], dtype=np.int32) for axis in (0, 1): f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop self.ConvertAndCompare(f_jax, values, indices, with_function=True) def test_boolean_gather(self): values = np.array([[True, True], [False, True], [False, False]], dtype=np.bool_) indices = np.array([0, 1], dtype=np.int32) for axis in [0, 1]: f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop self.ConvertAndCompare(f_jax, values, indices, with_function=True) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_numerical_input(self, f_jax): values = [np.array([1, 2, 3], dtype=np.float32)] self.ConvertAndCompare(f_jax, values, with_function=True) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in (jnp.cumsum, jnp.cumprod))) def test_cumulated_ops(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values, with_function=True) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{op.__name__}", op=op) for op in INDEX)) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u)) self.ConvertAndCompare(f_jax, values, update, with_function=True) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_boolean_input(self, f_jax): values = [np.array([True, False, True], dtype=np.bool_)] self.ConvertAndCompare(f_jax, values, with_function=True) def test_gather_rank_change(self): params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]]) indices = jnp.array([[1, 1, 2], [0, 1, 0]]) f_jax = jax.jit(lambda i: params[i]) self.ConvertAndCompare(f_jax, indices, with_function=True) def test_prngsplit(self): f_jax = jax.jit(lambda key: jax.random.split(key, 2)) for rng_key in [ jax.random.PRNGKey(42), np.array([0, 0], dtype=np.uint32), np.array([0xFFFFFFFF, 0], dtype=np.uint32), np.array([0, 0xFFFFFFFF], dtype=np.uint32), np.array([0xFFFFFFFF, 0xFFFFFFFF], dtype=np.uint32) ]: self.ConvertAndCompare(f_jax, rng_key, with_function=True) def test_zeros_like(self): v = np.float32(2.) f_jax = jax.ad_util.zeros_like_jaxval self.ConvertAndCompare(f_jax, v) def test_stop_gradient(self): f = jax2tf.convert(lax.stop_gradient) self.assertEqual(f(tf.ones([])), 1.)
class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testCholesky(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) def args_maker(): factor_shape = shape[:-1] + (2 * shape[-1], ) a = rng(factor_shape, dtype) return [onp.matmul(a, np.conj(T(a)))] if np.issubdtype(dtype, np.complexfloating) and ( len(shape) > 2 or jtu.device_under_test() != "cpu"): self.skipTest( "Unimplemented case for complex Cholesky decomposition.") self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True) if onp.finfo(dtype).bits == 64: jtu.check_grads(np.linalg.cholesky, args_maker(), order=2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [0, 4, 5, 25 ] # TODO(mattjj): complex64 unstable on large sizes? for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testDet(self, n, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True) def testDetOfSingularMatrix(self): x = np.array([[-1., 3. / 2], [2. / 3, -1.]], dtype=onp.float32) self.assertAllClose(onp.float32(0), jsp.linalg.det(x), check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200), (2, 2, 2), (2, 3, 3), (3, 2, 2)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSlogdet(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True) def testIssue1213(self): for n in range(5): mat = np.array( [onp.diag(onp.ones([5], dtype=onp.float32)) * (-.01)] * 2) args_maker = lambda: [mat] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEig(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() w, v = np.linalg.eig(a) self.assertTrue( onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100)) self._CompileAndCheck(partial(np.linalg.eig), args_maker, check_dtypes=True, rtol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testEigBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) shape = (10, ) + shape args = rng(shape, dtype) ws, vs = vmap(np.linalg.eig)(args) self.assertTrue( onp.all( onp.linalg.norm(onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}_lower={}".format( jtu.format_shape_dtype_string((n, n), dtype), lower), "n": n, "dtype": dtype, "lower": lower, "rng": rng } for n in [0, 4, 5, 50] for dtype in float_types + complex_types for lower in [False, True] for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for TPU. @jtu.skip_on_devices("tpu") def testEigh(self, n, dtype, lower, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() a = (a + onp.conj(a.T)) / 2 w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo, symmetrize_input=False) self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5) self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30) self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker, check_dtypes=True, rtol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_lower={}".format( jtu.format_shape_dtype_string(shape, dtype), lower), "shape": shape, "dtype": dtype, "rng": rng, "lower": lower } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)] for dtype in float_types + complex_types for rng in [jtu.rand_default()] for lower in [True, False])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for TPU. @jtu.skip_on_devices("tpu") def testEighGrad(self, shape, dtype, rng, lower): self.skipTest("Test fails with numeric errors.") uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + onp.conj(a.T)) / 2 a = onp.tril(a) if lower else onp.triu(a) # Gradient checks will fail without symmetrization as the eigh jvp rule # is only correct for tangents in the symmetric subspace, whereas the # checker checks against unconstrained (co)tangents. if dtype not in complex_types: f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True) else: # only check eigenvalue grads for complex matrices f = lambda a: partial( np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0] jtu.check_grads(f, (a, ), 2, rtol=1e-1) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_lower={}".format( jtu.format_shape_dtype_string(shape, dtype), lower), "shape": shape, "dtype": dtype, "rng": rng, "lower": lower, "eps": eps } for shape in [(1, 1), (4, 4), (5, 5), (50, 50)] for dtype in complex_types for rng in [jtu.rand_default()] for lower in [True, False] for eps in [1e-4])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for TPU. @jtu.skip_on_devices("tpu") def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps): _skip_if_unsupported_type(dtype) # Special case to test for complex eigenvector grad correctness. # Exact eigenvector coordinate gradients are hard to test numerically for complex # eigensystem solvers given the extra degrees of per-eigenvector phase freedom. # Instead, we numerically verify the eigensystem properties on the perturbed # eigenvectors. You only ever want to optimize eigenvector directions, not coordinates! uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + onp.conj(a.T)) / 2 a = onp.tril(a) if lower else onp.triu(a) a_dot = eps * rng(shape, dtype) a_dot = (a_dot + onp.conj(a_dot.T)) / 2 a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot) # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix f = partial(np.linalg.eigh, UPLO=uplo) (w, v), (dw, dv) = jvp(f, primals=(a, ), tangents=(a_dot, )) new_a = a + a_dot new_w, new_v = f(new_a) new_a = (new_a + onp.conj(new_a.T)) / 2 # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues. RTOL = 1e-2 assert onp.max( onp.abs((onp.diag( onp.dot(onp.conj( (v + dv).T), onp.dot(new_a, (v + dv)))) - new_w) / new_w)) < RTOL # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues. assert onp.max( onp.linalg.norm( onp.abs(new_w * (v + dv) - onp.dot(new_a, (v + dv))), axis=0) / onp.linalg.norm(onp.abs(new_w * (v + dv)), axis=0)) < RTOL @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") def testEighBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) shape = (10, ) + shape args = rng(shape, dtype) args = (args + onp.conj(T(args))) / 2 ws, vs = vmap(jsp.linalg.eigh)(args) self.assertTrue( onp.all( onp.linalg.norm(onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims, "ord": ord, "rng": rng } for axis, shape in [(None, (1, )), (None, (7, )), (None, ( 5, 8)), (0, (9, )), (0, (4, 5)), ((1, ), ( 10, 7, 3)), ((-2, ), (4, 8)), (-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9)), (None, (7, 8, 11))] for keepdims in [False, True] for ord in ([None] if axis is None and len(shape) > 2 else [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf] if (axis is None and len(shape) == 1 ) or isinstance(axis, int) or ( isinstance(axis, tuple) and len(axis) == 1 ) else [ None, 'fro', 1, 2, -1, -2, np. inf, -np.inf, 'nuc' ]) for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testNorm(self, shape, dtype, ord, axis, keepdims, rng): _skip_if_unsupported_type(dtype) if (ord in ('nuc', 2, -2) and (jtu.device_under_test() != "cpu" or (isinstance(axis, tuple) and len(axis) == 2))): raise unittest.SkipTest("No adequate SVD implementation available") args_maker = lambda: [rng(shape, dtype)] onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) # Older numpy versions promote to float64 unnecessarily.. check_dtypes = numpy_version >= (1, 15) self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=check_dtypes, tol=1e-3) self._CompileAndCheck(np_fn, args_maker, check_dtypes=check_dtypes) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( jtu.format_shape_dtype_string(b + (m, n), dtype), full_matrices, compute_uv), "b": b, "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices, "compute_uv": compute_uv, "rng": rng } for b in [(), (3, ), (2, 3)] for m in [2, 7, 29, 53] for n in [2, 7, 29, 53] for dtype in float_types + complex_types for full_matrices in [False, True] for compute_uv in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, rng): _skip_if_unsupported_type(dtype) if b != () and jax.lib.version <= (0, 1, 28): raise unittest.SkipTest("Batched SVD requires jaxlib 0.1.29") args_maker = lambda: [rng(b + (m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue( onp.all( norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :])) < 50)) else: self.assertTrue( onp.all( norm(a - onp.matmul( out[1][..., None, :] * out[0][..., :, :k], out[2])) < 350)) else: self.assertTrue( onp.all( norm(a - onp.matmul(out[1][..., None, :] * out[0], out[2])) < 300)) # Check the unitary properties of the singular vector matrices. self.assertTrue( onp.all( norm( onp.eye(out[0].shape[-1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue( onp.all( norm( onp.eye(out[2].shape[-1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue( onp.all( norm( onp.eye(out[2].shape[-2]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue( onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) if not full_matrices: svd = partial(np.linalg.svd, full_matrices=False) jtu.check_jvp(svd, partial(jvp, svd), (a, ), rtol=1e-2, atol=1e-1) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_fullmatrices={}".format( jtu.format_shape_dtype_string(shape, dtype), full_matrices), "shape": shape, "dtype": dtype, "full_matrices": full_matrices, "rng": rng } for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)] for dtype in float_types + complex_types for full_matrices in [False, True] for rng in [jtu.rand_default()])) def testQr(self, shape, dtype, full_matrices, rng): _skip_if_unsupported_type(dtype) if (onp.issubdtype(dtype, onp.complexfloating) and (jtu.device_under_test() == "tpu" or jax.lib.version <= (0, 1, 27))): raise unittest.SkipTest("No complex QR implementation") m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support batch dimensions. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue( onp.all(norm(onp.eye(k) - onp.matmul(onp.conj(T(lq)), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a, ), atol=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(10, 4, 5), (5, 3, 3), (7, 6, 4)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testQrBatching(self, shape, dtype, rng): args = rng(shape, np.float32) qs, rs = vmap(jsp.linalg.qr)(args) self.assertTrue( onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3)) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4)), ((1, 2, 2), (3, 2)), ((2, 1, 3, 3), (2, 4, 3, 4)), ] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSolve(self, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)] for dtype in float_types for rng in [jtu.rand_default()])) def testInv(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) if jtu.device_under_test() == "gpu" and shape == (200, 200): raise unittest.SkipTest("Test is flaky on GPU") def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True) # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("tpu" ) # TODO(phawkins): No eigh implementation on TPU. def testIssue669(self): def test(x): val, vec = np.linalg.eigh(x) return np.real(np.sum(val)) grad_test_jc = jit(grad(jit(test))) xc = onp.eye(3, dtype=onp.complex) self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True) def testIssue1151(self): A = np.array(onp.random.randn(100, 3, 3), dtype=np.float32) b = np.array(onp.random.randn(100, 3), dtype=np.float32) x = np.linalg.solve(A, b) self.assertAllClose(vmap(np.dot)(A, x), b, atol=1e-3, rtol=1e-3, check_dtypes=True) jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A, b) jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A, b) jac0 = jax.jacobian(np.linalg.solve, argnums=0)(A[0], b[0]) jac1 = jax.jacobian(np.linalg.solve, argnums=1)(A[0], b[0])
class HostCallbackTest(jtu.JaxTestCase): def setUp(self): testing_stream.reset() testing_stream.test_method_name = self._testMethodName self.old_flags = os.getenv("XLA_FLAGS", "") def tearDown(self) -> None: if os.getenv("XLA_FLAGS") != self.old_flags: os.environ["XLA_FLAGS"] = self.old_flags xla_bridge.get_backend.cache_clear() hcb.barrier_wait() def helper_set_devices(self, nr_devices): flags_str = os.getenv("XLA_FLAGS", "") os.environ["XLA_FLAGS"] = ( flags_str + " --xla_force_host_platform_device_count={}".format(nr_devices)) # Clear any cached backends so new CPU backend will pick up the env var. xla_bridge.get_backend.cache_clear() return api.devices() def helper_set_hlo_dump(self): flags_str = os.getenv("XLA_FLAGS", "") os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump" # Clear any cached backends so new CPU backend will pick up the env var. xla_bridge.get_backend.cache_clear() def test_eval(self): # TODO: renable jaxpr golden tests when changing host_callback #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.))) self.assertAllClose((5. * 2.)**2, fun1(5.)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 what: y * 3 30.00""", testing_stream.output) testing_stream.reset() def test_with_tuple_results(self): def func2(x): x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream) return x1 + y1 #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(func2)(3.))) self.assertEqual(3. * (2. + 3.), func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) testing_stream.reset() def test_with_dict_results(self): def func2(x): res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) return res["a"] + res["b"] self.assertEqual(3. * (2. + 3.), func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ { a=6.00 b=9.00 }""", testing_stream.output) testing_stream.reset() def test_with_result(self): def func2(x): x1 = hcb.id_print((x * 2., x * 3.), result=x * 4., output_stream=testing_stream) return x1 self.assertEqual(3. * 4., func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) testing_stream.reset() def test_eval_tap_exception(self): # Simulate a tap error def tap_err(*args, **kwargs): raise NotImplementedError def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 with self.assertRaises(hcb.TapFunctionException): func(0) hcb.barrier_wait() # We should have received everything before the error assertMultiLineStrippedEqual(self, """ what: x1 1 what: x3 3""", testing_stream.output) testing_stream.reset() def test_jit_simple(self): jit_fun1 = api.jit(lambda x: 3. * hcb.id_print( 2. * x, what="here", output_stream=testing_stream)) self.assertAllClose(6. * 5., jit_fun1(5.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: here 10.00""", testing_stream.output) testing_stream.reset() def test_jit_constant(self): def func(x): return hcb.id_print(42, result=x, output_stream=testing_stream) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5))) self.assertAllClose(5, api.jit(func)(5)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ 42""", testing_stream.output) testing_stream.reset() def test_jit_sequence1(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) logging.info("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.info("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) self.assertEqual(2, api.jit(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 where: 2 2""", testing_stream.output) testing_stream.reset() def test_jit2(self): """A sequence of JIT.""" def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) return x2 self.assertEqual(2, api.jit(func)(1)) self.assertEqual(11, api.jit(func)(10)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: 1 10 where: 2 11""", testing_stream.output) testing_stream.reset() def test_jit_nested(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) def func_nested(x): x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream) return x2 x3 = api.jit(func_nested)(x1) return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream) self.assertEqual(3, api.jit(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: nested 2 where: 3 3""", testing_stream.output) testing_stream.reset() def test_jit_devices(self): """Running on multiple devices.""" devices = api.local_devices() logging.info(f"{self._testMethodName}: has devices {devices}") def func(x, device_id): x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) return x2 for d in devices: self.assertEqual( 112, api.jit(func, device=d, static_argnums=1)(111, d.id)) hcb.barrier_wait() logging.info( f"{self._testMethodName}: found output {testing_stream.output}") self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output))) self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output))) testing_stream.reset() @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit) for with_jit in [True, False])) def test_pytree(self, with_jit=False): def func(x, what=""): """Returns some pytrees depending on x""" if what == "pair_1_x": return (1, x) elif what == "pair_x_2x": return (x, 2 * x) elif what == "dict": return dict(a=2 * x, b=3 * x) else: assert False tap_count = 0 def tap_func(a, what=""): nonlocal tap_count tap_count += 1 self.assertEqual(func(5, what), a) transform = api.jit if with_jit else lambda f: f for what in ("pair_1_x", "pair_x_2x", "dict"): self.assertEqual( func(10, what), transform(lambda x: hcb.id_tap(tap_func, func(x, what), result=func(x * 2, what), what=what))(5)) hcb.barrier_wait() # Wait for receivers to be done self.assertEqual(3, tap_count) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_concurrent_{concurrent}", concurrent=concurrent) for concurrent in [True, False])) def test_multiple_tap(self, concurrent=False): """Call id_tap multiple times, concurrently or in sequence. """ if concurrent and jtu.device_under_test() == "gpu": # TODO(necula): it seems that on GPU if multiple host threads run # a jit computation, the mutliple computations are interleaved on the # GPU. This can result in the outfeed trains being interleaved, which # will trigger an error. The solution is to fix on GPU the receiving # logic so that we can outfeed the train as one tuple, and receive it # one piece as a time. Then the trains should be atomic. # See also b/160692602. raise SkipTest("concurrent id_tap not supported on GPU") received = set() count = 5 def pause_tap(idx, **kwargs): received.add(int(idx)) logging.info(f"Starting do_tap {idx}. Sleeping 1sec ...") time.sleep(0.3) logging.info(f"Finish do_tap {idx}") def do_tap(idx): api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx) if concurrent: threads = [ threading.Thread(name=f"enqueue_tap_{idx}", target=do_tap, args=(idx, )) for idx in range(count) ] [t.start() for t in threads] [t.join() for t in threads] else: for idx in range(count): do_tap(idx) hcb.barrier_wait() self.assertEqual(received, set(range(count))) # TODO(necula): see comment for test_multiple_tap. @jtu.skip_on_devices("gpu") def test_multiple_barriers(self): """Call barrier_wait concurrently.""" def pause_tap(*args, **kwargs): logging.info("pause_tap waiting") time.sleep(0.3) logging.info("pause_tap done") def long_run(x): return hcb.id_tap(pause_tap, x) api.jit(long_run)(5.) def try_barrier(idx): logging.info(f"Starting test barrier {idx}") hcb.barrier_wait() logging.info(f"Finished test barrier {idx}") threads = [ threading.Thread(name=f"barrier_{idx}", target=try_barrier, args=(idx, )) for idx in range(3) ] [t.start() for t in threads] [t.join() for t in threads] @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit) for with_jit in [True, False])) def test_cond(self, with_jit=False): """A conditional""" def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, lambda x: hcb.id_print( x, where="cond_t", output_stream=testing_stream), lambda x: hcb.id_print( -1, where="cond_f", result=x, output_stream=testing_stream ), x2 + 1) x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream) return x5 transform = api.jit if with_jit else lambda f: f self.assertEqual(4, transform(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: cond_f -1 where: end 4""", testing_stream.output) testing_stream.reset() @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit) for with_jit in [True, False])) def test_while_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(x): x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, lambda x: hcb.id_print( x, where="w_b_t", output_stream=testing_stream), lambda x: hcb.id_print(-1, where="w_b_f", result=x, output_stream=testing_stream), x3 + 1) return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream) x10 = lax.while_loop(lambda x: x <= 3, body, x2) res = hcb.id_print(x10, where="end", output_stream=testing_stream) return res transform = api.jit if with_jit else lambda f: f self.assertEqual(4, transform(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: w_b_1 2 where: w_b_t 3 where: w_b_2 3 where: w_b_1 3 where: w_b_f -1 where: w_b_2 4 where: end 4""", testing_stream.output) testing_stream.reset() def test_jit_while_pred_tap(self): """While with printing in the conditional.""" def func(x): x1 = hcb.id_print(x, where="1") x10 = lax.while_loop( lambda x: hcb.id_print( x < 3, where="w_p", output_stream=testing_stream), lambda x: hcb.id_print( x + 1, where="w_b", output_stream=testing_stream), x1) res = hcb.id_print(x10, where="3", output_stream=testing_stream) return res self.assertEqual(3, api.jit(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: w_p True where: w_b 2 where: w_p True where: w_b 3 where: w_p False where: 3 3""", testing_stream.output) testing_stream.reset() @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit) for with_jit in [True, False])) def test_scan_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(c, x): x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, lambda x: hcb.id_print( x, where="s_t", output_stream=testing_stream), lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream), x3 + 1) return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream)) _, x10 = lax.scan(body, x2, jnp.arange(3)) res = hcb.id_print(x10, where="10", output_stream=testing_stream) return res if with_jit: func = api.jit(func) res = func(1) self.assertAllClose(jnp.array([1, 2, 3]), res) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: s_1 0 where: s_t 1 where: s_2 1 where: s_1 1 where: s_f -1 where: s_2 2 where: s_1 2 where: s_t 3 where: s_2 3 where: 10 [1 2 3]""", testing_stream.output) testing_stream.reset() @parameterized.named_parameters( jtu.cases_from_list( dict( testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}", shape=shape, dtype=dtype, nr_args=nr_args) for nr_args in [1, 2] for shape in [(), (2, ), (2, 3), (2, 3, 4)] for dtype in jtu.dtypes.all)) def test_jit_types(self, nr_args=2, dtype=jnp.int16, shape=(2, )): if dtype in (jnp.complex64, jnp.complex128, jnp.bool_): raise SkipTest(f"id_print jit not implemented for {dtype}.") if jtu.device_under_test() == "tpu": if dtype in (jnp.int16, ): raise SkipTest(f"transfering {dtype} not supported on TPU") args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)] if nr_args > 1: args = args * nr_args jit_fun1 = api.jit(lambda xs: hcb.id_print( xs, a_new_test="************", testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}")) res = jit_fun1(args) self.assertAllClose(args, res) def test_jit_large(self): arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1)) api.jit(hcb.id_print)(arg) def test_jit_several_together(self): arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5)) api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))( arg, jnp.ones(100, dtype=jnp.int32)) def test_jit_interleaving(self): # Several jit's without data dependencies; they may interfere count = 0 # Count tap invocations nr_arrays = 5 def tap_func(arg, **_): nonlocal count assert len(arg) == nr_arrays count += 1 # This is the function that we'll run multiple times def func(x, count): for i in range(count): x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1] return x x = jnp.array(1, dtype=np.int32) res = 0 for _ in range(10): # No dependencies between the jit invocations res += api.jit(lambda x: func(x, 10))(x) hcb.barrier_wait() self.assertEqual(100, count) def test_jit_tap_exception(self): # Simulate a tap error def tap_err(*args, **kwargs): raise NotImplementedError def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 res = api.jit(func)(0) # No error yet with self.assertRaises(hcb.TapFunctionException): hcb.barrier_wait() # Even though the receiver thread raised, the main thread should still # return 3. self.assertEqual(3, res) # We should have received all others assertMultiLineStrippedEqual(self, """ what: x1 1 what: x3 3""", testing_stream.output) testing_stream.reset() def test_jit_nested_cond_no_print(self): """A nested conditional, without any prints""" raise SkipTest("skip this") @api.jit def cfun(x): return lax.cond( lax.lt(x, 2), lambda x: x, lambda x: lax.cond(x < 5, 3, lambda x: x, 4, lambda y: y), x) print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text()) cfun(1) def test_while(self): """Executing while, even without JIT uses compiled code""" y = jnp.ones(5) # captured const def func(x): return lax.while_loop( lambda c: c[1] < 5, lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1), (x, 1)) func(y) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ 1 2 3 4""", testing_stream.output) testing_stream.reset() def test_jvp(self): jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, )) #assertMultiLineStrippedEqual(self, "") res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) self.assertAllClose(100., res_primals, check_dtypes=False) self.assertAllClose(4., res_tangents, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 transforms: ({'name': 'jvp'},) what: a * 2 0.20 what: y * 3 30.00 transforms: ({'name': 'jvp'},) what: y * 3 0.60""", testing_stream.output) testing_stream.reset() def test_grad_primal_unused(self): # The output of id_print is not needed for backwards pass def func(x): return 2. * hcb.id_print( x * 3., what="x * 3", output_stream=testing_stream) grad_func = api.grad(func) jaxpr = str(api.make_jaxpr(grad_func)(5.)) # Just making the Jaxpr invokes the id_print once hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ { lambda ; a. let in (6.00,) }""", jaxpr) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) hcb.barrier_wait() self.assertAllClose(6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: x * 3 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset() def test_grad_simple(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * hcb.id_print( y * 3., what="y * 3", output_stream=testing_stream) grad_func = api.grad(func) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.))) res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 what: y * 3 30.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3 5.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00""", testing_stream.output) testing_stream.reset() def test_grad_double(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * (y * 3.) grad_func = api.grad(api.grad(func)) # Just making the Jaxpr invokes the id_print twice _ = api.make_jaxpr(grad_func)(5.) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(12., res_grad, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00""", testing_stream.output) testing_stream.reset() def test_vmap(self): vmap_fun1 = api.vmap(fun1) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs))) vmap_fun1(vargs) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2 [ 8.00 10.00] transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3 [24.00 30.00]""", testing_stream.output) testing_stream.reset() def test_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped _, y = hcb.id_print((x, y), output_stream=testing_stream) return x + y vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs))) _ = vmap_func(vargs) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 [4.00 5.00] ]""", testing_stream.output) testing_stream.reset() def test_double_vmap(self): # A 2D tensor with x[i, j] = i + j using 2 vmap def sum(x, y): return hcb.id_print(x + y, output_stream=testing_stream) def sum_rows(xv, y): return api.vmap(sum, in_axes=(0, None))(xv, y) def sum_all(xv, yv): return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv) xv = jnp.arange(5, dtype=np.int32) yv = jnp.arange(3, dtype=np.int32) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv))) _ = sum_all(xv, yv) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)}) [[0 1 2 3 4] [1 2 3 4 5] [2 3 4 5 6]]""", testing_stream.output) testing_stream.reset() def test_vmap_while(self): """Vmap of while.""" def func(x): # like max(x, 2) x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = lax.while_loop( lambda x: x < 2, lambda x: hcb.id_print( x + 1, where="w_b", output_stream=testing_stream), x1) res = hcb.id_print(x2, where="3", output_stream=testing_stream) return res inputs = np.arange(5, dtype=np.int32) self.assertAllClose(np.array([2, 2, 2, 3, 4]), api.jit(api.vmap(func))(inputs), check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1 [0 1 2 3 4] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b [1 2 3 4 5] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b [2 3 3 4 5] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3 [2 2 2 3 4]""", testing_stream.output) testing_stream.reset() def test_vmap_while_tap_cond(self): """Vmap of while, with a tap in the conditional.""" def func(x): # like max(x, 2) x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = lax.while_loop( lambda x: hcb.id_print( x < 2, where="w_c", output_stream=testing_stream), lambda x: hcb.id_print( x + 1, where="w_b", output_stream=testing_stream), x1) res = hcb.id_print(x2, where="3", output_stream=testing_stream) return res inputs = np.arange(5, dtype=np.int32) res = api.jit(api.vmap(func))(inputs) hcb.barrier_wait() self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1 [0 1 2 3 4] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c [ True True False False False] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b [1 2 3 4 5] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c [ True False False False False] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b [2 3 3 4 5] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c [False False False False False] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3 [2 2 2 3 4]""", testing_stream.output) testing_stream.reset() def test_pmap(self): vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32) pmap_fun1 = api.pmap(fun1, axis_name="i") res = pmap_fun1(vargs) hcb.barrier_wait() expected_res = jnp.stack( [fun1_equiv(2. + a) for a in range(api.local_device_count())]) self.assertAllClose(expected_res, res, check_dtypes=False) def test_mask(self): # TODO(necula) raise SkipTest("masking has regressed") @functools.partial(api.mask, in_shapes=['n'], out_shape='') def padded_sum(x): return jnp.sum( hcb.id_print(x, what="x", output_stream=testing_stream)) args = [jnp.arange(4)], dict(n=np.int64(2)) assertMultiLineStrippedEqual( self, """ { lambda c f ; a b. let d = lt c b e = id_tap[ func=_print logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)] transforms=('mask',) what=x ] a g = select d e f h = reduce_sum[ axes=(0,) ] g in (h,) }""", str(api.make_jaxpr(padded_sum)(*args))) _ = padded_sum(*args) self.assertMultiLineStrippedEqual( """ logical_shapes: [(2,)] transforms: ('mask',) what: x [0 1 2 3] """, testing_stream.output) testing_stream.reset() def test_outfeed_receiver(self): """Test the deprecated outfeed_receiver""" with hcb.outfeed_receiver(): self.assertAllClose((5. * 2.)**2, fun1(5.), check_dtypes=True) assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 what: y * 3 30.00""", testing_stream.output) testing_stream.reset() def test_callback_delay(self): hcb.callback_extra = lambda dev: time.sleep(1) def func(x): for i in range(5): x = hcb.id_print(x * i, what="x times i") return x api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) def test_callback_delay_barrier(self): hcb.callback_extra = lambda dev: time.sleep(2) def func(x): for i in range(1, 4): x = hcb.id_print(x * i, what="x times i", output_stream=testing_stream) return x api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) # Wait for the results hcb.barrier_wait() expected = """ what: x times i [[0. 1. 2.] [3. 4. 5.]] what: x times i [[ 0. 2. 4.] [ 6. 8. 10.]] what: x times i [[ 0. 6. 12.] [18. 24. 30.]]""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() # Call again api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) hcb.barrier_wait() self.assertMultiLineStrippedEqual(expected, testing_stream.output) def test_error_bad_consumer_id(self): """Try to use reserved consumer ID 0. Check that we get the proper error from the runtime.""" comp = xla_bridge.make_computation_builder(self._testMethodName) token = hcb.xops.CreateToken(comp) hcb._initialize_outfeed_receiver() # Needed if this is the sole test with self.assertRaisesRegex( RuntimeError, "Consumer ID cannot be a reserved value: 0"): hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [ xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32)) ]) def test_error_different_shapes(self): """Try to register different shapes for the same consumer ID.""" comp = xla_bridge.make_computation_builder(self._testMethodName) token = hcb.xops.CreateToken(comp) hcb._initialize_outfeed_receiver() # Needed if this is the sole test hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])
class LaxRandomTest(jtu.JaxTestCase): def _CheckCollisions(self, samples, nbits): fail_prob = 0.01 # conservative bound on statistical fail prob by Chebyshev nitems = len(samples) nbins = 2**nbits nexpected = nbins * (1 - ((nbins - 1) / nbins)**nitems) ncollisions = len(np.unique(samples)) sq_percent_deviation = ((ncollisions - nexpected) / nexpected)**2 self.assertLess(sq_percent_deviation, 1 / np.sqrt(nexpected * fail_prob)) def _CheckKolmogorovSmirnovCDF(self, samples, cdf): fail_prob = 0.01 # conservative bound on statistical fail prob by Kolmo CDF self.assertGreater(scipy.stats.kstest(samples, cdf).pvalue, fail_prob) def _CheckChiSquared(self, samples, pmf): alpha = 0.01 # significance level, threshold for p-value values, actual_freq = np.unique(samples, return_counts=True) expected_freq = pmf(values) * samples.size # per scipy: "A typical rule is that all of the observed and expected # frequencies should be at least 5." valid = (actual_freq > 5) & (expected_freq > 5) self.assertGreater( valid.sum(), 1, msg='not enough valid frequencies for chi-squared test') _, p_value = scipy.stats.chisquare(actual_freq[valid], expected_freq[valid]) self.assertGreater(p_value, alpha, msg=f'Failed chi-squared test with p={p_value}.\n' 'Expected vs. actual frequencies:\n' f'{expected_freq[valid]}\n{actual_freq[valid]}') @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in jtu.dtypes.floating)) def testNumpyAndXLAAgreeOnFloatEndianness(self, dtype): bits_dtype = np.uint32 if jnp.finfo(dtype).bits == 32 else np.uint64 numpy_bits = np.array(1., dtype).view(bits_dtype) xla_bits = api.jit(lambda: lax.bitcast_convert_type( np.array(1., dtype), bits_dtype))() self.assertEqual(numpy_bits, xla_bits) def testThreefry2x32(self): # We test the hash by comparing to known values provided in the test code of # the original reference implementation of Threefry. For the values, see # https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32 def result_to_hex(result): return tuple([hex(x.copy()).rstrip("L") for x in result]) expected = ("0x6b200159", "0x99ba4efe") result = random.threefry_2x32(np.uint32([0, 0]), np.uint32([0, 0])) self.assertEqual(expected, result_to_hex(result)) expected = ("0x1cb996fc", "0xbb002be7") result = random.threefry_2x32(np.uint32([-1, -1]), np.uint32([-1, -1])) self.assertEqual(expected, result_to_hex(result)) expected = ("0xc4923a9c", "0x483df7a0") result = random.threefry_2x32(np.uint32([0x13198a2e, 0x03707344]), np.uint32([0x243f6a88, 0x85a308d3])) self.assertEqual(expected, result_to_hex(result)) def testThreefry2x32Large(self): n = 10000000 result = random.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.concatenate([ jnp.full((n, ), 0x243f6a88, jnp.uint32), jnp.full((n, ), 0x85a308d3, jnp.uint32) ])) np.testing.assert_equal(result[:n], np.full((n, ), 0xc4923a9c, dtype=np.uint32)) np.testing.assert_equal(result[n:], np.full((n, ), 0x483df7a0, dtype=np.uint32)) def testThreefry2x32Empty(self): # Regression test for an op-by-op crash for empty arrays in CUDA mode. with api.disable_jit(): result = random.threefry_2x32( (np.uint32(0x13198a2e), np.uint32(0x03707344)), jnp.ones(( 10, 0, ), jnp.uint32)) np.testing.assert_equal(result, np.zeros(( 10, 0, ), dtype=np.uint32)) def testRngRandomBitsViewProperty(self): # TODO: add 64-bit if it ever supports this property. # TODO: will this property hold across endian-ness? N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] rand_bits = [ jax._src.random._random_bits(key, n, (N * 64 // n, )) for n in nbits ] rand_bits_32 = np.array( [np.array(r).view(np.uint32) for r in rand_bits]) assert np.all(rand_bits_32 == rand_bits_32[0]) def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) bits8 = jax._src.random._random_bits(key, 8, (3, )) expected8 = np.array([216, 115, 43], dtype=np.uint8) self.assertArraysEqual(bits8, expected8) bits16 = jax._src.random._random_bits(key, 16, (3, )) expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) self.assertArraysEqual(bits16, expected16) bits32 = jax._src.random._random_bits(key, 32, (3, )) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) self.assertArraysEqual(bits32, expected32) with jtu.ignore_warning(category=UserWarning, message="Explicitly requested dtype.*"): bits64 = jax._src.random._random_bits(key, 64, (3, )) if FLAGS.jax_enable_x64: expected64 = np.array([ 3982329540505020460, 16822122385914693683, 7882654074788531506 ], dtype=np.uint64) else: expected64 = np.array([676898860, 3164047411, 4010691890], dtype=np.uint32) self.assertArraysEqual(bits64, expected64) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testRngUniform(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.uniform(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckCollisions(samples, jnp.finfo(dtype).nmant) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.uniform().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in int_dtypes + uint_dtypes)) def testRngRandint(self, dtype): lo = 5 hi = 10 key = random.PRNGKey(0) rand = lambda key: random.randint(key, (10000, ), lo, hi, dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self.assertTrue(np.all(lo <= samples)) self.assertTrue(np.all(samples < hi)) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testNormal(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.normal(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.norm().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testTruncatedNormal(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.truncated_normal(key, -0.3, 0.3, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) min_val = np.min(uncompiled_samples) max_val = np.max(uncompiled_samples) self.assertTrue(min_val > -0.3) self.assertTrue(max_val < 0.3) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF( samples, scipy.stats.truncnorm(-0.3, 0.3).cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in jtu.dtypes.floating + jtu.dtypes.integer)) def testShuffle(self, dtype): key = random.PRNGKey(0) x = np.arange(100).astype(dtype) rand = lambda key: random.shuffle(key, x) crand = api.jit(rand) with self.assertWarns(FutureWarning): perm1 = rand(key) with self.assertWarns(FutureWarning): perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1), x, check_dtypes=False) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}_shape={}_replace={}_weighted={}_array_input={}".format( np.dtype(dtype).name, shape, replace, weighted, array_input), "dtype": dtype, "shape": shape, "replace": replace, "weighted": weighted, "array_input": array_input } for dtype in jtu.dtypes.floating + jtu.dtypes.integer for shape in [(), (5, ), (4, 5)] for replace in [True, False] for weighted in [True, False] for array_input in [False, 'jnp', 'np'])) def testChoice(self, dtype, shape, replace, weighted, array_input): N = 100 key = random.PRNGKey(0) x = (N if not array_input else jnp.arange(N, dtype=dtype) if array_input == 'jnp' else np.arange(N, dtype=dtype)) p = None if not weighted else jnp.arange(N) rand = lambda key: random.choice(key, x, shape, p=p, replace=replace) crand = api.jit(rand) sample1 = rand(key) sample2 = crand(key) self.assertEqual(shape, sample1.shape) if array_input == 'jnp': self.assertEqual(x.dtype, sample1.dtype) if not replace: assert len(np.unique(sample1)) == len(np.ravel(sample1)) self.assertAllClose(sample1, sample2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "dtype": dtype, "shape": shape } for dtype in jtu.dtypes.floating + jtu.dtypes.integer for shape in [100, (10, 10), (10, 5, 2)])) def testPermutationArray(self, dtype, shape): key = random.PRNGKey(0) x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)) def testPermutationInteger(self): key = random.PRNGKey(0) x = 100 rand = lambda key: random.permutation(key, x) crand = api.jit(rand) perm1 = rand(key) perm2 = crand(key) self.assertAllClose(perm1, perm2) self.assertEqual(perm1.dtype, perm2.dtype) self.assertFalse(np.all(perm1 == np.arange(100))) # seems unlikely! self.assertAllClose(np.sort(perm1), np.arange(100), check_dtypes=False) def testPermutationErrors(self): key = random.PRNGKey(0) with self.assertRaises(TypeError): random.permutation(key, 10.) with self.assertRaises(core.ConcretizationTypeError): api.jit(random.permutation)(key, 10) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_p={}_dtype={}".format( p, np.dtype(dtype).name), "p": p, "dtype": dtype } for p in [0.1, 0.5, 0.9] for dtype in jtu.dtypes.floating)) def testBernoulli(self, p, dtype): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) rand = lambda key, p: random.bernoulli(key, p, (10000, )) crand = api.jit(rand) uncompiled_samples = rand(key, p) compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.bernoulli(p).pmf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_p={}_{}_{}".format(p, np.dtype(dtype).name, sample_shape), "p": p, "axis": axis, "dtype": dtype, 'sample_shape': sample_shape } for (p, axis) in [ ([.25] * 4, -1), ([.1, .2, .3, .4], -1), ([[.5, .5], [.1, .9]], 1), ([[.5, .1], [.5, .9]], 0), ] for sample_shape in [(10000, ), (5000, 2)] for dtype in jtu.dtypes.floating)) def testCategorical(self, p, axis, dtype, sample_shape): key = random.PRNGKey(0) p = np.array(p, dtype=dtype) logits = np.log(p) - 42 # test unnormalized out_shape = tuple(np.delete(logits.shape, axis)) shape = sample_shape + out_shape rand = partial(random.categorical, shape=shape, axis=axis) crand = api.jit(rand) uncompiled_samples = rand(key, logits) compiled_samples = crand(key, logits) if axis < 0: axis += len(logits.shape) for samples in [uncompiled_samples, compiled_samples]: assert samples.shape == shape samples = jnp.reshape(samples, (10000, ) + out_shape) if len(p.shape[:-1]) > 0: ps = np.transpose(p, (1, 0)) if axis == 0 else p for cat_samples, cat_p in zip(samples.transpose(), ps): self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x]) else: self._CheckChiSquared(samples, pmf=lambda x: p[x]) def testBernoulliShape(self): key = random.PRNGKey(0) x = random.bernoulli(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_a={}_b={}_dtype={}".format(a, b, np.dtype(dtype).name), "a": a, "b": b, "dtype": dtype } for a in [0.2, 5.] for b in [0.2, 5.] for dtype in [np.float64]) ) # NOTE: KS test fails with float32 def testBeta(self, a, b, dtype): if not FLAGS.jax_enable_x64: raise SkipTest("skip test except on X64") key = random.PRNGKey(0) rand = lambda key, a, b: random.beta(key, a, b, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a, b) compiled_samples = crand(key, a, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testCauchy(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.cauchy(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.cauchy().cdf) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_alpha={}_dtype={}".format(alpha, np.dtype(dtype).name), "alpha": alpha, "dtype": dtype } for alpha in [ np.array([0.2, 1., 5.]), ] for dtype in jtu.dtypes.floating)) @jtu.skip_on_devices("tpu") # TODO(mattjj): slow compilation times def testDirichlet(self, alpha, dtype): key = random.PRNGKey(0) rand = lambda key, alpha: random.dirichlet(key, alpha, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, alpha) compiled_samples = crand(key, alpha) for samples in [uncompiled_samples, compiled_samples]: self.assertAllClose(samples.sum(-1), np.ones(10000, dtype=dtype)) alpha_sum = sum(alpha) for i, a in enumerate(alpha): self._CheckKolmogorovSmirnovCDF( samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testExponential(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.exponential(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_a={}_dtype={}".format( a, np.dtype(dtype).name), "a": a, "dtype": dtype } for a in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) def testGamma(self, a, dtype): key = random.PRNGKey(0) rand = lambda key, a: random.gamma(key, a, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, a) compiled_samples = crand(key, a) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gamma(a).cdf) def testGammaShape(self): key = random.PRNGKey(0) x = random.gamma(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_a={}".format(alpha), "alpha": alpha } for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4])) def testGammaGrad(self, alpha): rng = random.PRNGKey(0) alphas = np.full((100, ), alpha) z = random.gamma(rng, alphas) actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas) eps = 0.01 * alpha / (1.0 + np.sqrt(alpha)) cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) - scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps) pdf = scipy.stats.gamma.pdf(z, alpha) expected_grad = -cdf_dot / pdf self.assertAllClose( actual_grad, expected_grad, check_dtypes=True, rtol=2e-2 if jtu.device_under_test() == "tpu" else 7e-4) def testGammaGradType(self): # Regression test for https://github.com/google/jax/issues/2130 key = random.PRNGKey(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) f = lambda x, y: random.gamma(key=key, a=x, dtype=jnp.float32) / y # Should not crash with a type error. api.vjp(f, a, b) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lam={}_dtype={}".format(lam, np.dtype(dtype).name), "lam": lam, "dtype": np.dtype(dtype) } for lam in [0.5, 3, 9, 11, 50, 500] for dtype in [np.int16, np.int32, np.int64])) def testPoisson(self, lam, dtype): key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, lam) compiled_samples = crand(key, lam) for samples in [uncompiled_samples, compiled_samples]: self._CheckChiSquared(samples, scipy.stats.poisson(lam).pmf) # TODO(shoyer): determine error bounds for moments more rigorously (e.g., # based on the central limit theorem). self.assertAllClose(samples.mean(), lam, rtol=0.01, check_dtypes=False) self.assertAllClose(samples.var(), lam, rtol=0.03, check_dtypes=False) def testPoissonBatched(self): key = random.PRNGKey(0) lam = jnp.concatenate([2 * jnp.ones(10000), 20 * jnp.ones(10000)]) samples = random.poisson(key, lam, shape=(20000, )) self._CheckChiSquared(samples[:10000], scipy.stats.poisson(2.0).pmf) self._CheckChiSquared(samples[10000:], scipy.stats.poisson(20.0).pmf) def testPoissonShape(self): key = random.PRNGKey(0) x = random.poisson(key, np.array([2.0, 20.0]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in jtu.dtypes.floating)) def testGumbel(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.gumbel(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.gumbel_r().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testLaplace(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.laplace(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.laplace().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype } for dtype in float_dtypes)) def testLogistic(self, dtype): key = random.PRNGKey(0) rand = lambda key: random.logistic(key, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key) compiled_samples = crand(key) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.logistic().cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_b={}_dtype={}".format( b, np.dtype(dtype).name), "b": b, "dtype": dtype } for b in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) def testPareto(self, b, dtype): key = random.PRNGKey(0) rand = lambda key, b: random.pareto(key, b, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, b) compiled_samples = crand(key, b) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.pareto(b).cdf) def testParetoShape(self): key = random.PRNGKey(0) x = random.pareto(key, np.array([0.2, 0.3]), shape=(3, 2)) assert x.shape == (3, 2) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_df={}_dtype={}".format( df, np.dtype(dtype).name), "df": df, "dtype": dtype } for df in [0.1, 1., 10.] for dtype in jtu.dtypes.floating)) @jtu.skip_on_devices("cpu", "tpu") # TODO(phawkins): slow compilation times def testT(self, df, dtype): key = random.PRNGKey(0) rand = lambda key, df: random.t(key, df, (10000, ), dtype) crand = api.jit(rand) uncompiled_samples = rand(key, df) compiled_samples = crand(key, df) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.t(df).cdf) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dim={}_dtype={}".format( dim, np.dtype(dtype)), "dim": dim, "dtype": dtype } for dim in [1, 3, 5] for dtype in float_dtypes)) def testMultivariateNormal(self, dim, dtype): r = np.random.RandomState(dim) mean = r.randn(dim) cov_factor = r.randn(dim, dim) cov = np.dot(cov_factor, cov_factor.T) + dim * np.eye(dim) key = random.PRNGKey(0) rand = partial(random.multivariate_normal, mean=mean, cov=cov, shape=(10000, )) crand = api.jit(rand) uncompiled_samples = np.asarray(rand(key), np.float64) compiled_samples = np.asarray(crand(key), np.float64) inv_scale = scipy.linalg.lapack.dtrtri(np.linalg.cholesky(cov), lower=True)[0] for samples in [uncompiled_samples, compiled_samples]: centered = samples - mean whitened = np.einsum('nj,ij->ni', centered, inv_scale) # This is a quick-and-dirty multivariate normality check that tests that a # uniform mixture of the marginals along the covariance matrix's # eigenvectors follow a standard normal distribution. self._CheckKolmogorovSmirnovCDF(whitened.ravel(), scipy.stats.norm().cdf) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_dim={}_mean_batch_size={}_cov_batch_size={}_shape={}"\ .format(dim, mean_batch_size, cov_batch_size, shape), "dim": dim, "mean_batch_size": mean_batch_size, "cov_batch_size": cov_batch_size, "shape": shape} for dim in [1, 2, 4] for mean_batch_size in [(), (3,), (2, 3)] for cov_batch_size in [(), (3,), (2, 3)] for shape in [(), (1,), (5,)])) def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, shape): r = np.random.RandomState(0) key = random.PRNGKey(0) eff_batch_size = mean_batch_size \ if len(mean_batch_size) > len(cov_batch_size) else cov_batch_size mean = r.randn(*(mean_batch_size + (dim, ))) cov_factor = r.randn(*(cov_batch_size + (dim, dim))) cov = np.einsum('...ij,...kj->...ik', cov_factor, cov_factor) cov += 1e-3 * np.eye(dim) shape = shape + eff_batch_size samples = random.multivariate_normal(key, mean, cov, shape=shape) assert samples.shape == shape + (dim, ) def testMultivariateNormalCovariance(self): # test code based on https://github.com/google/jax/issues/1869 N = 100000 cov = jnp.array([[0.19, 0.00, -0.13, 0.00], [0.00, 0.29, 0.00, -0.23], [-0.13, 0.00, 0.39, 0.00], [0.00, -0.23, 0.00, 0.49]]) mean = jnp.zeros(4) out_np = np.random.RandomState(0).multivariate_normal(mean, cov, N) key = random.PRNGKey(0) out_jnp = random.multivariate_normal(key, mean=mean, cov=cov, shape=(N, )) var_np = out_np.var(axis=0) var_jnp = out_jnp.var(axis=0) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) var_np = np.cov(out_np, rowvar=False) var_jnp = np.cov(out_jnp, rowvar=False) self.assertAllClose(var_np, var_jnp, rtol=1e-2, atol=1e-2, check_dtypes=False) def testIssue222(self): x = random.randint(random.PRNGKey(10003), (), 0, 0) assert x == 0 def testFoldIn(self): key = random.PRNGKey(0) keys = [random.fold_in(key, i) for i in range(10)] assert np.unique(np.ravel(keys)).shape == (20, ) def testStaticShapeErrors(self): if config.read("jax_disable_jit"): raise SkipTest("test only relevant when jit enabled") @api.jit def feature_map(n, d, sigma=1.0, seed=123): key = random.PRNGKey(seed) W = random.normal(key, (d, n)) / sigma w = random.normal(key, (d, )) / sigma b = 2 * jnp.pi * random.uniform(key, (d, )) phi = lambda x, t: jnp.sqrt(2.0 / d) * jnp.cos( jnp.matmul(W, x) + w * t + b) return phi self.assertRaisesRegex(TypeError, 'Shapes must be 1D.*', lambda: feature_map(5, 3)) def testIssue756(self): key = random.PRNGKey(0) w = random.normal(key, ()) if FLAGS.jax_enable_x64: self.assertEqual(np.result_type(w), np.float64) else: self.assertEqual(np.result_type(w), np.float32) def testIssue1789(self): def f(x): return random.gamma(random.PRNGKey(0), x) grad(lambda x: jnp.sum(vmap(f)(x)))(jnp.ones(2)) def testNoOpByOpUnderHash(self): def fail(*args, **kwargs): assert False apply_primitive, xla.apply_primitive = xla.apply_primitive, fail try: _ = random.threefry_2x32(np.zeros(2, np.uint32), np.arange(10, dtype=np.uint32)) finally: xla.apply_primitive = apply_primitive def testPRNGValues(self): # Test to ensure consistent random values between JAX versions k = random.PRNGKey(0) if FLAGS.jax_enable_x64: self.assertAllClose( random.randint(k, (3, 3), 0, 8), np.array([[7, 2, 6], [2, 1, 0], [6, 7, 7]], dtype='int64')) else: self.assertAllClose( random.randint(k, (3, 3), 0, 8), np.array([[2, 1, 3], [6, 1, 5], [6, 3, 4]], dtype='int32')) self.assertAllClose( random.split(k, 4), np.array([[2285895361, 1501764800], [1518642379, 4090693311], [433833334, 4221794875], [839183663, 3740430601]], dtype='uint32')) self.assertAllClose(random.fold_in(k, 4), np.array([2285895361, 433833334], dtype='uint32')) def testDtypeErrorMessage(self): with self.assertRaisesRegex(ValueError, r"dtype argument to.*"): random.normal(random.PRNGKey(0), (), dtype=jnp.int32) def testRandomBroadcast(self): """Issue 4033""" # test for broadcast issue in https://github.com/google/jax/issues/4033 key = random.PRNGKey(0) shape = (10, 2) x = random.uniform(key, shape, minval=jnp.zeros(2), maxval=jnp.ones(2)) assert x.shape == shape x = random.randint(key, shape, jnp.array([0, 1]), jnp.array([1, 2])) assert x.shape == shape def testMaxwellSample(self): num_samples = 10**5 rng = random.PRNGKey(0) rand = lambda x: random.maxwell(x, (num_samples, )) crand = api.jit(rand) loc = scipy.stats.maxwell.mean() std = scipy.stats.maxwell.std() uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples, ), samples.shape) self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.maxwell().cdf) @parameterized.named_parameters(('test1', 4.0, 1.0), ('test2', 2.0, 3.0)) def testWeibullSample(self, concentration, scale): num_samples = 10**5 rng = random.PRNGKey(0) rand = lambda x: random.weibull_min(x, scale, concentration, (num_samples, )) crand = api.jit(rand) loc = scipy.stats.weibull_min.mean(c=concentration, scale=scale) std = scipy.stats.weibull_min.std(c=concentration, scale=scale) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples, ), samples.shape) self.assertAllClose(np.mean(samples), loc, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF( samples, scipy.stats.weibull_min(c=concentration, scale=scale).cdf) @parameterized.named_parameters(('test1', 4.0, 1.0), ('test2', 2.0, 3.0)) def testDoublesidedMaxwellSample(self, loc, scale): num_samples = 10**5 rng = random.PRNGKey(0) rand = lambda key: random.double_sided_maxwell(rng, loc, scale, (num_samples, )) crand = api.jit(rand) mean = loc std = np.sqrt(3.) * scale uncompiled_samples = rand(rng) compiled_samples = crand(rng) # Compute the double sided maxwell CDF through the one sided maxwell cdf. # This is done as follows: # P(DSM <= x) = P (loc + scale * radamacher_sample * one_sided_sample <=x) = # P (radamacher_sample * one_sided_sample <= (x - loc) / scale) = # 1/2 P(one_sided_sample <= (x - loc) / scale) # + 1/2 P( - one_sided_sample <= (x - loc) / scale) = # 1/2 P(one_sided_sample <= (x - loc) / scale) # + 1/2 P(one_sided_sample >= - (x - loc) / scale) = # 1/2 CDF_one_maxwell((x - loc) / scale)) # + 1/2 (1 - CDF_one_maxwell(- (x - loc) / scale))) def double_sided_maxwell_cdf(x, loc, scale): pos = scipy.stats.maxwell().cdf((x - loc) / scale) neg = (1 - scipy.stats.maxwell().cdf((-x + loc) / scale)) return (pos + neg) / 2 for samples in [uncompiled_samples, compiled_samples]: # Check first and second moments. self.assertEqual((num_samples, ), samples.shape) self.assertAllClose(np.mean(samples), mean, atol=0., rtol=0.1) self.assertAllClose(np.std(samples), std, atol=0., rtol=0.1) self._CheckKolmogorovSmirnovCDF( samples, lambda x: double_sided_maxwell_cdf(x, loc, scale)) def testRadamacher(self): rng = random.PRNGKey(0) num_samples = 10**5 rand = lambda x: random.rademacher(x, (num_samples, )) crand = api.jit(rand) uncompiled_samples = rand(rng) compiled_samples = crand(rng) for samples in [uncompiled_samples, compiled_samples]: unique_values, counts = np.unique(samples, return_counts=True) assert len(unique_values) == 2 assert len(counts) == 2 self.assertAllClose(counts[0] / num_samples, 0.5, rtol=1e-02, atol=1e-02) self.assertAllClose(counts[1] / num_samples, 0.5, rtol=1e-02, atol=1e-02) def testChoiceShapeIsNotSequenceError(self): key = random.PRNGKey(0) with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=False) with self.assertRaises(TypeError): random.choice(key, 5, 2, replace=True) def test_eval_shape_big_random_array(self): def f(x): return random.normal(random.PRNGKey(x), (int(1e12), )) with core.skipping_checks(): # check_jaxpr will materialize array api.eval_shape(f, 0) # doesn't error @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "seed={seed}_type={type}_jit={jit}".format( **dct), **dct } for dct in [ { "seed": 0, "type": int, "jit": True, "key": [0, 0] }, { "seed": 0, "type": int, "jit": False, "key": [0, 0] }, { "seed": 1, "type": np.int32, "jit": True, "key": [0, 1] }, { "seed": 1, "type": np.int32, "jit": False, "key": [0, 1] }, { "seed": 2, "type": np.uint32, "jit": True, "key": [0, 2] }, { "seed": 2, "type": np.uint32, "jit": False, "key": [0, 2] }, { "seed": 3, "type": np.int64, "jit": True, "key": [0, 3] }, { "seed": 3, "type": np.int64, "jit": False, "key": [0, 3] }, { "seed": -1, "type": int, "jit": True, "key": [4294967295, 4294967295] if FLAGS. jax_enable_x64 else [0, 4294967295] }, { "seed": -1, "type": int, "jit": False, "key": [4294967295, 4294967295] if FLAGS. jax_enable_x64 else [0, 4294967295] }, { "seed": -2, "type": np.int32, "jit": True, "key": [0, 4294967294] }, { "seed": -2, "type": np.int32, "jit": False, "key": [0, 4294967294] }, { "seed": -3, "type": np.int64, "jit": True, "key": [4294967295, 4294967293] if FLAGS. jax_enable_x64 else [0, 4294967293] }, { "seed": -3, "type": np.int64, "jit": False, "key": [4294967295, 4294967293] if FLAGS. jax_enable_x64 else [0, 4294967293] }, { "seed": np.iinfo(np.int32).max + 100, "type": int, "jit": True, "key": [0, 2147483747] }, { "seed": np.iinfo(np.int32).max + 100, "type": int, "jit": False, "key": [0, 2147483747] }, { "seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": True, "key": [0, 2147483748] }, { "seed": np.iinfo(np.int32).max + 101, "type": np.uint32, "jit": False, "key": [0, 2147483748] }, { "seed": np.iinfo(np.int32).min - 100, "type": int, "jit": True, "key": [4294967295, 2147483548] if FLAGS. jax_enable_x64 else [0, 2147483548] }, { "seed": np.iinfo(np.int32).min - 100, "type": int, "jit": False, "key": [4294967295, 2147483548] if FLAGS. jax_enable_x64 else [0, 2147483548] }, { "seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": True, "key": [4294967295, 2147483547] if FLAGS. jax_enable_x64 else [0, 2147483547] }, { "seed": np.iinfo(np.int32).min - 101, "type": np.int64, "jit": False, "key": [4294967295, 2147483547] if FLAGS. jax_enable_x64 else [0, 2147483547] }, ])) def test_prng_seeds_and_keys(self, seed, type, jit, key): seed = type(seed) if jit: actual = api.jit(random.PRNGKey)(seed) else: actual = random.PRNGKey(seed) expected = jnp.array(key, dtype=jnp.uint32) self.assertArraysEqual(actual, expected) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": f"_seed={seed}_type={type}", "seed": seed, "type": type } for type in ["int", "np.array", "jnp.array"] for seed in [-1, 0, 1, (1 << 32) - 1, (1 << 63) - 1, np.uint64((1 << 64) - 1)])) def test_prng_jit_invariance(self, seed, type): if type == "int" and seed == (1 << 64) - 1: self.skipTest("Expected failure: Python int too large.") type = {"int": int, "np.array": np.array, "jnp.array": jnp.array}[type] args_maker = lambda: [type(seed)] self._CompileAndCheck(random.PRNGKey, args_maker) def test_prng_errors(self): seed = np.iinfo(np.uint64).max with self.assertRaises(OverflowError): random.PRNGKey(seed) with self.assertRaises(OverflowError): api.jit(random.PRNGKey)(seed)
class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string( shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format(name, jtu.format_shape_dtype_string( shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer } for name, index_specs in STATIC_INDEXING_GRAD_TESTS for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) def testStaticIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg,), 2, tol, tol, tol) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = util.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*util.subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in [ ("OneSliceIndex", [IndexSpec(shape=(5,), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3))]), ("TwoSliceIndices", [IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2)))]), ("NonUnitStrides", [ IndexSpec(shape=(3,), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng_factory, indexer): rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in [ ("OneIntIndex", [IndexSpec(shape=(3,), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3,), indexer=-1), IndexSpec(shape=(3,), indexer=-2)]), ("TwoIntIndices", [IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2))]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng_factory, indexer): rng = rng_factory() unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in [ ("OneIntIndex", [IndexSpec(shape=(3,), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3,), indexer=-1), IndexSpec(shape=(3,), indexer=-2), ]), ("TwoIntIndices", [IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr,), 2, tol, tol, tol) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in [ ("One1DIntArrayIndex", [IndexSpec(shape=(3,), indexer=onp.array([0, 1])), IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), IndexSpec(shape=(3,), indexer=onp.array([-1, 1])), IndexSpec(shape=(3,), indexer=onp.array([-2, -1])), ]), ("One2DIntArrayIndex", [IndexSpec(shape=(3,), indexer=onp.array([[0, 0]])), IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])), ]), ("Two1DIntArrayIndicesNoBroadcasting", [IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]), onp.array([1, 2])]), IndexSpec(shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])]), ]), ("Two1DIntArrayIndicesWithBroadcasting", [IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]), onp.array([1, 2])]), IndexSpec(shape=(3, 4, 5), indexer=[onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2])]), ]), ("ListOfPythonInts", [IndexSpec(shape=(3,), indexer=[0, 1, 0]), IndexSpec(shape=(3, 4, 5), indexer=[0, -1]), ]), ("ListOfListsOfPythonInts", [IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]), ]), ("ListOfPythonIntsAndIntArrays", [IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]), IndexSpec(shape=(3, 4, 5), indexer=[0, 1, onp.array([[2, 3, 0, 3]])]), ]), ("ListOfListsOfPythonIntsAndIntArrays", [IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], onp.array([[2, 3, 0, 3]])]), ]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng_factory in [jtu.rand_default]) def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng_factory, indexer): rng = rng_factory() tol = 1e-2 if lnp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg,), 2, tol, tol, tol) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_indexer={}" .format(name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer} for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng_factory in [jtu.rand_default]) def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng_factory, indexer): rng = rng_factory() indexer_with_dummies = [e if isinstance(e, onp.ndarray) else () for e in indexer] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def fun(x, indexer_with_dummies): idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = api.jit(foo) a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True) def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList1D(self): idx = [True, True, False] x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2DBroadcast(self): idx = onp.array([True, True, False, True]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2D(self): idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingDynamicShapeError(self): x = onp.zeros(3) i = onp.array([True, True, False]) self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i)) def testIssue187(self): x = lnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False) def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = lnp.ones((3, 4), lnp.float32) i = lnp.ones((3,), lnp.int32) f = lambda x, i: lnp.sum(x[i]) primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i))) expected = onp.broadcast_to( onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals, check_dtypes=True) self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True) def testTrivialGatherIsntGenerated(self): # https://github.com/google/jax/issues/1621 jaxpr = api.make_jaxpr(lambda x: x[:, None])(onp.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) def testBooleanIndexingWithEmptyResult(self): # based on a TensorFlow Probability test that started failing after #1622 x = lnp.array([-1]) mask = lnp.array([False]) ans = x[mask] # doesn't crash expected = onp.array([-1])[onp.array([False])] self.assertAllClose(ans, expected, check_dtypes=False) def testFloatIndexingError(self): x = lnp.array([1, 2, 3]) self.assertRaises(TypeError, lambda: x[3.5])
class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes) for rng_factory in [jtu.rand_default])) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes) for rng_factory in [jtu.rand_default])) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in (all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) for update_dtype in ([dtype] if op == UpdateOps.ADD else all_dtypes) for rng_factory in [jtu.rand_default])) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() args_maker = lambda: [rng(shape, dtype), rng(update_shape, update_dtype)] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in float_dtypes for update_shape in _broadcastable_shapes(_update_shape(shape, indexer)) for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes) for rng_factory in [jtu.rand_default])) @jtu.skip_on_devices("tpu") # TODO(mattjj,phawkins): tpu issues def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng_factory, indexer, op): rng = rng_factory() jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add jax_fn = lambda x, y: jax_op(x, indexer, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) def testSegmentSumBehavior(self): # testAdvancedIndexing compares against NumPy, and as a result doesn't check # repeated indices. This test is just a simple manual check, based on # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids, data) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) def testSegmentSum(self): data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False)
class SimulateTest(jtu.JaxTestCase): # pylint: disable=g-complex-comprehension @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in DTYPE)) def test_nve_ensemble(self, spatial_dimension, dtype): key = random.PRNGKey(0) pos_key, center_key, vel_key, mass_key = random.split(key, 4) R = random.normal(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(center_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mass = random.uniform(mass_key, (PARTICLE_COUNT, ), minval=0.1, maxval=5.0, dtype=dtype) _, shift = space.free() E = lambda R, **kwargs: np.sum((R - R0)**2) init_fn, apply_fn = simulate.nve(E, shift, 1e-3) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, kT=0.5, mass=mass) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) for _ in range(DYNAMICS_STEPS): state = apply_fn(state) E_total = E_T(state) assert np.abs(E_total - E_initial) < E_initial * 0.01 assert state.position.dtype == dtype @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in DTYPE)) def test_nve_jammed(self, spatial_dimension, dtype): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position, kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in DTYPE)) def test_nve_jammed(self, spatial_dimension, dtype): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position, kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': f'_dtype={dtype.__name__}_coordinates={coords}', 'dtype': dtype, 'coords': coords } for dtype in DTYPE for coords in COORDS)) def test_nve_jammed_periodic_general(self, dtype, coords): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general( state.box, coords == 'fractional') E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) init_fn, apply_fn = simulate.nve(E, shift_fn, 1e-3) apply_fn = jit(apply_fn) state = init_fn(key, getattr(state, coords + '_position'), kT=1e-3) E_T = lambda state: \ E(state.position) + quantity.kinetic_energy(state.velocity, state.mass) E_initial = E_T(state) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, E_T(state)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in DTYPE)) def test_nve_neighbor_list(self, spatial_dimension, dtype): Nx = particles_per_side = 8 spacing = f32(1.25) tol = 5e-12 if dtype == np.float64 else 5e-3 L = Nx * spacing if spatial_dimension == 2: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx)]) * spacing elif spatial_dimension == 3: R = np.stack([np.array(r) for r in onp.ndindex(Nx, Nx, Nx)]) * spacing R = np.array(R, dtype) displacement, shift = space.periodic(L) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, L) exact_energy_fn = energy.lennard_jones_pair(displacement) init_fn, apply_fn = simulate.nve(energy_fn, shift, 1e-3) exact_init_fn, exact_apply_fn = simulate.nve(exact_energy_fn, shift, 1e-3) nbrs = neighbor_fn(R) state = init_fn(random.PRNGKey(0), R, kT=0.5, neighbor=nbrs) exact_state = exact_init_fn(random.PRNGKey(0), R, kT=0.5) def body_fn(i, state): state, nbrs, exact_state = state nbrs = neighbor_fn(state.position, nbrs) state = apply_fn(state, neighbor=nbrs) return state, nbrs, exact_apply_fn(exact_state) step = 0 for i in range(20): new_state, nbrs, new_exact_state = lax.fori_loop( 0, 100, body_fn, (state, nbrs, exact_state)) if nbrs.did_buffer_overflow: nbrs = neighbor_fn(state.position) else: state = new_state exact_state = new_exact_state step += 1 assert state.position.dtype == dtype self.assertAllClose(state.position, exact_state.position, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': f'_dim={dim}_dtype={dtype.__name__}_sy_steps={sy_steps}', 'spatial_dimension': dim, 'dtype': dtype, 'sy_steps': sy_steps, } for dim in SPATIAL_DIMENSION for dtype in DTYPE for sy_steps in [1, 3, 5, 7])) def test_nvt_nose_hoover(self, spatial_dimension, dtype, sy_steps): key = random.PRNGKey(0) box_size = quantity.box_size_at_number_density(PARTICLE_COUNT, f32(1.2), spatial_dimension) displacement_fn, shift_fn = space.periodic(box_size) bonds_i = np.arange(PARTICLE_COUNT) bonds_j = np.roll(bonds_i, 1) bonds = np.stack([bonds_i, bonds_j]) E = energy.simple_spring_bond(displacement_fn, bonds) invariant = partial(simulate.nvt_nose_hoover_invariant, E) for _ in range(STOCHASTIC_SAMPLES): key, pos_key, vel_key, T_key, masses_key = random.split(key, 5) R = box_size * random.uniform(pos_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype) mass = 1 + random.uniform(masses_key, (PARTICLE_COUNT, ), dtype=dtype) init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift_fn, 1e-3, T, sy_steps=sy_steps) apply_fn = jit(apply_fn) state = init_fn(vel_key, R, mass=mass) initial = invariant(state, T) for _ in range(DYNAMICS_STEPS): state = apply_fn(state) T_final = quantity.temperature(state.velocity, state.mass) assert np.abs(T_final - T) / T < 0.1 self.assertAllClose(invariant(state, T), initial, rtol=1e-4) self.assertEqual(state.position.dtype, dtype) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': f'dtype={dtype.__name__}_sy_steps={sy_steps}', 'dtype': dtype, 'sy_steps': sy_steps, } for dtype in DTYPE for sy_steps in [1, 3, 5, 7])) def test_nvt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic(state.box[0, 0]) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.nvt_nose_hoover_invariant, E) kT = 1e-3 init_fn, apply_fn = simulate.nvt_nose_hoover(E, shift_fn, 1e-3, kT=kT, sy_steps=sy_steps) apply_fn = jit(apply_fn) state = init_fn(key, state.real_position) E_initial = invariant(state, kT) * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_and_energy): state, energy = state_and_energy state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, kT)) return state, energy Es = np.zeros((DYNAMICS_STEPS, )) state, Es = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': f'dtype={dtype.__name__}_sy_steps={sy_steps}', 'dtype': dtype, 'sy_steps': sy_steps, } for dtype in DTYPE for sy_steps in [1, 3, 5, 7])) def test_npt_nose_hoover_jammed(self, dtype, sy_steps): key = random.PRNGKey(0) state = test_util.load_test_state('simulation_test_state.npy', dtype) displacement_fn, shift_fn = space.periodic_general(state.box) E = energy.soft_sphere_pair(displacement_fn, state.species, state.sigma) invariant = partial(simulate.npt_nose_hoover_invariant, E) pressure_fn = partial(quantity.pressure, E) nhc_kwargs = {sy_steps: sy_steps} kT = 1e-3 P = state.pressure init_fn, apply_fn = simulate.npt_nose_hoover(E, shift_fn, 1e-3, P, kT, nhc_kwargs, nhc_kwargs) apply_fn = jit(apply_fn) state = init_fn(key, state.fractional_position, state.box) E_initial = invariant(state, P, kT) * np.ones((DYNAMICS_STEPS, )) P_target = P * np.ones((DYNAMICS_STEPS, )) def step_fn(i, state_energy_pressure): state, energy, pressure = state_energy_pressure state = apply_fn(state) energy = ops.index_update(energy, i, invariant(state, P, kT)) box = simulate.npt_box(state) KE = quantity.kinetic_energy(state.velocity, state.mass) p = pressure_fn(state.position, box, KE) pressure = ops.index_update(pressure, i, p) return state, energy, pressure Es = np.zeros((DYNAMICS_STEPS, )) Ps = np.zeros((DYNAMICS_STEPS, )) state, Es, Ps = lax.fori_loop(0, DYNAMICS_STEPS, step_fn, (state, Es, Ps)) tol = 1e-3 if dtype is f32 else 1e-7 self.assertEqual(state.position.dtype, dtype) self.assertAllClose(Es, E_initial, rtol=tol, atol=tol) self.assertAllClose(Ps, P_target, rtol=0.05, atol=0.05) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in DTYPE)) def test_nvt_langevin(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, R_key, R0_key, T_key, masses_key = random.split(key, 5) R = random.normal(R_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype) R0 = random.normal(R0_key, (LANGEVIN_PARTICLE_COUNT, spatial_dimension), dtype=dtype) _, shift = space.free() E = functools.partial(lambda R, R0, **kwargs: np.sum((R - R0)**2), R0=R0) T = random.uniform(T_key, (), minval=0.3, maxval=1.4, dtype=dtype) mass = random.uniform(masses_key, (LANGEVIN_PARTICLE_COUNT, ), minval=0.1, maxval=10.0, dtype=dtype) init_fn, apply_fn = simulate.nvt_langevin(E, shift, f32(1e-2), T, gamma=f32(0.3)) apply_fn = jit(apply_fn) state = init_fn(key, R, mass=mass, T_initial=dtype(1.0)) T_list = [] for step in range(LANGEVIN_DYNAMICS_STEPS): state = apply_fn(state) if step > 4000 and step % 100 == 0: T_list += [ quantity.temperature(state.velocity, state.mass) ] # TODO(schsam): It would be good to check Gaussinity of R and V in the # noninteracting case. T_emp = np.mean(np.array(T_list)) assert np.abs(T_emp - T) < 0.1 assert state.position.dtype == dtype @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in DTYPE)) def test_brownian(self, spatial_dimension, dtype): key = random.PRNGKey(0) key, T_split, mass_split = random.split(key, 3) _, shift = space.free() energy_fn = lambda R, **kwargs: f32(0) R = np.zeros((BROWNIAN_PARTICLE_COUNT, 2), dtype=dtype) mass = random.uniform(mass_split, (), minval=0.1, maxval=10.0, dtype=dtype) T = random.uniform(T_split, (), minval=0.3, maxval=1.4, dtype=dtype) dt = f32(1e-2) gamma = f32(0.1) init_fn, apply_fn = simulate.brownian(energy_fn, shift, dt, T, gamma=gamma) apply_fn = jit(apply_fn) state = init_fn(key, R, mass) sim_t = f32(BROWNIAN_DYNAMICS_STEPS * dt) for _ in range(BROWNIAN_DYNAMICS_STEPS): state = apply_fn(state) msd = np.var(state.position) th_msd = dtype(2 * T / (mass * gamma) * sim_t) assert np.abs(msd - th_msd) / msd < 1e-2 assert state.position.dtype == dtype
class FftTest(jtu.JaxTestCase): def testNotImplemented(self): for name in jnp.fft._NOT_IMPLEMENTED: func = getattr(jnp.fft, name) with self.assertRaises(NotImplementedError): func() @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format( inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes), "axes": axes, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "inverse": inverse, "real": real} for inverse in [False, True] for real in [False, True] for rng_factory in [jtu.rand_default] for dtype in (real_dtypes if real and not inverse else all_dtypes) for shape in [(10,), (10, 10), (9,), (2, 3, 4), (2, 3, 4, 5)] for axes in _get_fftn_test_axes(shape))) def testFftn(self, inverse, real, shape, dtype, axes, rng_factory): rng = rng_factory(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_op = _get_fftn_func(jnp.fft, inverse, real) np_op = _get_fftn_func(np.fft, inverse, real) jnp_fn = lambda a: jnp_op(a, axes=axes) np_fn = lambda a: np_op(a, axes=axes) if axes is None or axes else a # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if (FLAGS.jax_enable_x64 and dtype in (float_dtypes if real and not inverse else inexact_dtypes)): # TODO(skye): can we be more precise? tol = 0.15 jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}".format(inverse, real), "inverse": inverse, "real": real} for inverse in [False, True] for real in [False, True])) def testFftnErrors(self, inverse, real): rng = jtu.rand_default(self.rng()) name = 'fftn' if real: name = 'r' + name if inverse: name = 'i' + name func = _get_fftn_func(jnp.fft, inverse, real) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 1D, 2D, and 3D FFTs. " "Got axes None with input rank 4.".format(name), lambda: func(rng([2, 3, 4, 5], dtype=np.float64), axes=None)) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} does not support repeated axes. Got axes \\[1, 1\\].".format(name), lambda: func(rng([2, 3], dtype=np.float64), axes=[1, 1])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3])) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}_hermitian={}_shape={}_axis={}".format( inverse, real, hermitian, jtu.format_shape_dtype_string(shape, dtype), axis), "axis": axis, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "inverse": inverse, "real": real, "hermitian": hermitian} for inverse in [False, True] for real in [False, True] for hermitian in [False, True] for rng_factory in [jtu.rand_default] for dtype in (real_dtypes if (real and not inverse) or (hermitian and inverse) else all_dtypes) for shape in [(10,)] for axis in [-1, 0])) def testFft(self, inverse, real, hermitian, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) args_maker = lambda: (rng(shape, dtype),) name = 'fft' if real: name = 'r' + name elif hermitian: name = 'h' + name if inverse: name = 'i' + name jnp_op = getattr(jnp.fft, name) np_op = getattr(np.fft, name) jnp_fn = lambda a: jnp_op(a, axis=axis) np_fn = lambda a: np_op(a, axis=axis) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}_hermitian={}".format(inverse, real, hermitian), "inverse": inverse, "real": real, "hermitian": hermitian} for inverse in [False, True] for real in [False, True] for hermitian in [False, True])) def testFftErrors(self, inverse, real, hermitian): rng = jtu.rand_default(self.rng()) name = 'fft' if real: name = 'r' + name elif hermitian: name = 'h' + name if inverse: name = 'i' + name func = getattr(jnp.fft, name) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} does not support multiple axes. " "Please use jax.numpy.fft.{}n. " "Got axis = \\[1, 1\\].".format(name, name), lambda: func(rng([2, 3], dtype=np.float64), axis=[1, 1]) ) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} does not support multiple axes. " "Please use jax.numpy.fft.{}n. " "Got axis = \\(1, 1\\).".format(name, name), lambda: func(rng([2, 3], dtype=np.float64), axis=(1, 1)) ) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axis=[-3])) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}_shape={}_axes={}".format( inverse, real, jtu.format_shape_dtype_string(shape, dtype), axes), "axes": axes, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "inverse": inverse, "real": real} for inverse in [False, True] for real in [False, True] for rng_factory in [jtu.rand_default] for dtype in (real_dtypes if real and not inverse else all_dtypes) for shape in [(16, 8, 4, 8), (16, 8, 4, 8, 4)] for axes in [(-2, -1), (0, 1), (1, 3), (-1, 2)])) def testFft2(self, inverse, real, shape, dtype, axes, rng_factory): rng = rng_factory(self.rng()) args_maker = lambda: (rng(shape, dtype),) name = 'fft2' if real: name = 'r' + name if inverse: name = 'i' + name jnp_op = getattr(jnp.fft, name) np_op = getattr(np.fft, name) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_op, jnp_op, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_op, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inverse={}_real={}".format(inverse, real), "inverse": inverse, "real": real} for inverse in [False, True] for real in [False, True])) def testFft2Errors(self, inverse, real): rng = jtu.rand_default(self.rng()) name = 'fft2' if real: name = 'r' + name if inverse: name = 'i' + name func = getattr(jnp.fft, name) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 2 axes. " "Got axes = \\[0\\].".format(name), lambda: func(rng([2, 3], dtype=np.float64), axes=[0]) ) self.assertRaisesRegex( ValueError, "jax.numpy.fft.{} only supports 2 axes. " "Got axes = \\(0, 1, 2\\).".format(name), lambda: func(rng([2, 3, 3], dtype=np.float64), axes=(0, 1, 2)) ) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[2, 3])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=np.float64), axes=[-3, -4])) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_size={}_d={}".format( jtu.format_shape_dtype_string([size], dtype), d), "dtype": dtype, "size": size, "rng_factory": rng_factory, "d": d} for rng_factory in [jtu.rand_default] for dtype in all_dtypes for size in [9, 10, 101, 102] for d in [0.1, 2.])) def testFftfreq(self, size, d, dtype, rng_factory): rng = rng_factory(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.fftfreq np_op = np.fft.fftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(n), "n": n} for n in [[0,1,2]])) def testFftfreqErrors(self, n): name = 'fftfreq' func = jnp.fft.fftfreq self.assertRaisesRegex( ValueError, "The n argument of jax.numpy.fft.{} only takes an int. " "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n) ) self.assertRaisesRegex( ValueError, "The d argument of jax.numpy.fft.{} only takes a single value. " "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n) ) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_size={}_d={}".format( jtu.format_shape_dtype_string([size], dtype), d), "dtype": dtype, "size": size, "rng_factory": rng_factory, "d": d} for rng_factory in [jtu.rand_default] for dtype in all_dtypes for size in [9, 10, 101, 102] for d in [0.1, 2.])) def testRfftfreq(self, size, d, dtype, rng_factory): rng = rng_factory(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.rfftfreq np_op = np.fft.rfftfreq jnp_fn = lambda a: jnp_op(size, d=d) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker) # Test gradient for differentiable types. if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(n), "n": n} for n in [[0, 1, 2]])) def testRfftfreqErrors(self, n): name = 'rfftfreq' func = jnp.fft.rfftfreq self.assertRaisesRegex( ValueError, "The n argument of jax.numpy.fft.{} only takes an int. " "Got n = \\[0, 1, 2\\].".format(name), lambda: func(n=n) ) self.assertRaisesRegex( ValueError, "The d argument of jax.numpy.fft.{} only takes a single value. " "Got d = \\[0, 1, 2\\].".format(name), lambda: func(n=10, d=n) ) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "dtype={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), axes), "dtype": dtype, "shape": shape, "rng_factory": rng_factory, "axes": axes} for rng_factory in [jtu.rand_default] for dtype in all_dtypes for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape))) def testFftshift(self, shape, dtype, rng_factory, axes): rng = rng_factory(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda arg: jnp.fft.fftshift(arg, axes=axes) np_fn = lambda arg: np.fft.fftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "dtype={}_axes={}".format( jtu.format_shape_dtype_string(shape, dtype), axes), "dtype": dtype, "shape": shape, "rng_factory": rng_factory, "axes": axes} for rng_factory in [jtu.rand_default] for dtype in all_dtypes for shape in [[9], [10], [101], [102], [3, 5], [3, 17], [5, 7, 11]] for axes in _get_fftn_test_axes(shape))) def testIfftshift(self, shape, dtype, rng_factory, axes): rng = rng_factory(self.rng()) args_maker = lambda: (rng(shape, dtype),) jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes) np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
class PredictTest(jtu.JaxTestCase): @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}_{}'.format( train, test, network, out_logits, name), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, 'fn_and_kernel': fn } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for out_logits in OUTPUT_LOGITS for name, fn in KERNELS.items())) def testNTKMSEPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel): key = random.PRNGKey(0) key, split = random.split(key) data_train = random.normal(split, train_shape) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = random.normal(split, test_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda params, x: \ 0.5 * np.mean((f(params, x) - data_labels) ** 2) grad_loss = jit(grad(loss)) g_dd = ntk(data_train, None, 'ntk') g_td = ntk(data_test, data_train, 'ntk') predictor = predict.gradient_descent_mse(g_dd, data_labels, g_td) atol = ATOL rtol = RTOL step_size = 0.1 if len(train_shape) > 2: # Hacky way to up the tolerance just for convolutions. atol = ATOL * 2 rtol = RTOL * 2 step_size = 0.1 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor(0.0, fx_initial_train, fx_initial_test) self.assertAllClose(fx_initial_train, fx_pred_train, True) self.assertAllClose(fx_initial_test, fx_pred_test, True) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor(train_time, fx_initial_train, fx_initial_test) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), True, rtol, atol) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True, rtol, atol) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}_{}'.format( train, test, network, out_logits, name), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, 'fn_and_kernel': fn } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for out_logits in OUTPUT_LOGITS for name, fn in KERNELS.items())) def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel): key = random.PRNGKey(0) key, split = random.split(key) data_train = random.normal(split, train_shape) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = random.normal(split, test_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit( grad(lambda params, x: loss(f(params, x), data_labels))) g_dd = ntk(data_train, None, 'ntk') g_td = ntk(data_test, data_train, 'ntk') predictor = predict.gradient_descent(g_dd, data_labels, loss, g_td) atol = ATOL rtol = RTOL step_size = 0.5 if len(train_shape) > 2: # Hacky way to up the tolerance just for convolutions. atol = ATOL * 2 rtol = RTOL * 2 step_size = 0.1 train_time = 100.0 steps = int(train_time / step_size) opt_init, opt_update, get_params = optimizers.sgd(step_size) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor(0.0, fx_initial_train, fx_initial_test) self.assertAllClose(fx_initial_train, fx_pred_train, True) self.assertAllClose(fx_initial_test, fx_pred_test, True) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) fx_pred_train, fx_pred_test = predictor(train_time, fx_initial_train, fx_initial_test) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), True, rtol, atol) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True, rtol, atol) # TODO(schsam): Get this test passing with theoretical conv. @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}_{}'.format( train, test, network, out_logits, name), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, 'fn_and_kernel': fn } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for out_logits in OUTPUT_LOGITS for name, fn in KERNELS.items() if len(train) == 2)) def testNTKMomentumPrediction(self, train_shape, test_shape, network, out_logits, fn_and_kernel): key = random.PRNGKey(0) key, split = random.split(key) data_train = random.normal(split, train_shape) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = random.normal(split, test_shape) params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits) # Regress to an MSE loss. loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2) grad_loss = jit( grad(lambda params, x: loss(f(params, x), data_labels))) g_dd = ntk(data_train, None, 'ntk') g_td = ntk(data_test, data_train, 'ntk') atol = ATOL rtol = RTOL step_size = 0.5 if len(train_shape) > 2: # Hacky way to up the tolerance just for convolutions. atol = ATOL * 2 rtol = RTOL * 2 step_size = 0.1 train_time = 100.0 steps = int(train_time / np.sqrt(step_size)) init, predictor, get = predict.momentum(g_dd, data_labels, loss, step_size, g_td) opt_init, opt_update, get_params = momentum(step_size, 0.9) opt_state = opt_init(params) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) lin_state = init(fx_initial_train, fx_initial_test) fx_pred_train, fx_pred_test = get(lin_state) self.assertAllClose(fx_initial_train, fx_pred_train, True) self.assertAllClose(fx_initial_test, fx_pred_test, True) for i in range(steps): params = get_params(opt_state) opt_state = opt_update(i, grad_loss(params, data_train), opt_state) params = get_params(opt_state) fx_train = f(params, data_train) fx_test = f(params, data_test) lin_state = predictor(lin_state, train_time) fx_pred_train, fx_pred_test = get(lin_state) fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2)) fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2)) fx_error_train = (fx_train - fx_pred_train) / fx_disp_train fx_error_test = (fx_test - fx_pred_test) / fx_disp_test self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train), True, rtol, atol) self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True, rtol, atol) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}'.format( train, test, network, out_logits), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1]) for out_logits in OUTPUT_LOGITS)) def testNTKMeanPrediction(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) mean_pred, var = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, 'ntk', diag_reg=0., compute_var=True) if xla_bridge.get_backend().platform == 'tpu': eigh = np.onp.linalg.eigh else: eigh = np.linalg.eigh self.assertEqual(var.shape[0], data_test.shape[0]) min_eigh = np.min(eigh(var)[0]) self.assertGreater(min_eigh + 1e-10, 0.) def mc_sampling(count=10): empirical_mean = 0. key = random.PRNGKey(100) init_fn, f, _ = _build_network(train_shape[1:], network, out_logits) _kernel_fn = empirical.empirical_kernel_fn(f) kernel_fn = jit( lambda x1, x2, params: _kernel_fn(x1, x2, params, 'ntk')) for _ in range(count): key, split = random.split(key) _, params = init_fn(split, train_shape) g_dd = kernel_fn(data_train, None, params) g_td = kernel_fn(data_test, data_train, params) predictor = predict.gradient_descent_mse( g_dd, data_labels, g_td) fx_initial_train = f(params, data_train) fx_initial_test = f(params, data_test) _, fx_pred_test = predictor(1.0e8, fx_initial_train, fx_initial_test) empirical_mean += fx_pred_test return empirical_mean / count atol = ATOL rtol = RTOL mean_emp = mc_sampling(100) self.assertAllClose(mean_pred, mean_emp, True, rtol, atol) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}'.format( train, test, network, out_logits), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1]) for out_logits in OUTPUT_LOGITS)) def testGPInferenceGet(self, train_shape, test_shape, network, out_logits): key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) out = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, 'ntk', diag_reg=0., compute_var=True) assert isinstance(out, predict.Gaussian) out = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, 'nngp', diag_reg=0., compute_var=True) assert isinstance(out, predict.Gaussian) out = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, ('ntk', ), diag_reg=0., compute_var=True) assert len(out) == 1 and isinstance(out[0], predict.Gaussian) out = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, ('ntk', 'nngp'), diag_reg=0., compute_var=True) assert (len(out) == 2 and isinstance(out[0], predict.Gaussian) and isinstance(out[1], predict.Gaussian)) out2 = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, ('nngp', 'ntk'), diag_reg=0., compute_var=True) self.assertAllClose(out[0], out2[1], True) self.assertAllClose(out[1], out2[0], True) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}_get={}'.format( train, test, network, out_logits, get), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, 'get': get, } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1]) for out_logits in OUTPUT_LOGITS for get in GETS)) def testInfiniteTimeAgreement(self, train_shape, test_shape, network, out_logits, get): key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, kernel_fn = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 inf_prediction = predict.gp_inference(kernel_fn, data_train, data_labels, data_test, get, diag_reg=reg, compute_var=True) prediction = predict.gradient_descent_mse_gp(kernel_fn, data_train, data_labels, data_test, get, diag_reg=reg, compute_var=True) finite_prediction = prediction(np.inf) self.assertAllClose(inf_prediction, finite_prediction, True) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train={}_test={}_network={}_logits={}'.format( train, test, network, out_logits), 'train_shape': train, 'test_shape': test, 'network': network, 'out_logits': out_logits, } for train, test, network in zip(TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1]) for out_logits in OUTPUT_LOGITS)) def testZeroTimeAgreement(self, train_shape, test_shape, network, out_logits): """Test that the NTK and NNGP agree at t=0.""" key = random.PRNGKey(0) key, split = random.split(key) data_train = np.cos(random.normal(split, train_shape)) key, split = random.split(key) data_labels = np.array( random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32) key, split = random.split(key) data_test = np.cos(random.normal(split, test_shape)) _, _, ker_fun = _build_network(train_shape[1:], network, out_logits) reg = 1e-7 prediction = predict.gradient_descent_mse_gp(ker_fun, data_train, data_labels, data_test, diag_reg=reg, get=('NTK', 'NNGP'), compute_var=True) zero_prediction = prediction(0.0) self.assertAllClose(zero_prediction.ntk, zero_prediction.nngp, True) reference = (np.zeros((test_shape[0], out_logits)), ker_fun(data_test, data_test, get='nngp')) self.assertAllClose((reference, ) * 2, zero_prediction, True)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b } for shape_group in compatible_shapes for dtype in float_dtypes for use_b in [False, True] for shapes in itertools.product( *((shape_group, shape_group) if use_b else (shape_group, ))) for axis in range( -max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( itertools.chain.from_iterable( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for shapes in itertools.combinations_with_replacement( all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement( rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [ x for i, x in enumerate(args) if i not in nondiff_argnums ] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng_factory": jtu.rand_positive, "shape": shape, "dtype": dtype, "d": d } for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng_factory, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = rng_factory(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={ np.float32: 1e-3, np.float64: 1e-14 }) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4, ), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4, ), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False)
class SMapTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_no_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32)) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R), dtype(accum)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_no_type_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1))**f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric) bonds = np.array([[0, 1], [0, 2]], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R, bonds), dtype(accum)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32), np.array([0, 1], i32), sigma=sigma) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic( metric(R[0], R[2]), 2) self.assertAllClose(mapped(R), dtype(accum)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_type_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, sigma=sigma) bonds = np.array([[0, 1], [0, 2]], i32) bond_types = np.array([0, 1], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic( metric(R[0], R[2]), 2) self.assertAllClose(mapped(R, bonds, bond_types), dtype(accum)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_params_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, sigma=1.0) bonds = np.array([[0, 1], [0, 2]], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic( metric(R[0], R[2]), 2) self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_per_bond_static(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma)**f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32), sigma=sigma) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic( metric(R[0], R[2]), 2) self.assertAllClose(mapped(R), dtype(accum)) def test_get_species_parameters(self): species = [(0, 0), (0, 1), (1, 0), (1, 1)] params = np.array([[2.0, 3.0], [3.0, 1.0]]) global_params = 3.0 self.assertAllClose(smap._get_species_parameters(params, species[0]), 2.0) self.assertAllClose(smap._get_species_parameters(params, species[1]), 3.0) self.assertAllClose(smap._get_species_parameters(params, species[2]), 3.0) self.assertAllClose(smap._get_species_parameters(params, species[3]), 1.0) for s in species: self.assertAllClose(smap._get_species_parameters(global_params, s), 3.0) def test_get_matrix_parameters(self): params = np.array([1.0, 2.0]) params_mat_test = np.array([[1.0, 1.5], [1.5, 2.0]]) params_mat = smap._get_matrix_parameters(params, lambda x, y: 0.5 * (x + y)) self.assertAllClose(params_mat, params_mat_test) params_mat_direct = np.array([[1.0, 2.0], [3.0, 4.0]]) self.assertAllClose( smap._get_matrix_parameters(params_mat_direct, None), params_mat_direct) params_scalar = 1.0 self.assertAllClose(smap._get_matrix_parameters(params_scalar, None), params_scalar) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_scalar(self, spatial_dimension, dtype): square = lambda dr: dr**2 displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric) metric = space.map_product(metric) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose( mapped_square(R), np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype): square = lambda dr, epsilon: epsilon * dr**2 displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric, epsilon=1.0) metric = space.map_product(metric) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.uniform(split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) epsilon = random.uniform(split2, (PARTICLE_COUNT, ), dtype=dtype) mat_epsilon = 0.5 * (epsilon[:, np.newaxis] + epsilon[np.newaxis, :]) self.assertAllClose( mapped_square(R, epsilon=epsilon), np.array(0.5 * np.sum(square(metric(R, R), mat_epsilon)), dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_vector(self, spatial_dimension, dtype): square = lambda dr: np.sum(dr**2, axis=2) disp, _ = space.free() mapped_square = smap.pair(square, disp) disp = space.map_product(disp) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype) self.assertAllClose(mapped_square(R), mapped_ref) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_vector_nonadditive(self, spatial_dimension, dtype): square = lambda dr, params: params * np.sum(dr**2, axis=2) disp, _ = space.free() mapped_square = smap.pair(square, disp, params=lambda x, y: x * y) disp = space.map_product(disp) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, R_key, params_key = random.split(key, 3) R = random.uniform(R_key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) params = random.uniform(params_key, (PARTICLE_COUNT, ), dtype=dtype, minval=0.1, maxval=1.5) pp_params = params[None, :] * params[:, None] mapped_ref = np.array(0.5 * np.sum(square(disp(R, R), pp_params)), dtype=dtype) self.assertAllClose(mapped_square(R, params=params), mapped_ref) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_static_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr**2 params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric, species=species, param=params) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum( square(metric(R_1, R_2), param)) self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_static_species_scalar_dynamic(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr**2 key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric, species=species, param=1.0) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.uniform(split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) params = random.uniform(split2, (2, 2), dtype=dtype) params = f32(0.5) * (params + params.T) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum( square(metric(R_1, R_2), param)) self.assertAllClose(mapped_square(R, param=params), np.array(total, dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_scalar_dummy_arg(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr**2 key, split = random.split(key) R = random.normal(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) displacement, shift = space.free() mapped = smap.pair(square, space.metric(displacement)) mapped(R, t=f32(0)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_static_species_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * np.sum(dr**2, axis=2) params = np.array([[1.0, 2.0], [2.0, 3.0]], dtype=f32) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) disp, _ = space.free() mapped_square = smap.pair(square, disp, species=species, param=params) disp = space.map_product(disp) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(disp(R_1, R_2), param)) self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr**2 params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) displacement, _ = space.free() metric = space.metric(displacement) mapped_square = smap.pair(square, metric, species=2, param=params) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum( square(metric(R_1, R_2), param)) self.assertAllClose(mapped_square(R, species), np.array(total, dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_dynamic_species_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * np.sum(dr**2, axis=-1) params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT, ), 0, 2) disp, _ = space.free() mapped_square = smap.pair(square, disp, species=2, param=params) disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(disp(R_1, R_2), param)) self.assertAllClose(mapped_square(R, species), np.array(total, dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_scalar(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=2.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_scalar_diverging_potential( self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def potential(dr, sigma): return np.where(dr < sigma, dr**-6, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(potential, d, sigma=1.0)) mapped_square = jit(smap.pair(potential, d, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=2.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_force_scalar_diverging_potential( self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def potential(dr, sigma): return np.where(dr < sigma, dr**-6, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(potential, d, sigma=1.0) neighbor_square = jit(quantity.force(neighbor_square)) mapped_square = jit(quantity.force(smap.pair(potential, d, sigma=1.0))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=4.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0.0, format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_scalar_params_no_species( self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_scalar_params_matrix(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = jit(smap.pair(truncated_square, d, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_scalar_params_species(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr**2, f32(0.)) N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) species = np.zeros((N, ), np.int32) species = np.where(np.arange(N) > N / 3, 1, species) species = np.where(np.arange(N) > 2 * N / 3, 2, species) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = smap.pair_neighbor_list(truncated_square, d, species=species, sigma=1.0) neighbor_square = jit(neighbor_square) mapped_square = smap.pair(truncated_square, d, species=species, sigma=1.0) mapped_square = jit(mapped_square) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma), 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={str(format).split(".")[-1]}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_vector(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('Vector valued pair_neighbor_list not supported.') key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1, )) return np.where(dr < sigma, dR**2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, disp, sigma=1.0, reduce_axis=(1, ))) mapped_square = jit( smap.pair(truncated_square, disp, sigma=1.0, reduce_axis=(1, ))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=1.5) neighbor_fn = partition.neighbor_list(disp, box_size, sigma, 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_vector_nonadditive(self, spatial_dimension, dtype, format): if format is partition.OrderedSparse: self.skipTest('Vector valued pair_neighbor_list not supported.') key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = space.distance(dR) return np.where(dr < sigma, dr**2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, disp, sigma=lambda x, y: x * y, reduce_axis=(1, ))) mapped_square = jit( smap.pair(truncated_square, disp, sigma=1.0, reduce_axis=(1, ))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5) sigma_pair = sigma[:, None] * sigma[None, :] neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma)**2, 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma_pair), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': (f'_dim={dim}_dtype={dtype.__name__}' f'_format={format}'), 'spatial_dimension': dim, 'dtype': dtype, 'format': format } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE for format in NEIGHBOR_LIST_FORMAT)) def test_pair_neighbor_list_scalar_nonadditive(self, spatial_dimension, dtype, format): key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = space.distance(dR) return np.where(dr < sigma, dr**2, f32(0.)) N = PARTICLE_COUNT box_size = 2. * N**(1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, disp, sigma=lambda x, y: x * y)) mapped_square = jit(smap.pair(truncated_square, disp, sigma=1.0)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, ), minval=0.5, maxval=1.5) sigma_pair = sigma[:, None] * sigma[None, :] neighbor_fn = partition.neighbor_list(disp, box_size, np.max(sigma)**2, 0., format=format) nbrs = neighbor_fn.allocate(R) self.assertAllClose(mapped_square(R, sigma=sigma_pair), neighbor_square(R, nbrs, sigma=sigma)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_triplet_no_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) angle_fn = lambda dR1, dR2: np.sum(np.square(dR1) + np.square(dR2)) square = lambda dR: np.sum(np.square(dR)) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) triplet_square = smap.triplet(angle_fn, displacement) metric = space.map_product(metric) count = PARTICLE_COUNT // 50 for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (count, spatial_dimension), dtype=dtype) self.assertAllClose( triplet_square(R) / count / 2., np.array(0.5 * np.sum(metric(R, R)), dtype=dtype)) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_triplet_static_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) angle_fn = lambda dR1, dR2, param=5.0: param * np.sum(np.square(dR1)) square = lambda dR, param: param * np.sum(np.square(dR)) params = f32(np.array([[[1., 1.], [2., 0.]], [[0., 2.], [1., 1.]]])) count = PARTICLE_COUNT // 50 key, split = random.split(key) species = random.randint(split, (count, ), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) triplet_square = smap.triplet(angle_fn, displacement, species=species, param=params, reduce_axis=None) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform(split, (count, spatial_dimension), dtype=dtype) total = 0. for i in range(2): for j in range(2): R_1 = R[species == i] R_2 = R[species == j] total += 0.5 * np.sum(metric(R_1, R_2)) self.assertAllClose( triplet_square(R) / count, np.array(total, dtype=dtype))
class SMapTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_no_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1)) ** f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric, np.array([[0, 1], [0, 2]], i32)) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R), dtype(accum), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_no_type_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, **kwargs: (dr - f32(1)) ** f32(2) disp, _ = space.free() metric = space.metric(disp) mapped = smap.bond(harmonic, metric) bonds = np.array([[0, 1], [0, 2]], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1])) + harmonic(metric(R[0], R[2])) self.assertAllClose(mapped(R, bonds), dtype(accum), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_type_static(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond( harmonic, metric, np.array([[0, 1], [0, 2]], i32), np.array([0, 1], i32), sigma=sigma) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2) self.assertAllClose(mapped(R), dtype(accum), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_type_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric, sigma=sigma) bonds = np.array([[0, 1], [0, 2]], i32) bond_types = np.array([0, 1], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2) self.assertAllClose(mapped(R, bonds, bond_types), dtype(accum), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_params_dynamic(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond(harmonic, metric) bonds = np.array([[0, 1], [0, 2]], i32) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2) self.assertAllClose(mapped(R, bonds, sigma=sigma), dtype(accum), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_bond_per_bond_static(self, spatial_dimension, dtype): harmonic = lambda dr, sigma, **kwargs: (dr - sigma) ** f32(2) disp, _ = space.free() metric = space.metric(disp) sigma = np.array([1.0, 2.0], f32) mapped = smap.bond( harmonic, metric, np.array([[0, 1], [0, 2]], i32), sigma=sigma) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) accum = harmonic(metric(R[0], R[1]), 1) + harmonic(metric(R[0], R[2]), 2) self.assertAllClose(mapped(R), dtype(accum), True) def test_get_species_parameters(self): species = [(0, 0), (0, 1), (1, 0), (1, 1)] params = np.array([[2.0, 3.0], [3.0, 1.0]]) global_params = 3.0 self.assertAllClose( smap._get_species_parameters(params, species[0]), 2.0, True) self.assertAllClose( smap._get_species_parameters(params, species[1]), 3.0, True) self.assertAllClose( smap._get_species_parameters(params, species[2]), 3.0, True) self.assertAllClose( smap._get_species_parameters(params, species[3]), 1.0, True) for s in species: self.assertAllClose( smap._get_species_parameters(global_params, s), 3.0, True) def test_get_matrix_parameters(self): params = np.array([1.0, 2.0]) params_mat_test = np.array([[1.0, 1.5], [1.5, 2.0]]) params_mat = smap._get_matrix_parameters(params) self.assertAllClose(params_mat, params_mat_test, True) params_mat_direct = np.array([[1.0, 2.0], [3.0, 4.0]]) self.assertAllClose( smap._get_matrix_parameters(params_mat_direct), params_mat_direct, True) params_scalar = 1.0 self.assertAllClose( smap._get_matrix_parameters(params_scalar), params_scalar, True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_scalar(self, spatial_dimension, dtype): square = lambda dr: dr ** 2 displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric) metric = space.map_product(metric) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) self.assertAllClose( mapped_square(R), np.array(0.5 * np.sum(square(metric(R, R))), dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_scalar_dynamic(self, spatial_dimension, dtype): square = lambda dr, epsilon: epsilon * dr ** 2 displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric) metric = space.map_product(metric) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.uniform( split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) epsilon = random.uniform(split2, (PARTICLE_COUNT,), dtype=dtype) mat_epsilon = 0.5 * (epsilon[:, np.newaxis] + epsilon[np.newaxis, :]) self.assertAllClose( mapped_square(R, epsilon=epsilon), np.array(0.5 * np.sum( square(metric(R, R), mat_epsilon)), dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_no_species_vector(self, spatial_dimension, dtype): square = lambda dr: np.sum(dr ** 2, axis=2) disp, _ = space.free() mapped_square = smap.pair(square, disp) disp = space.map_product(disp) key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) mapped_ref = np.array(0.5 * np.sum(square(disp(R, R))), dtype=dtype) self.assertAllClose(mapped_square(R), mapped_ref, True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_static_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr ** 2 params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT,), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair( square, metric, species=species, param=params) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(metric(R_1, R_2), param)) self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_static_species_scalar_dynamic(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr ** 2 key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT,), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair(square, metric, species=species) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split1, split2 = random.split(key, 3) R = random.uniform( split1, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) params = random.uniform( split2, (2, 2), dtype=dtype) params = f32(0.5) * (params + params.T) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(metric(R_1, R_2), param)) self.assertAllClose( mapped_square(R, param=params), np.array(total, dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_scalar_dummy_arg( self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=f32(1.0), **unused_kwargs: param * dr ** 2 key, split = random.split(key) R = random.normal(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) displacement, shift = space.free() mapped = smap.pair(square, space.metric(displacement)) mapped(R, t=f32(0)) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_static_species_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * np.sum(dr ** 2, axis=2) params = np.array([[1.0, 2.0], [2.0, 3.0]], dtype=f32) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT,), 0, 2) disp, _ = space.free() mapped_square = smap.pair(square, disp, species=species, param=params) disp = space.map_product(disp) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(disp(R_1, R_2), param)) self.assertAllClose(mapped_square(R), np.array(total, dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_dynamic_species_scalar(self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * dr ** 2 params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT,), 0, 2) displacement, _ = space.free() metric = lambda Ra, Rb, **kwargs: \ np.sum(displacement(Ra, Rb, **kwargs) ** 2, axis=-1) mapped_square = smap.pair( square, metric, species=quantity.Dynamic, param=params) metric = space.map_product(metric) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(metric(R_1, R_2), param)) self.assertAllClose( mapped_square(R, species, 2), np.array(total, dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_dynamic_species_vector( self, spatial_dimension, dtype): key = random.PRNGKey(0) square = lambda dr, param=1.0: param * np.sum(dr ** 2, axis=2) params = f32(np.array([[1.0, 2.0], [2.0, 3.0]])) key, split = random.split(key) species = random.randint(split, (PARTICLE_COUNT,), 0, 2) disp, _ = space.free() mapped_square = smap.pair( square, disp, species=quantity.Dynamic, param=params) disp = vmap(vmap(disp, (0, None), 0), (None, 0), 0) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = random.uniform( split, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) total = 0.0 for i in range(2): for j in range(2): param = params[i, j] R_1 = R[species == i] R_2 = R[species == j] total = total + 0.5 * np.sum(square(disp(R_1, R_2), param)) self.assertAllClose( mapped_square(R, species, 2), np.array(total, dtype=dtype), True) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_neighbor_list_scalar( self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr ** 2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 4. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=2.5) neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_neighbor_list_scalar_params_no_species( self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr ** 2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N,), minval=0.5, maxval=1.5) neighbor_fn = jit( partition.neighbor_list(disp, box_size, np.max(sigma), R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_neighbor_list_scalar_params_matrix( self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr ** 2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit(smap.pair_neighbor_list(truncated_square, d)) mapped_square = jit(smap.pair(truncated_square, d)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (N, N), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = jit( partition.neighbor_list(disp, box_size, np.max(sigma), R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_neighbor_list_scalar_params_species( self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dr, sigma): return np.where(dr < sigma, dr ** 2, f32(0.)) tol = 2e-10 if dtype == np.float32 else None N = NEIGHBOR_LIST_PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) species = np.zeros((N,), np.int32) species = np.where(np.arange(N) > N / 3, 1, species) species = np.where(np.arange(N) > 2 * N / 3, 2, species) key, split = random.split(key) disp, _ = space.periodic(box_size) d = space.metric(disp) neighbor_square = jit( smap.pair_neighbor_list(truncated_square, d, species=species)) mapped_square = jit(smap.pair(truncated_square, d, species=species)) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform(split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (3, 3), minval=0.5, maxval=1.5) sigma = 0.5 * (sigma + sigma.T) neighbor_fn = jit( partition.neighbor_list(disp, box_size, np.max(sigma), R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol) @parameterized.named_parameters(jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format(dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_pair_neighbor_list_vector(self, spatial_dimension, dtype): key = random.PRNGKey(0) def truncated_square(dR, sigma): dr = np.reshape(space.distance(dR), dR.shape[:-1] + (1,)) return np.where(dr < sigma, dR ** 2, f32(0.)) tol = 5e-6 if dtype == np.float32 else 1e-14 N = PARTICLE_COUNT box_size = 2. * N ** (1. / spatial_dimension) key, split = random.split(key) disp, _ = space.periodic(box_size) neighbor_square = jit(smap.pair_neighbor_list( truncated_square, disp, reduce_axis=(1,))) mapped_square = jit(smap.pair(truncated_square, disp, reduce_axis=(1,))) for _ in range(STOCHASTIC_SAMPLES): key, split = random.split(key) R = box_size * random.uniform( split, (N, spatial_dimension), dtype=dtype) sigma = random.uniform(key, (), minval=0.5, maxval=1.5) neighbor_fn = jit(partition.neighbor_list(disp, box_size, sigma, R)) idx = neighbor_fn(R) self.assertAllClose(mapped_square(R, sigma=sigma), neighbor_square(R, idx, sigma=sigma), True, tol, tol)
class BatchTest(jtu.JaxTestCase): @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format( train, test, network, name, batch_size), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn, 'batch_size': batch_size } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items() for batch_size in [2, 8])) def testSerial(self, train_shape, test_shape, network, name, kernel_fn, batch_size): key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._serial(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) # We also exclude tests for dropout + parallel. It is not clear what is the # best way to handle this case. @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}'.format( train, test, network, name), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items())) def testParallel(self, train_shape, test_shape, network, name, kernel_fn): utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batch._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format( train, test, network, name, batch_size), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn, 'batch_size': batch_size } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items() for batch_size in [2, 8])) def testComposition(self, train_shape, test_shape, network, name, kernel_fn, batch_size): utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._parallel( batch._serial(kernel_fn, batch_size=batch_size)) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch._serial(batch._parallel(kernel_fn), batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_train_shape={}_test_shape={}_network={}_{}_batch_size={}'.format( train, test, network, name, batch_size), 'train_shape': train, 'test_shape': test, 'network': network, 'name': name, 'kernel_fn': kernel_fn, 'batch_size': batch_size } for train, test, network in zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK) for name, kernel_fn in KERNELS.items() for batch_size in [2, 8])) def testAutomatic(self, train_shape, test_shape, network, name, kernel_fn, batch_size): utils.stub_out_pmap(batch, 2) key = random.PRNGKey(0) key, self_split, other_split = random.split(key, 3) data_self = random.normal(self_split, train_shape) data_other = random.normal(other_split, test_shape) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) kernel_batched = batch.batch(kernel_fn, batch_size=batch_size, store_on_device=False) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other) def _test_analytic_kernel_composition(self, batching_fn): # Check Fully-Connected. rng = random.PRNGKey(0) rng_self, rng_other = random.split(rng) x_self = random.normal(rng_self, (8, 10)) x_other = random.normal(rng_other, (2, 10)) Block = stax.serial(stax.Dense(256), stax.Relu()) _, _, ker_fn = Block ker_fn = batching_fn(ker_fn) _, _, composed_ker_fn = stax.serial(Block, Block) ker_out = ker_fn(ker_fn(x_self)) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: # In the parallel setting, `x1_is_x2` is not computed correctly # when x1==x2. composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = ker_fn(ker_fn(x_self, x_other)) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) # Check convolutional + pooling. x_self = random.normal(rng, (8, 10, 10, 3)) x_other = random.normal(rng, (2, 10, 10, 3)) Block = stax.serial(stax.Conv(256, (2, 2)), stax.Relu()) Readout = stax.serial(stax.GlobalAvgPool(), stax.Dense(10)) block_ker_fn, readout_ker_fn = Block[2], Readout[2] _, _, composed_ker_fn = stax.serial(Block, Readout) block_ker_fn = batching_fn(block_ker_fn) readout_ker_fn = batching_fn(readout_ker_fn) ker_out = readout_ker_fn(block_ker_fn(x_self, marginalization='none')) composed_ker_out = composed_ker_fn(x_self) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) ker_out = readout_ker_fn( block_ker_fn(x_self, x_other, marginalization='none')) composed_ker_out = composed_ker_fn(x_self, x_other) if batching_fn == batch._parallel: composed_ker_out = composed_ker_out._replace( x1_is_x2=ker_out.x1_is_x2) self.assertAllClose(ker_out, composed_ker_out, True) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_on_device={}_batch_size={}'.format(store_on_device, batch_size), 'store_on_device': store_on_device, 'batch_size': batch_size } for store_on_device in [True, False] for batch_size in [2, 8])) def testAnalyticKernelComposeSerial(self, store_on_device, batch_size): self._test_analytic_kernel_composition( partial(batch._serial, batch_size=batch_size, store_on_device=store_on_device)) def testAnalyticKernelComposeParallel(self): utils.stub_out_pmap(batch, 2) self._test_analytic_kernel_composition(batch._parallel) @jtu.parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_on_device={}_batch_size={}'.format(store_on_device, batch_size), 'store_on_device': store_on_device, 'batch_size': batch_size } for store_on_device in [True, False] for batch_size in [2, 8])) def testAnalyticKernelComposeAutomatic(self, store_on_device, batch_size): utils.stub_out_pmap(batch, 2) self._test_analytic_kernel_composition( partial(batch.batch, batch_size=batch_size, store_on_device=store_on_device)) def test_jit_or_pmap_broadcast(self): def kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.65): res = np.abs(np.matmul(x1, x2)) if do_square: res *= res if do_flip: res = -res res *= random.uniform(keys) * p return [res, params] params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5]))) x2 = np.arange(0, 10).reshape((10, )) keys = random.PRNGKey(1) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=0) x1 = np.arange(0, 10).reshape((1, 10)) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=0): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=True, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=True) self.assertAllClose(res_1, res_2, True) utils.stub_out_pmap(batch, 1) x1 = np.arange(0, 10).reshape((1, 10)) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=1) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=1): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, _unused=False, p=0.65) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None) self.assertAllClose(res_1[0], res_2[0], True) self.assertAllClose( tree_map(partial(np.expand_dims, axis=0), res_1[1]), res_2[1], True) kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn, device_count=2) x1 = np.arange(0, 20).reshape((2, 10)) utils.stub_out_pmap(batch, 2) def broadcast(arg): return np.broadcast_to(arg, (2, ) + arg.shape) for do_flip in [True, False]: for do_square in [True, False]: with self.subTest(do_flip=do_flip, do_square=do_square, device_count=2): res_1 = kernel_fn(x1, x2, do_flip, keys, do_square, params, p=0.2) res_2 = kernel_fn_pmapped(x1, x2, do_flip, keys, do_square, params, _unused=None, p=0.2) self.assertAllClose(res_1[0][0], res_2[0][0], True) self.assertAllClose(res_1[0][1], res_2[0][1], True) self.assertAllClose(tree_map(broadcast, res_1[1]), res_2[1], True)
class ImageTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias } for dtype in float_dtypes for target_shape, image_shape in itertools. combinations_with_replacement([[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2) for method in ["nearest", "bilinear", "lanczos3", "lanczos5", "bicubic"] for antialias in [False, True])) @unittest.skipIf(not tf, "Test requires TensorFlow") def testResizeAgainstTensorFlow(self, dtype, image_shape, target_shape, method, antialias): # TODO(phawkins): debug this. There is a small mismatch between TF and JAX # for some cases of non-antialiased bicubic downscaling; we would expect # exact equality. if method == "bicubic" and any( x < y for x, y in zip(target_shape, image_shape)): raise unittest.SkipTest( "non-antialiased bicubic downscaling mismatch") rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype), ) def tf_fn(x): out = tf.image.resize(x.astype(np.float64), tf.constant(target_shape[1:-1]), method=method, antialias=antialias).numpy().astype(dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) self._CheckAgainstNumpy(tf_fn, jax_fn, args_maker, check_dtypes=True, tol={ np.float16: 2e-2, jnp.bfloat16: 1e-1, np.float32: 1e-4, np.float64: 1e-4 }) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method } for dtype in [np.float32] for target_shape, image_shape in itertools.combinations_with_replacement( [[3, 2], [6, 4], [33, 17], [50, 39]], 2) for method in ["nearest", "bilinear", "lanczos3", "bicubic"])) @unittest.skipIf(not PIL_Image, "Test requires PIL") def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method): rng = jtu.rand_uniform(self.rng()) args_maker = lambda: (rng(image_shape, dtype), ) def pil_fn(x): pil_methods = { "nearest": PIL_Image.NEAREST, "bilinear": PIL_Image.BILINEAR, "bicubic": PIL_Image.BICUBIC, "lanczos3": PIL_Image.LANCZOS, } img = PIL_Image.fromarray(x.astype(np.float32)) out = np.asarray(img.resize(target_shape[::-1], pil_methods[method]), dtype=dtype) return out jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=True) self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method } for dtype in inexact_dtypes for image_shape, target_shape in [ ([3, 1, 2], [6, 1, 4]), ([1, 3, 2, 1], [1, 6, 4, 1]), ] for method in ["nearest", "linear", "lanczos3", "lanczos5", "cubic"]) ) def testResizeUp(self, dtype, image_shape, target_shape, method): data = [64, 32, 32, 64, 50, 100] expected_data = {} expected_data["nearest"] = [ 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 32.0, 32.0, 64.0, 64.0, 32.0, 32.0, 64.0, 64.0, 50.0, 50.0, 100.0, 100.0, 50.0, 50.0, 100.0, 100.0 ] expected_data["linear"] = [ 64.0, 56.0, 40.0, 32.0, 56.0, 52.0, 44.0, 40.0, 40.0, 44.0, 52.0, 56.0, 36.5, 45.625, 63.875, 73.0, 45.5, 56.875, 79.625, 91.0, 50.0, 62.5, 87.5, 100.0 ] expected_data["lanczos3"] = [ 75.8294, 59.6281, 38.4313, 22.23, 60.6851, 52.0037, 40.6454, 31.964, 35.8344, 41.0779, 47.9383, 53.1818, 24.6968, 43.0769, 67.1244, 85.5045, 35.7939, 56.4713, 83.5243, 104.2017, 44.8138, 65.1949, 91.8603, 112.2413 ] expected_data["lanczos5"] = [ 77.5699, 60.0223, 40.6694, 23.1219, 61.8253, 51.2369, 39.5593, 28.9709, 35.7438, 40.8875, 46.5604, 51.7041, 21.5942, 43.5299, 67.7223, 89.658, 32.1213, 56.784, 83.984, 108.6467, 44.5802, 66.183, 90.0082, 111.6109 ] expected_data["cubic"] = [ 70.1453, 59.0252, 36.9748, 25.8547, 59.3195, 53.3386, 41.4789, 35.4981, 36.383, 41.285, 51.0051, 55.9071, 30.2232, 42.151, 65.8032, 77.731, 41.6492, 55.823, 83.9288, 98.1026, 47.0363, 62.2744, 92.4903, 107.7284 ] x = np.array(data, dtype=dtype).reshape(image_shape) output = image.resize(x, target_shape, method) expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape) self.assertAllClose(output, expected, atol=1e-04) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_target={}_method={}_antialias={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method, antialias), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "method": method, "antialias": antialias } for dtype in [np.float32] for target_shape, image_shape in itertools. combinations_with_replacement([[2, 3, 2, 4], [2, 6, 4, 4], [2, 33, 17, 4], [2, 50, 38, 4]], 2) for method in ["bilinear", "lanczos3", "lanczos5", "bicubic"] for antialias in [False, True])) def testResizeGradients(self, dtype, image_shape, target_shape, method, antialias): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(image_shape, dtype), ) jax_fn = partial(image.resize, shape=target_shape, method=method, antialias=antialias) jtu.check_grads(jax_fn, args_maker(), order=2, rtol=1e-2, eps=1.) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_target={}_method={}".format( jtu.format_shape_dtype_string(image_shape, dtype), jtu.format_shape_dtype_string(target_shape, dtype), method), "dtype": dtype, "image_shape": image_shape, "target_shape": target_shape, "scale": scale, "translation": translation, "method": method } for dtype in inexact_dtypes for image_shape, target_shape, scale, translation in [([3, 1, 2], [6, 1, 4], [2.0, 1.0, 2.0], [1.0, 0.0, -1.0]), ([1, 3, 2, 1], [1, 6, 4, 1], [1.0, 2.0, 2.0, 1.0], [0.0, 1.0, -1.0, 0.0])] for method in ["linear", "lanczos3", "lanczos5", "cubic"])) def testScaleAndTranslateUp(self, dtype, image_shape, target_shape, scale, translation, method): data = [64, 32, 32, 64, 50, 100] # Note zeros occur in the output because the sampling location is outside # the boundaries of the input image. expected_data = {} expected_data["linear"] = [ 0.0, 0.0, 0.0, 0.0, 56.0, 40.0, 32.0, 0.0, 52.0, 44.0, 40.0, 0.0, 44.0, 52.0, 56.0, 0.0, 45.625, 63.875, 73.0, 0.0, 56.875, 79.625, 91.0, 0.0 ] expected_data["lanczos3"] = [ 0.0, 0.0, 0.0, 0.0, 59.6281, 38.4313, 22.23, 0.0, 52.0037, 40.6454, 31.964, 0.0, 41.0779, 47.9383, 53.1818, 0.0, 43.0769, 67.1244, 85.5045, 0.0, 56.4713, 83.5243, 104.2017, 0.0 ] expected_data["lanczos5"] = [ 0.0, 0.0, 0.0, 0.0, 60.0223, 40.6694, 23.1219, 0.0, 51.2369, 39.5593, 28.9709, 0.0, 40.8875, 46.5604, 51.7041, 0.0, 43.5299, 67.7223, 89.658, 0.0, 56.784, 83.984, 108.6467, 0.0 ] expected_data["cubic"] = [ 0.0, 0.0, 0.0, 0.0, 59.0252, 36.9748, 25.8547, 0.0, 53.3386, 41.4789, 35.4981, 0.0, 41.285, 51.0051, 55.9071, 0.0, 42.151, 65.8032, 77.731, 0.0, 55.823, 83.9288, 98.1026, 0.0 ] x = np.array(data, dtype=dtype).reshape(image_shape) # Should we test different float types here? scale_a = jnp.array(scale, dtype=jnp.float32) translation_a = jnp.array(translation, dtype=jnp.float32) output = image.scale_and_translate(x, target_shape, range(len(image_shape)), scale_a, translation_a, method) expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_dtype={}_method={}_antialias={}".format( jtu.dtype_str(dtype), method, antialias), "dtype": dtype, "method": method, "antialias": antialias } for dtype in inexact_dtypes for method in ["linear", "lanczos3", "lanczos5", "cubic"] for antialias in [True, False])) def testScaleAndTranslateDown(self, dtype, method, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] if antialias: expected_data = {} expected_data["linear"] = [ 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 ] expected_data["lanczos3"] = [ 43.2884, 57.9091, 54.6439, 48.5856, 58.2427, 53.7551, 0, 0, 0 ] expected_data["lanczos5"] = [ 43.9209, 57.6360, 54.9575, 48.9272, 58.1865, 53.1948, 0, 0, 0 ] expected_data["cubic"] = [ 42.9935, 59.1687, 54.2138, 48.2640, 58.2678, 54.4088, 0, 0, 0 ] else: expected_data = {} expected_data["linear"] = [ 43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0 ] expected_data["lanczos3"] = [ 44.1390, 87.8786, 63.3111, 25.1161, 20.8795, 53.6165, 0, 0, 0 ] expected_data["lanczos5"] = [ 44.8835, 85.5896, 66.7231, 16.9983, 19.8891, 47.1446, 0, 0, 0 ] expected_data["cubic"] = [ 43.6426, 88.8854, 60.6638, 31.4685, 22.1204, 58.3457, 0, 0, 0 ] x = np.array(data, dtype=dtype).reshape(image_shape) expected = np.array(expected_data[method], dtype=dtype).reshape(target_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) output = image.scale_and_translate(x, target_shape, (0, 1, 2, 3), scale_a, translation_a, method, antialias=antialias) self.assertAllClose(output, expected, atol=2e-03) # Tests that running with just a subset of dimensions that have non-trivial # scale and translation. output = image.scale_and_translate(x, target_shape, (1, 2), scale_a[1:3], translation_a[1:3], method, antialias=antialias) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "antialias={}".format(antialias), "antialias": antialias } for antialias in [True, False])) def testScaleAndTranslateJITs(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] if antialias: expected_data = [ 43.5372, 59.3694, 53.6907, 49.3221, 56.8168, 55.4849, 0, 0, 0 ] else: expected_data = [ 43.6071, 89, 59, 37.1785, 27.2857, 58.3571, 0, 0, 0 ] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) expected = jnp.array(expected_data, dtype=jnp.float32).reshape(target_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def jit_fn(in_array, s, t): return jax.image.scale_and_translate( in_array, target_shape, (0, 1, 2, 3), s, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST) output = jax.jit(jit_fn)(x, scale_a, translation_a) self.assertAllClose(output, expected, atol=2e-03) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "antialias={}".format(antialias), "antialias": antialias } for antialias in [True, False])) def testScaleAndTranslateGradFinite(self, antialias): image_shape = [1, 6, 7, 1] target_shape = [1, 3, 3, 1] data = [ 51, 38, 32, 89, 41, 21, 97, 51, 33, 87, 89, 34, 21, 97, 43, 25, 25, 92, 41, 11, 84, 11, 55, 111, 23, 99, 50, 83, 13, 92, 52, 43, 90, 43, 14, 89, 71, 32, 23, 23, 35, 93 ] x = jnp.array(data, dtype=jnp.float32).reshape(image_shape) scale_a = jnp.array([1.0, 0.35, 0.4, 1.0], dtype=jnp.float32) translation_a = jnp.array([0.0, 0.2, 0.1, 0.0], dtype=jnp.float32) def scale_fn(s): return jnp.sum( jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), s, translation_a, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) scale_out = jax.grad(scale_fn)(scale_a) self.assertTrue(jnp.all(jnp.isfinite(scale_out))) def translate_fn(t): return jnp.sum( jax.image.scale_and_translate( x, target_shape, (0, 1, 2, 3), scale_a, t, "linear", antialias, precision=jax.lax.Precision.HIGHEST)) translate_out = jax.grad(translate_fn)(translation_a) self.assertTrue(jnp.all(jnp.isfinite(translate_out)))
class Jax2TfTest(tf_test_util.JaxToTfTestCase): def test_basics(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) _, res_tf = self.ConvertAndCompare(f_jax, jnp.float_(0.7)) def test_input_output_naming(self): @jax2tf.convert def f(xs, y): return [jnp.add(x, y) for x in xs] @tf.function(autograph=False) def u(xs, y): xs = tf.nest.map_structure(tf.convert_to_tensor, xs) with tf.GradientTape() as tape: tf.nest.map_structure(tape.watch, xs) y = f(xs, y) tape.gradient(y, xs) return y cf = u.get_concrete_function([1., 2., 3.], 4.) g = cf.graph g.get_operation_by_name("jax2tf_arg_0") g.get_operation_by_name("jax2tf_arg_0_1") g.get_operation_by_name("jax2tf_arg_0_2") g.get_operation_by_name("jax2tf_arg_1") g.get_operation_by_name("jax2tf_out") g.get_operation_by_name("jax2tf_out_1") g.get_operation_by_name("jax2tf_out_2") with self.assertRaises(KeyError): g.get_operation_by_name("jax2tf_arg_2") with self.assertRaises(KeyError): g.get_operation_by_name("jax2tf_out_3") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_0") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1_1") g.get_operation_by_name("jax2tf_vjp/jax2tf_arg_1_2") g.get_operation_by_name("jax2tf_vjp/jax2tf_out") g.get_operation_by_name("jax2tf_vjp/jax2tf_out_1") g.get_operation_by_name("jax2tf_vjp/jax2tf_out_2") g.get_operation_by_name("jax2tf_vjp/jax2tf_out_3") def test_pytrees(self): # Take and return pytrees def f_jax( x: Tuple[float, Dict[str, float]]) -> Tuple[float, Dict[str, float]]: x_a, x_dict = x return x_a * 2., {k: v * 3. for k, v in x_dict.items()} x = (jnp.float_(.7), {"a": jnp.float_(.8), "b": jnp.float_(.9)}) self.ConvertAndCompare(f_jax, x) def test_variable_input(self): f_jax = lambda x: jnp.sin(jnp.cos(x)) f_tf = jax2tf.convert(f_jax) v = tf.Variable(0.7, dtype=dtypes.canonicalize_dtype(jnp.float_)) self.assertIsInstance(f_tf(v), tf.Tensor) self.assertAllClose(f_jax(0.7), f_tf(v)) def test_jit(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) self.ConvertAndCompare(f_jax, jnp.float_(0.7)) def test_nested_jit(self): f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x))) f_tf = jax2tf.convert(f_jax) np.testing.assert_allclose(f_jax(0.7), f_tf(0.7)) def test_converts_jax_arrays(self): f_tf = tf.function(lambda x: x) self.assertEqual(f_tf(jnp.zeros([])).numpy(), 0.) self.assertEqual(f_tf(jnp.ones([])).numpy(), 1.) f_tf = tf.function(lambda x: x + x) self.assertEqual(f_tf(jnp.ones([])).numpy(), 2.) # Test with ShardedDeviceArray. n = jax.local_device_count() mk_sharded = lambda f: jax.pmap(lambda x: x)(f([n])) f_tf = tf.function(lambda x: x) self.assertAllClose(f_tf(mk_sharded(jnp.zeros)).numpy(), np.zeros([n])) self.assertAllClose(f_tf(mk_sharded(jnp.ones)).numpy(), np.ones([n])) @jtu.skip_on_devices("gpu") def test_bfloat16_passed_by_tf(self): f_jax = lambda a, b: a + b f_tf = tf.function(jax2tf.convert(f_jax), input_signature=[ tf.TensorSpec([512, 512], tf.bfloat16), tf.TensorSpec([512, 512], tf.bfloat16) ]) self.assertIsNotNone(f_tf.get_concrete_function()) @jtu.skip_on_devices("gpu") def test_bfloat16_returned_by_jax(self): f_jax = lambda a, b: (a + b).astype(jnp.bfloat16) f_tf = jax2tf.convert(f_jax) self.assertEqual(f_tf(1., 2.).dtype, tf.bfloat16) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_dtype={dtype.__name__}", dtype=dtype) for dtype in [np.int64, np.float64])) def test_converts_64bit(self, dtype=np.int64, with_function=False): if not config.jax_enable_x64: self.skipTest("requires x64 mode") big_const = np.full((5, ), 2**33, dtype=dtype) self.ConvertAndCompare(jnp.sin, big_const) f_conv = jax2tf.convert(jnp.sin) if with_function: f_conv = tf.function(f_conv) # We check also when we pass tf.Variable or tf.Tensor into the # converted function self.assertAllClose(jnp.sin(big_const), f_conv(tf.Variable(big_const))) self.assertAllClose(jnp.sin(big_const), f_conv(tf.constant(big_const))) def test_function(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) self.ConvertAndCompare(f_jax, jnp.float_(0.7)) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_disabled(self, with_function=False): f_tf = jax2tf.convert(jnp.tan, with_gradient=False) if with_function: f_tf = tf.function(f_tf, autograph=False) x = tf.ones([]) # With tf.function the error is raised when we evaluate f_tf(x), in # eager mode when we evaluate tape.gradient(y, x) with self.assertRaisesRegex( LookupError, "Gradient explicitly disabled.*The jax2tf-converted function does not support gradients" ): with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) _ = tape.gradient(y, x) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients(self, with_function=True): def f(x, y): return x * x, x * y f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_type = dtypes.canonicalize_dtype(jnp.float_) x = tf.Variable(4., dtype=default_float_type) y = tf.Variable(5., dtype=default_float_type) with tf.GradientTape(persistent=True) as tape: u, v = f_tf(x, y) self.assertAllClose(2. * 4., tape.gradient(u, x)) self.assertAllClose(0., tape.gradient(u, y)) self.assertAllClose(5., tape.gradient(v, x)) self.assertAllClose(4., tape.gradient(v, y)) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_pytree(self, with_function=True): def f(xy: Tuple[float, float]) -> Dict[str, float]: x, y = xy return dict(one=x * x, two=x * y) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) default_float_dtype = dtypes.canonicalize_dtype(jnp.float_) x = tf.Variable(4., dtype=default_float_dtype) y = tf.Variable(5., dtype=default_float_dtype) with tf.GradientTape(persistent=True) as tape: uv = f_tf((x, y)) self.assertAllClose(2. * 4., tape.gradient(uv["one"], x)) self.assertAllClose(0., tape.gradient(uv["one"], y)) self.assertAllClose(5., tape.gradient(uv["two"], x)) self.assertAllClose(4., tape.gradient(uv["two"], y)) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_with_custom_jvp(self, with_function=True): """Check gradients, for a function with custom JVP.""" @jax.custom_jvp def f(x): return x * x @f.defjvp def f_jvp(primals, tangents): # 3 * x * x_t x, = primals x_dot, = tangents primal_out = f(x) tangent_out = 3. * x * x_dot return primal_out, tangent_out self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(jnp.float_(4.))) x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x)) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"function={with_function}", with_function=with_function) for with_function in [False, True])) def test_gradients_with_custom_vjp(self, with_function=True): """Check gradients, for a function with custom VJP.""" @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), 3. * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) self.assertAllClose(4. * 4., f(4.)) self.assertAllClose(3. * 4., jax.grad(f)(4.)) f_tf = jax2tf.convert(f, with_gradient=True) if with_function: f_tf = tf.function(f_tf, autograph=False) self.assertAllClose(4. * 4., f_tf(jnp.float_(4.))) x = tf.Variable(4., dtype=dtypes.canonicalize_dtype(jnp.float_)) with tf.GradientTape() as tape: tape.watch(x) y = f_tf(x) self.assertAllClose(4. * 4., y) self.assertAllClose(3. * 4., tape.gradient(y, x)) def test_convert_argument_non_callable_error(self): with self.assertRaisesRegex(TypeError, "Expected a callable value"): jax2tf.convert(5.) def test_convert_argument_non_tensor_error(self): with self.assertRaisesRegex(TypeError, "Argument.*should be NumPy array"): jax2tf.convert(lambda x: x)(lambda y: y) def test_argument_eager_tensor(self): x = jax2tf.convert(jnp.sin)(1.) jax2tf.convert(jnp.cos)(x) # No error def test_checkpoint_wrapper_types(self): m = tf.Module() m.a = [tf.Module(), tf.Module()] m.b = (tf.Module(), tf.Module()) m.c = {'a': tf.Module(), 'b': tf.Module()} self.assertNotEqual(type(m.a), list) self.assertNotEqual(type(m.b), tuple) self.assertNotEqual(type(m.c), dict) self.assertLen(jax.tree_leaves(m.a), 2) self.assertLen(jax.tree_leaves(m.b), 2) self.assertLen(jax.tree_leaves(m.c), 2) def test_custom_jvp(self): """Conversion of function with custom JVP""" @jax.custom_jvp def f(x): return x * x @f.defjvp def f_jvp(primals, tangents): x, = primals x_dot, = tangents primal_out = f(x) tangent_out = 3. * x * x_dot return primal_out, tangent_out arg = jnp.float_(0.7) self.TransformConvertAndCompare(f, arg, None) self.TransformConvertAndCompare(f, arg, "jvp") self.TransformConvertAndCompare(f, arg, "vmap") self.TransformConvertAndCompare(f, arg, "jvp_vmap") self.TransformConvertAndCompare(f, arg, "grad") self.TransformConvertAndCompare(f, arg, "grad_vmap") def test_custom_vjp(self): """Conversion of function with custom VJP""" @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), 3. * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) arg = jnp.float_(0.7) self.TransformConvertAndCompare(f, arg, None) self.TransformConvertAndCompare(f, arg, "vmap") self.TransformConvertAndCompare(f, arg, "grad") self.TransformConvertAndCompare(f, arg, "grad_vmap") def test_remat1(self): @jax.remat def f(x1): x2 = jnp.sin(x1) x3 = jnp.sin(x2) x4 = jnp.sin(x3) return jnp.sum(x4) # The computation of grad_f computes "sin" 5 times, 3 for the forward pass # and then to rematerialize "x2" and "x3" in the backward pass. arg = np.arange(3.) self.TransformConvertAndCompare(f, arg, "grad") # TODO: check that the TF code also computes "sin" 5 times def test_remat_free_var(self): def f(x): y = 2 * x @jax.remat def g(): return y return g() arg = jnp.float_(3.) self.TransformConvertAndCompare(f, arg, None) self.TransformConvertAndCompare(f, arg, "grad") def test_convert_nullary_func(self): # Even nullary functions are converted to TF (as opposed to constant-folded # in JAX prior to conversion). def f_jax(): return jnp.sin(1.) f_tf = tf.function(jax2tf.convert(f_jax), autograph=False) f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def() self.assertIn('op: "Sin"', str(f_tf_graph)) def test_convert_of_nested_independent_jit(self): def func(x): def inner1(y): return x + y # The JIT does not have data dependency return jax.jit(inner1)(1.) jax2tf.convert(func)(2.) def test_convert_of_nested_dependent_jit(self): def func(x): def inner1(y): return x + y # The JIT does have data dependency return jax.jit(inner1)(x) jax2tf.convert(func)(2.) # No error def test_nested_convert_error(self): def outer(y): return jax2tf.convert(jnp.sin)( y) # Inner convert takes tracer args with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): jax2tf.convert(outer)(np.ones((4, ))) def test_nested_convert_error_non_tracer(self): """The inner convert takes non-tracer arguments""" def outer(y): sin_1 = jax2tf.convert(jnp.sin)( 1.) # Inner convert takes non-tracer arg return y + sin_1 with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): jax2tf.convert(outer)(2.) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{transform}", transform=transform) for transform in ["jit", "jvp", "grad", "vmap"])) def test_convert_under_transform_error(self, transform="vmap"): def outer(y): return jax2tf.convert(jnp.sin)( y) # Inner convert takes tracer args with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): self.TransformConvertAndCompare(outer, np.ones((4, )), transform) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{transform}", transform=transform) for transform in ["jit", "jvp", "grad", "vmap"])) def test_convert_under_transform_error_non_tracer(self, transform="vmap"): def outer(y): sin_1 = jax2tf.convert(jnp.sin)( 1.) # Inner convert takes non-tracer arg return y + sin_1 with self.assertRaisesRegex( ValueError, "convert must be used outside all JAX transformations"): self.TransformConvertAndCompare(outer, np.ones((4, )), transform) def test_name_scope(self): log = [] @jax.named_call def my_test_function(x): y = tf.Variable(1., name="foo") log.append(y.name) return x * x jax2tf.convert(my_test_function)(2) self.assertIn("my_test_function/foo", log[0]) def test_bfloat16_constant(self): # Re: https://github.com/google/jax/issues/3942 def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. return x def jax_fn_array(x): x = x.astype(jnp.bfloat16) x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16) return x tf_fn_scalar = jax2tf.convert(jax_fn_scalar) self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750)) tf_fn_array = jax2tf.convert(jax_fn_array) self.assertAllClose(tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5], jnp.bfloat16))
class FftTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inverse={}_shape={}_axes={}".format( inverse, jtu.format_shape_dtype_string(shape, dtype), axes), "axes": axes, "shape": shape, "dtype": dtype, "rng_factory": rng_factory, "inverse": inverse } for inverse in [False, True] for rng_factory in [jtu.rand_default] for dtype in all_dtypes for shape in [(10, ), (10, 10), (2, 3, 4), (2, 3, 4, 5)] for axes in _get_fftn_test_axes(shape))) def testFftn(self, inverse, shape, dtype, axes, rng_factory): rng = rng_factory() args_maker = lambda: (rng(shape, dtype), ) np_op = np.fft.ifftn if inverse else np.fft.fftn onp_op = onp.fft.ifftn if inverse else onp.fft.fftn np_fn = lambda a: np_op(a, axes=axes) onp_fn = lambda a: onp_op(a, axes=axes) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=False, tol=1e-4) self._CompileAndCheck(np_fn, args_maker, check_dtypes=True) # Test gradient for differentiable types. if dtype in inexact_dtypes: # TODO(skye): can we be more precise? tol = 1e-1 jtu.check_grads(np_fn, args_maker(), order=1, atol=tol, rtol=tol) jtu.check_grads(np_fn, args_maker(), order=2, atol=tol, rtol=tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inverse={}".format(inverse), "inverse": inverse } for inverse in [False, True])) def testFftnErrors(self, inverse): rng = jtu.rand_default() name = 'ifftn' if inverse else 'fftn' func = np.fft.ifftn if inverse else np.fft.fftn self.assertRaisesRegex( ValueError, "jax.np.fft.{} only supports 1D, 2D, and 3D FFTs. " "Got axes None with input rank 4.".format(name), lambda: func(rng([2, 3, 4, 5], dtype=onp.float64), axes=None)) self.assertRaisesRegex( ValueError, "jax.np.fft.{} does not support repeated axes. Got axes \\[1, 1\\]." .format(name), lambda: func(rng([2, 3], dtype=onp.float64), axes=[1, 1])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[2])) self.assertRaises( ValueError, lambda: func(rng([2, 3], dtype=onp.float64), axes=[-3]))
class LaxBackedScipySignalTests(jtu.JaxTestCase): """Tests for LAX-backed scipy.stats implementations""" @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_op={}_xshape=[{}]_yshape=[{}]_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve', 'correlate'] for dtype in default_dtypes for xshape in onedim_shapes for yshape in onedim_shapes)) def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-8} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format( op, jtu.format_shape_dtype_string(xshape, dtype), jtu.format_shape_dtype_string(yshape, dtype), mode), "xshape": xshape, "yshape": yshape, "dtype": dtype, "mode": mode, "jsp_op": getattr(jsp_signal, op), "osp_op": getattr(osp_signal, op) } for mode in ['full', 'same', 'valid'] for op in ['convolve2d', 'correlate2d'] for dtype in default_dtypes for xshape in twodim_shapes for yshape in twodim_shapes)) def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(xshape, dtype), rng(yshape, dtype)] osp_fun = partial(osp_op, mode=mode) jsp_fun = partial(jsp_op, mode=mode, precision=lax.Precision.HIGHEST) tol = {onp.float16: 1e-2, onp.float32: 1e-2, onp.float64: 1e-14} self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True)
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): def test_primitive_coverage(self): """Fail if there are JAX primitives that are not implemented.""" # Harvest primitives from XLA translation tables all_primitives = (set(xla.translations) | set(xla.backend_specific_translations['cpu']) | set(xla.backend_specific_translations['gpu']) | set(xla.backend_specific_translations['tpu']) | set(xla.initial_style_translations) | set(xla.parallel_translations)) tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set(jax.experimental.jax2tf.jax2tf.tf_impl_with_avals) tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl) all_primitives = tuple(sorted(all_primitives, key=str)) for p in all_primitives: # TODO: remove tie_in once omnistaging is on by default if p.name == "axis_index" or p.name == "tie_in": continue if p in tf_not_yet_impl: self.assertNotIn(p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl else: self.assertIn(p, tf_impl) @parameterized.named_parameters( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less, jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal, jnp.not_equal, jnp.maximum, jnp.minimum]) def test_type_promotion(self, f_jax=jnp.add): # We only test a few types here, as tensorflow does not support many # types like uint* or bool in binary ops. types = [dtypes.bfloat16, np.int32, np.int64, np.float32] for x_dtype in types: for y_dtype in types: x = np.array([1, 2], dtype=x_dtype) y = np.array([3, 4], dtype=y_dtype) self.ConvertAndCompare(f_jax, x, y) def test_concat(self): values = [np.array([1, 2], dtype=np.float32), np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int8)] f_jax = jax.jit(lambda x: jnp.concatenate(x, axis=0)) self.ConvertAndCompare(f_jax, values) @primitive_harness.parameterized(primitive_harness.lax_pad) def test_pad(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_top_k) def test_top_k(self, harness: primitive_harness.Harness): if (harness.params["k"] > harness.params["shape"][-1] or harness.params["k"] < 0): with self.assertRaisesRegex(ValueError, "k argument to top_k must be"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) elif harness.params["dtype"] in jtu.dtypes.complex: # TODO(necula): fix top_k complex bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("top_k complex on TPU raises different error") with self.assertRaisesRegex(RuntimeError, "Unimplemented: complex comparison"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) # TODO: TF and JAX sort [inf, nan] differently. elif harness.name.startswith("nan_"): raise unittest.SkipTest("inconsistent [nan, inf] sorting") else: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_sort) def test_sort(self, harness: primitive_harness.Harness): if (jtu.device_under_test() == "gpu" and len(harness.arg_descriptors) == 4 and not harness.params["is_stable"]): # TODO: fix the TF GPU test raise unittest.SkipTest("GPU tests are running TF on CPU") if jtu.device_under_test() == "tpu" and harness.params["dtype"] in jtu.dtypes.complex: raise unittest.SkipTest("JAX sort is not implemented on TPU for complex") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_fft) @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_fft(self, harness: primitive_harness.Harness): if len(harness.params["fft_lengths"]) > 3: with self.assertRaisesRegex(RuntimeError, "FFT only supports ranks 1-3"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) elif (jtu.device_under_test() == "tpu" and len(harness.params["fft_lengths"]) > 1): # TODO(b/140351181): FFT is mostly unimplemented on TPU, even for JAX with self.assertRaisesRegex(RuntimeError, "only 1D FFT is currently supported."): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) else: tol = None if jtu.device_under_test() == "gpu": if harness.params["dtype"] in jtu.dtypes.boolean: tol = 0.01 else: tol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol) @primitive_harness.parameterized(primitive_harness.lax_linalg_cholesky) def test_cholesky(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype in [dtypes.bfloat16, np.float16]: raise unittest.SkipTest("Cholesky decomposition not supported for " "(b)float16 in JAX.") operand = harness.dyn_args_maker(self.rng())[0] operand = np.matmul(operand, jnp.conj(np.swapaxes(operand, -1, -2))) tol = None # TODO(bchetioui): very high discrepancy in the float32/complex64 case if dtype in [np.float32, np.complex64]: tol = 1e-2 # TODO(bchetioui): also high discrepancy in the float64/complex128 case elif dtype in [np.float64, np.complex128]: tol = 1e-11 def custom_assert(result_jax, result_tf): # cholesky_p returns garbage in the strictly upper triangular part of the # result, so we can safely ignore that part. self.assertAllClose(jnp.tril(result_jax), result_tf, atol=tol) self.ConvertAndCompare(harness.dyn_fun, operand, custom_assert=custom_assert, always_custom_assert=True) @primitive_harness.parameterized(primitive_harness.lax_linalg_qr) def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types dtype = harness.params["dtype"] dut = jtu.device_under_test() # These cases are not implemented in JAX if dtype in (jtu.dtypes.all_integer + [jnp.bfloat16]): unimplemented_jax = True elif dtype is np.complex64 and dut == "tpu": unimplemented_jax = True elif dtype is np.float16 and dut in ("cpu", "gpu"): unimplemented_jax = True else: unimplemented_jax = False if unimplemented_jax: raise unittest.SkipTest(f"QR not implemented in JAX for {dtype} on {dut}") # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) @primitive_harness.parameterized(primitive_harness.lax_linalg_svd) @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_svd(self, harness: primitive_harness.Harness): if harness.params["dtype"] in [np.float16, dtypes.bfloat16]: if jtu.device_under_test() != "tpu": # Does not work in JAX with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return if harness.params["dtype"] in [np.complex64, np.complex128]: if jtu.device_under_test() == "tpu": # TODO: on JAX on TPU there is no SVD implementation for complex with self.assertRaisesRegex(RuntimeError, "Binary op compare with different element types"): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) return def _custom_assert(r_jax, r_tf, atol=1e-6, rtol=1e-6): def _reconstruct_operand(result, is_tf: bool): # Reconstructing operand as documented in numpy.linalg.svd (see # https://numpy.org/doc/stable/reference/generated/numpy.linalg.svd.html) s, u, v = result if is_tf: s = s.numpy() u = u.numpy() v = v.numpy() U = u[..., :s.shape[-1]] V = v[..., :s.shape[-1], :] S = s[..., None, :] return jnp.matmul(U * S, V), s.shape, u.shape, v.shape if harness.params["compute_uv"]: r_jax_reconstructed = _reconstruct_operand(r_jax, False) r_tf_reconstructed = _reconstruct_operand(r_tf, True) self.assertAllClose(r_jax_reconstructed, r_tf_reconstructed, atol=atol, rtol=rtol) else: self.assertAllClose(r_jax, r_tf, atol=atol, rtol=rtol) tol = 1e-4 custom_assert = partial(_custom_assert, atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol, custom_assert=custom_assert, always_custom_assert=True) @primitive_harness.parameterized(primitive_harness.lax_select_and_gather_add) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def test_select_and_gather_add(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_reduce_window) def test_reduce_window(self, harness: primitive_harness.Harness): dtype = harness.params['dtype'] if (jtu.device_under_test() == 'tpu' and dtype is np.complex64): raise unittest.SkipTest( 'TODO: JAX reduce_window on TPU does not handle complex64' ) self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_linalg_eig) def test_eig(self, harness: primitive_harness.Harness): operand = harness.dyn_args_maker(self.rng())[0] compute_left_eigenvectors = harness.params["compute_left_eigenvectors"] compute_right_eigenvectors = harness.params["compute_right_eigenvectors"] dtype = harness.params["dtype"] if jtu.device_under_test() != "cpu": raise unittest.SkipTest("eig only supported on CPU in JAX") if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("eig unsupported with (b)float16 in JAX") def custom_assert(result_jax, result_tf): result_tf = tuple(map(lambda e: e.numpy(), result_tf)) inner_dimension = operand.shape[-1] # Test ported from tests.lax_test.testEig # Norm, adjusted for dimension and type. def norm(x): norm = np.linalg.norm(x, axis=(-2, -1)) return norm / ((inner_dimension + 1) * jnp.finfo(dtype).eps) def check_right_eigenvectors(a, w, vr): self.assertTrue( np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100)) def check_left_eigenvectors(a, w, vl): rank = len(a.shape) aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2])) wC = jnp.conj(w) check_right_eigenvectors(aH, wC, vl) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None # TODO(bchetioui): numerical discrepancies if dtype in [np.float32, np.complex64]: tol = 1e-4 elif dtype in [np.float64, np.complex128]: tol = 1e-13 closest_diff = min(abs(eigenvalues_array - eigenvalue)) self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype), atol=tol) all_w_jax, all_w_tf = result_jax[0], result_tf[0] for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) if compute_left_eigenvectors: check_left_eigenvectors(operand, all_w_tf, result_tf[1]) if compute_right_eigenvectors: check_right_eigenvectors(operand, all_w_tf, result_tf[1 + compute_left_eigenvectors]) self.ConvertAndCompare(harness.dyn_fun, operand, custom_assert=custom_assert) @primitive_harness.parameterized(primitive_harness.lax_linalg_eigh) def test_eigh(self, harness: primitive_harness.Harness): operand = harness.dyn_args_maker(self.rng())[0] lower = harness.params["lower"] # Make operand self-adjoint operand = (operand + np.conj(np.swapaxes(operand, -1, -2))) / 2 # Make operand lower/upper triangular triangular_operand = np.tril(operand) if lower else np.triu(operand) dtype = harness.params["dtype"] if (dtype in [np.complex64, np.complex128] and jtu.device_under_test() == "tpu"): raise unittest.SkipTest("TODO: complex eigh not supported on TPU in JAX") def custom_assert(result_jax, result_tf): result_tf = tuple(map(lambda e: e.numpy(), result_tf)) inner_dimension = operand.shape[-1] def check_right_eigenvectors(a, w, vr): tol = 1e-16 # TODO(bchetioui): tolerance needs to be very high in compiled mode, # specifically for eigenvectors. if dtype == np.float64: tol = 1e-6 elif dtype == np.float32: tol = 1e-2 elif dtype in [dtypes.bfloat16, np.complex64]: tol = 1e-3 elif dtype == np.complex128: tol = 1e-13 self.assertAllClose(np.matmul(a, vr) - w[..., None, :] * vr, np.zeros(a.shape, dtype=vr.dtype), atol=tol) def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): tol = None if dtype in [dtypes.bfloat16, np.float32, np.complex64]: tol = 1e-3 elif dtype in [np.float64, np.complex128]: tol = 1e-11 closest_diff = min(abs(eigenvalues_array - eigenvalue)) self.assertAllClose(closest_diff, np.array(0., closest_diff.dtype), atol=tol) _, all_w_jax = result_jax all_vr_tf, all_w_tf = result_tf for idx in itertools.product(*map(range, operand.shape[:-2])): w_jax, w_tf = all_w_jax[idx], all_w_tf[idx] for i in range(inner_dimension): check_eigenvalue_is_in_array(w_jax[i], w_tf) check_eigenvalue_is_in_array(w_tf[i], w_jax) check_right_eigenvectors(operand, all_w_tf, all_vr_tf) # On CPU and GPU, JAX makes custom calls always_custom_assert = True # On TPU, JAX calls xops.Eigh if jtu.device_under_test == "tpu": always_custom_assert = False self.ConvertAndCompare(harness.dyn_fun, triangular_operand, custom_assert=custom_assert, always_custom_assert=always_custom_assert) @primitive_harness.parameterized( primitive_harness.lax_linalg_triangular_solve) def test_triangular_solve(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] if dtype == np.float16 and jtu.device_under_test() == "gpu": raise unittest.SkipTest( f"Triangular solve is not implemented in JAX for dtype {dtype}") atol = rtol = None if dtype == np.float32: atol = rtol = 1e-5 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=atol, rtol=rtol) @primitive_harness.parameterized(primitive_harness.lax_unary_elementwise) def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] lax_name = harness.params["lax_name"] arg, = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "digamma": # TODO(necula): fix bug with digamma/(f32|f16) on TPU if dtype in [np.float16, np.float32] and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") # In the bfloat16 case, TF and lax both return NaN in undefined cases. if not dtype is dtypes.bfloat16: # digamma is not defined at 0 and -1 def custom_assert(result_jax, result_tf): # lax.digamma returns NaN and tf.math.digamma returns inf special_cases = (arg == 0.) | (arg == -1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan)), result_jax[special_cases]) self.assertAllClose(np.full((nr_special_cases,), dtype(np.inf)), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) if lax_name == "erf_inv": # TODO(necula): fix erf_inv bug on TPU if jtu.device_under_test() == "tpu": raise unittest.SkipTest("erf_inv bug on TPU: nan vs non-nan") # TODO: investigate: in the (b)float16 cases, TF and lax both return the # same result in undefined cases. if not dtype in [np.float16, dtypes.bfloat16]: # erf_inv is not defined for arg <= -1 or arg >= 1 def custom_assert(result_jax, result_tf): # noqa: F811 # for arg < -1 or arg > 1 # lax.erf_inv returns NaN; tf.math.erf_inv return +/- inf special_cases = (arg < -1.) | (arg > 1.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), dtype(np.nan), dtype=dtype), result_jax[special_cases]) signs = np.where(arg[special_cases] < 0., -1., 1.) self.assertAllClose(np.full((nr_special_cases,), signs * dtype(np.inf), dtype=dtype), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) atol = None if jtu.device_under_test() == "gpu": # TODO(necula): revisit once we fix the GPU tests atol = 1e-3 self.ConvertAndCompare(harness.dyn_fun, arg, custom_assert=custom_assert, atol=atol) @primitive_harness.parameterized(primitive_harness.lax_bitwise_not) def test_bitwise_not(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_population_count) def test_population_count(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_add_mul) def test_add_mul(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_min_max) def test_min_max(self, harness: primitive_harness.Harness): # TODO(bchetioui): discrepancies between TF & JAX when comparing with NaN; # JAX always returns NaN, while TF returns the value NaN is compared with. def custom_assert(result_jax, result_tf): mask = np.isnan(result_jax) self.assertAllClose(result_jax[~ mask], result_tf[~ mask]) # TODO(bchetioui): figure out why we need always_custom_assert=True always_custom_assert = True self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), custom_assert=custom_assert, always_custom_assert=always_custom_assert) @primitive_harness.parameterized(primitive_harness.lax_binary_elementwise) def test_binary_elementwise(self, harness): tol = None lax_name, dtype = harness.params["lax_name"], harness.params["dtype"] if lax_name in ("igamma", "igammac"): # TODO(necula): fix bug with igamma/f16 if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("TODO: igamma(c) unsupported with (b)float16 in JAX") # TODO(necula): fix bug with igamma/f32 on TPU if dtype is np.float32 and jtu.device_under_test() == "tpu": raise unittest.SkipTest("TODO: fix bug: nan vs not-nan") arg1, arg2 = harness.dyn_args_maker(self.rng()) custom_assert = None if lax_name == "igamma": # igamma is not defined when the first argument is <=0 def custom_assert(result_jax, result_tf): # lax.igamma returns NaN when arg1 == arg2 == 0; tf.math.igamma returns 0 special_cases = (arg1 == 0.) & (arg2 == 0.) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), np.nan, dtype=dtype), result_jax[special_cases]) self.assertAllClose(np.full((nr_special_cases,), 0., dtype=dtype), result_tf[special_cases]) # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases]) if lax_name == "igammac": # On GPU, tolerance also needs to be adjusted in compiled mode if dtype == np.float64 and jtu.device_under_test() == 'gpu': tol = 1e-14 # igammac is not defined when the first argument is <=0 def custom_assert(result_jax, result_tf): # noqa: F811 # lax.igammac returns 1. when arg1 <= 0; tf.math.igammac returns NaN special_cases = (arg1 <= 0.) | (arg2 <= 0) nr_special_cases = np.count_nonzero(special_cases) self.assertAllClose(np.full((nr_special_cases,), 1., dtype=dtype), result_jax[special_cases]) self.assertAllClose(np.full((nr_special_cases,), np.nan, dtype=dtype), result_tf[special_cases]) # On CPU, tolerance only needs to be adjusted in eager & graph modes tol = None if dtype == np.float64: tol = 1e-14 # non-special cases are equal self.assertAllClose(result_jax[~ special_cases], result_tf[~ special_cases], atol=tol, rtol=tol) self.ConvertAndCompare(harness.dyn_fun, arg1, arg2, custom_assert=custom_assert, atol=tol, rtol=tol) @primitive_harness.parameterized(primitive_harness.lax_binary_elementwise_logical) def test_binary_elementwise_logical(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_betainc) def test_betainc(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"] # TODO: https://www.tensorflow.org/api_docs/python/tf/math/betainc only # supports float32/64 tests. # TODO(bchetioui): investigate why the test actually fails in JAX. if dtype in [np.float16, dtypes.bfloat16]: raise unittest.SkipTest("(b)float16 not implemented in TF") tol = None if dtype is np.float64: tol = 1e-14 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol) # TODO(necula): combine tests that are identical except for the harness # wait until we get more experience with using harnesses. @primitive_harness.parameterized(primitive_harness.lax_shift_left) def test_shift_left(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_shift_right_logical) def test_shift_right_logical(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_shift_right_arithmetic) def test_shift_right_arithmetic(self, harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_slice) def test_slice(self, harness): # JAX.slice rejects negative indices; check, and skip jax2tf if any(si < 0 or si >= sh or li < 0 or li > sh for sh, si, li in zip(harness.params["shape"], harness.params["start_indices"], harness.params["limit_indices"])): with self.assertRaisesRegex(TypeError, ""): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) else: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_dynamic_slice) def test_dynamic_slice(self, harness): # JAX.dynamic_slice rejects slice sizes too big; check this, and skip jax2tf args = harness.dyn_args_maker(self.rng()) if any(li - si < 0 or li - si >= sh for sh, si, li in zip(harness.params["shape"], harness.params["start_indices"], harness.params["limit_indices"])): with self.assertRaisesRegex(TypeError, ""): harness.dyn_fun(*args) return self.ConvertAndCompare(harness.dyn_fun, *args) @primitive_harness.parameterized(primitive_harness.lax_dynamic_update_slice) def test_dynamic_update_slice(self, harness): # JAX.dynamic_update_slice rejects update slices too big; check, and skip jax2tf if any(ush > sh for sh, ush in zip(harness.params["shape"], harness.params["update_shape"])): with self.assertRaisesRegex(TypeError, ""): harness.dyn_fun(*harness.dyn_args_maker(self.rng())) else: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_squeeze) def test_squeeze(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_dot_general) def test_dot_general(self, harness: primitive_harness.Harness): tol, dtype = None, harness.params["dtype"] if dtype == dtypes.bfloat16: tol = 0.3 elif dtype in [np.complex64, np.float32]: if jtu.device_under_test() == "tpu": tol = 0.1 if dtype == np.float32 else 0.3 else: tol = 1e-5 elif dtype == np.float16: if jtu.device_under_test() == "gpu": tol = 0.1 else: tol = 0.01 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol) @primitive_harness.parameterized(primitive_harness.lax_conv_general_dilated) def test_conv_general_dilated(self, harness: primitive_harness.Harness): dtype, device = harness.params["dtype"], jtu.device_under_test() if device == "gpu" and dtype in [np.complex64, np.complex128]: raise unittest.SkipTest("TODO: crash on GPU in TF") tol = None if device == "gpu": tol = 1e-4 elif device == "tpu": tol = 1e-3 # TODO(bchetioui): significant discrepancies in some float16 cases. if dtype == np.float16: tol = 1. # TODO(bchetioui): slight occasional discrepancy in float32 cases. elif dtype == np.float32: tol = 0.5 if device == "tpu" else (1e-3 if device == "gpu" else 1e-4) elif dtype == np.complex64 and device == "tpu": tol = 0.1 # TODO(bchetioui): slight discrepancy when going through the path using # tf.nn.convolution. elif dtype == np.float64 and device == "cpu": tol = 1e-13 self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=tol, rtol=tol) @primitive_harness.parameterized(primitive_harness.lax_gather) def test_gather(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_scatter) def test_scatter(self, harness: primitive_harness.Harness): f_name = harness.params['f_lax'].__name__ dtype = harness.params['dtype'] if jtu.device_under_test() == 'tpu': if dtype is np.complex64 and f_name in ['scatter_min', 'scatter_max']: raise unittest.SkipTest(f"TODO: complex {f_name} on TPU fails in JAX") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) def test_boolean_gather(self): values = np.array([[True, True], [False, True], [False, False]], dtype=np.bool_) indices = np.array([0, 1], dtype=np.int32) for axis in [0, 1]: f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop self.ConvertAndCompare(f_jax, values, indices) def test_gather_rank_change(self): params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]]) indices = jnp.array([[1, 1, 2], [0, 1, 0]]) f_jax = jax.jit(lambda i: params[i]) self.ConvertAndCompare(f_jax, indices) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_numerical_input(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in (jnp.cumsum, jnp.cumprod))) def test_cumulated_ops(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{op.__name__}", op=op) for op in INDEX)) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u)) self.ConvertAndCompare(f_jax, values, update) @parameterized.named_parameters(jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_boolean_input(self, f_jax): values = np.array([True, False, True], dtype=np.bool_) self.ConvertAndCompare(f_jax, values) @primitive_harness.parameterized(primitive_harness.random_gamma) def test_random_gamma(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), rtol=1e-5) @primitive_harness.parameterized(primitive_harness.random_split) def test_random_split(self, harness: primitive_harness.Harness): self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) def test_zeros_like(self): v = np.float32(2.) f_jax = jax.ad_util.zeros_like_jaxval self.ConvertAndCompare(f_jax, v) def test_stop_gradient(self): f = jax2tf.convert(lax.stop_gradient) self.assertEqual(f(tf.ones([])), 1.) # test_bfloat16_constant checks that https://github.com/google/jax/issues/3942 is # fixed def test_bfloat16_constant(self): def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. return x def jax_fn_array(x): x = x.astype(jnp.bfloat16) x *= np.array([1.5, 2.5, 3.5], jnp.bfloat16) return x tf_fn_scalar = jax2tf.convert(jax_fn_scalar) self.assertAllClose(tf_fn_scalar(1.375).numpy(), jnp.bfloat16(2.750)) tf_fn_array = jax2tf.convert(jax_fn_array) self.assertAllClose(tf_fn_array(np.array([3, 4, 5])), np.array([4.5, 10, 17.5], jnp.bfloat16))
class LaxBackedScipyTests(jtu.JaxTestCase): def _fetch_preconditioner(self, preconditioner, A, rng=None): """ Returns one of various preconditioning matrices depending on the identifier `preconditioner' and the input matrix A whose inverse it supposedly approximates. """ if preconditioner == 'identity': M = np.eye(A.shape[0], dtype=A.dtype) elif preconditioner == 'random': if rng is None: rng = jtu.rand_default(self.rng()) M = np.linalg.inv(rand_sym_pos_def(rng, A.shape, A.dtype)) elif preconditioner == 'exact': M = np.linalg.inv(A) else: M = None return M @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(4, 4), (7, 7)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'])) def test_cg_against_scipy(self, shape, dtype, preconditioner): if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rand_sym_pos_def(rng, shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=1), partial(lax_cg, M=M, maxiter=1), args_maker, tol=1e-12) self._CheckAgainstNumpy(partial(scipy_cg, M=M, maxiter=3), partial(lax_cg, M=M, maxiter=3), args_maker, tol=1e-12) self._CheckAgainstNumpy(np.linalg.solve, partial(lax_cg, M=M, atol=1e-10), args_maker, tol=1e-6) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype } for shape in [(2, 2)] for dtype in float_types + complex_types)) def test_cg_as_solve(self, shape, dtype): rng = jtu.rand_default(self.rng()) a = rng(shape, dtype) b = rng(shape[:1], dtype) expected = np.linalg.solve(posify(a), b) actual = lax_cg(posify(a), b) self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5) actual = jit(lax_cg)(posify(a), b) self.assertAllClose(expected, actual, atol=1e-5, rtol=1e-5) # numerical gradients are only well defined if ``a`` is guaranteed to be # positive definite. jtu.check_grads(lambda x, y: lax_cg(posify(x), y), (a, b), order=2, rtol=2e-1) def test_cg_ndarray(self): A = lambda x: 2 * x b = jnp.arange(9.0).reshape((3, 3)) expected = b / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual) def test_cg_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=6) self.assertAlmostEqual(expected["b"], actual["b"], places=6) def test_cg_errors(self): A = lambda x: x b = jnp.zeros((2, )) with self.assertRaisesRegex( ValueError, "x0 and b must have matching tree structure"): jax.scipy.sparse.linalg.cg(A, {'x': b}, {'y': b}) with self.assertRaisesRegex(ValueError, "x0 and b must have matching shape"): jax.scipy.sparse.linalg.cg(A, b, b[:, np.newaxis]) with self.assertRaisesRegex(ValueError, "must be a square matrix"): jax.scipy.sparse.linalg.cg(jnp.zeros((3, 2)), jnp.zeros((2, ))) with self.assertRaisesRegex( TypeError, "linear operator must be either a function or ndarray"): jax.scipy.sparse.linalg.cg([[1]], jnp.zeros((1, ))) def test_cg_without_pytree_equality(self): @register_pytree_node_class class MinimalPytree: def __init__(self, value): self.value = value def tree_flatten(self): return [self.value], None @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children) A = lambda x: MinimalPytree(2 * x.value) b = MinimalPytree(jnp.arange(5.0)) expected = b.value / 2 actual, _ = jax.scipy.sparse.linalg.cg(A, b) self.assertAllClose(expected, actual.value) # BICGSTAB @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(5, 5)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'])) def test_bicgstab_against_scipy(self, shape, dtype, preconditioner): if not config.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy(partial(scipy_bicgstab, M=M, maxiter=1), partial(lax_bicgstab, M=M, maxiter=1), args_maker, tol=1e-5) self._CheckAgainstNumpy(partial(scipy_bicgstab, M=M, maxiter=2), partial(lax_bicgstab, M=M, maxiter=2), args_maker, tol=1e-4) self._CheckAgainstNumpy(partial(scipy_bicgstab, M=M, maxiter=1), partial(lax_bicgstab, M=M, maxiter=1), args_maker, tol=1e-4) self._CheckAgainstNumpy(np.linalg.solve, partial(lax_bicgstab, M=M, atol=1e-6), args_maker, tol=1e-4) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(2, 2), (7, 7)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'])) @jtu.skip_on_devices("gpu") def test_bicgstab_on_identity_system(self, shape, dtype, preconditioner): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) rng = jtu.rand_default(self.rng()) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(2, 2), (4, 4)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'])) @jtu.skip_on_devices("gpu") def test_bicgstab_on_random_system(self, shape, dtype, preconditioner): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) solution = rng(shape[1:], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg.bicgstab(A, b, tol=tol, atol=tol, M=M) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) # solve = lambda A, b: jax.scipy.sparse.linalg.bicgstab(A, b)[0] # jtu.check_grads(solve, (A, b), order=1, rtol=3e-1) def test_bicgstab_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.bicgstab(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=5) self.assertAlmostEqual(expected["b"], actual["b"], places=5) # GMRES @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method } for shape in [(3, 3)] for dtype in [np.float64, np.complex128] for preconditioner in [None, 'identity', 'exact', 'random'] for solve_method in ['incremental', 'batched'])) def test_gmres_against_scipy(self, shape, dtype, preconditioner, solve_method): if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) b = rng(shape[:1], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) def args_maker(): return A, b self._CheckAgainstNumpy(partial(scipy_gmres, M=M, restart=1, maxiter=1), partial(lax_gmres, M=M, restart=1, maxiter=1, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy(partial(scipy_gmres, M=M, restart=1, maxiter=2), partial(lax_gmres, M=M, restart=1, maxiter=2, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy(partial(scipy_gmres, M=M, restart=2, maxiter=1), partial(lax_gmres, M=M, restart=2, maxiter=1, solve_method=solve_method), args_maker, tol=1e-10) self._CheckAgainstNumpy(np.linalg.solve, partial(lax_gmres, M=M, atol=1e-6, solve_method=solve_method), args_maker, tol=1e-10) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method } for shape in [(2, 2), (7, 7)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for solve_method in ['batched', 'incremental'])) @jtu.skip_on_devices("gpu") def test_gmres_on_identity_system(self, shape, dtype, preconditioner, solve_method): A = jnp.eye(shape[1], dtype=dtype) solution = jnp.ones(shape[1], dtype=dtype) rng = jtu.rand_default(self.rng()) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) restart = shape[-1] tol = shape[0] * jnp.finfo(dtype).eps x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}_solve_method={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner, solve_method), "shape": shape, "dtype": dtype, "preconditioner": preconditioner, "solve_method": solve_method } for shape in [(2, 2), (4, 4)] for dtype in float_types + complex_types for preconditioner in [None, 'identity', 'exact'] for solve_method in ['incremental', 'batched'])) @jtu.skip_on_devices("gpu") def test_gmres_on_random_system(self, shape, dtype, preconditioner, solve_method): rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) solution = rng(shape[1:], dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) b = matmul_high_precision(A, solution) restart = shape[-1] tol = shape[0] * jnp.finfo(A.dtype).eps x, info = jax.scipy.sparse.linalg.gmres(A, b, tol=tol, atol=tol, restart=restart, M=M, solve_method=solve_method) using_x64 = solution.dtype.kind in {np.float64, np.complex128} solution_tol = 1e-8 if using_x64 else 1e-4 self.assertAllClose(x, solution, atol=solution_tol, rtol=solution_tol) # solve = lambda A, b: jax.scipy.sparse.linalg.gmres(A, b)[0] # jtu.check_grads(solve, (A, b), order=1, rtol=2e-1) def test_gmres_pytree(self): A = lambda x: {"a": x["a"] + 0.5 * x["b"], "b": 0.5 * x["a"] + x["b"]} b = {"a": 1.0, "b": -4.0} expected = {"a": 4.0, "b": -6.0} actual, _ = jax.scipy.sparse.linalg.gmres(A, b) self.assertEqual(expected.keys(), actual.keys()) self.assertAlmostEqual(expected["a"], actual["a"], places=5) self.assertAlmostEqual(expected["b"], actual["b"], places=5) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}_preconditioner={}".format( jtu.format_shape_dtype_string(shape, dtype), preconditioner), "shape": shape, "dtype": dtype, "preconditioner": preconditioner } for shape in [(2, 2), (3, 3)] for dtype in float_types + complex_types for preconditioner in [None, 'identity'])) def test_gmres_arnoldi_step(self, shape, dtype, preconditioner): """ The Arnoldi decomposition within GMRES is correct. """ if not config.x64_enabled: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) A = rng(shape, dtype) M = self._fetch_preconditioner(preconditioner, A, rng=rng) if preconditioner is None: M = lambda x: x else: M = partial(matmul_high_precision, M) n = shape[0] x0 = rng(shape[:1], dtype) Q = np.zeros((n, n + 1), dtype=dtype) Q[:, 0] = x0 / jnp.linalg.norm(x0) Q = jnp.array(Q) H = jnp.eye(n, n + 1, dtype=dtype) @jax.tree_util.Partial def A_mv(x): return matmul_high_precision(A, x) for k in range(n): Q, H, _ = jax._src.scipy.sparse.linalg._kth_arnoldi_iteration( k, A_mv, M, Q, H) QA = matmul_high_precision(Q[:, :n].conj().T, A) QAQ = matmul_high_precision(QA, Q[:, :n]) self.assertAllClose(QAQ, H.T[:n, :], rtol=1e-5, atol=1e-5)
class EnergyTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_simple_spring(self, spatial_dimension, dtype): key = random.PRNGKey(0) disp, _ = space.free() if spatial_dimension == 2: R = np.array([[0., 0.], [1., 1.]], dtype=dtype) dist = np.sqrt(2.) elif spatial_dimension == 3: R = np.array([[0., 0., 0.], [1., 1., 1.]], dtype=dtype) dist = np.sqrt(3.) bonds = np.array([[0, 1]], np.int32) for _ in range(STOCHASTIC_SAMPLES): key, l_key, a_key = random.split(key, 3) length = random.uniform(key, (), minval=0.1, maxval=3.0) alpha = random.uniform(key, (), minval=2., maxval=4.) E = energy.simple_spring_bond(disp, bonds, length=length, alpha=alpha) E_exact = dtype((dist - length)**alpha / alpha) self.assertAllClose(E(R), E_exact, True) # pylint: disable=g-complex-comprehension @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_dim={}_alpha={}_dtype={}'.format(dim, alpha, dtype.__name__), 'spatial_dimension': dim, 'alpha': alpha, 'dtype': dtype } for dim in SPATIAL_DIMENSION for alpha in SOFT_SPHERE_ALPHA for dtype in POSITION_DTYPE)) def test_soft_sphere(self, spatial_dimension, alpha, dtype): key = random.PRNGKey(0) alpha = f32(alpha) for _ in range(STOCHASTIC_SAMPLES): key, split_sigma, split_epsilon = random.split(key, 3) sigma = np.array(random.uniform(split_sigma, (1, ), minval=0.0, maxval=3.0)[0], dtype=dtype) epsilon = np.array(random.uniform(split_epsilon, (1, ), minval=0.0, maxval=4.0)[0], dtype=dtype) self.assertAllClose( energy.soft_sphere(dtype(0), sigma, epsilon, alpha), epsilon / alpha, True) self.assertAllClose( energy.soft_sphere(dtype(sigma), sigma, epsilon, alpha), np.array(0.0, dtype=dtype), True) if alpha == 3.0: grad_energy = grad(energy.soft_sphere) g = grad_energy(dtype(sigma), sigma, epsilon, alpha) self.assertAllClose(g, np.array(0, dtype=dtype), True) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_lennard_jones(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split_sigma, split_epsilon = random.split(key, 3) sigma = dtype( random.uniform(split_sigma, (1, ), minval=0.5, maxval=3.0)[0]) epsilon = dtype( random.uniform(split_epsilon, (1, ), minval=0.0, maxval=4.0)[0]) dr = dtype(sigma * 2**(1.0 / 6.0)) self.assertAllClose(energy.lennard_jones(dr, sigma, epsilon), np.array(-epsilon, dtype=dtype), True) g = grad(energy.lennard_jones)(dr, sigma, epsilon) self.assertAllClose(g, np.array(0, dtype=dtype), True) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_isotropic_cutoff(self, spatial_dimension, dtype): key = random.PRNGKey(0) for _ in range(STOCHASTIC_SAMPLES): key, split_rs, split_rl, split_sigma, split_epsilon = random.split( key, 5) sigma = f32( random.uniform(split_sigma, (1, ), minval=0.5, maxval=3.0)[0]) epsilon = f32( random.uniform(split_epsilon, (1, ), minval=0.0, maxval=4.0)[0]) r_small = random.uniform(split_rs, (10, ), minval=0.0, maxval=2.0 * sigma, dtype=dtype) r_large = random.uniform(split_rl, (10, ), minval=2.5 * sigma, maxval=3.0 * sigma, dtype=dtype) r_onset = f32(2.0 * sigma) r_cutoff = f32(2.5 * sigma) E = energy.multiplicative_isotropic_cutoff(energy.lennard_jones, r_onset, r_cutoff) self.assertAllClose(E(r_small, sigma, epsilon), energy.lennard_jones(r_small, sigma, epsilon), True) self.assertAllClose(E(r_large, sigma, epsilon), np.zeros_like(r_large, dtype=dtype), True) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_soft_sphere_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) exact_energy_fn = energy.soft_sphere_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.soft_sphere_neighbor_list( displacement, box_size, R) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, idx), True) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_lennard_jones_cell_neighbor_list_energy(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_energy_fn = energy.lennard_jones_pair(displacement) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, R) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_energy_fn(R), dtype=dtype), energy_fn(R, idx), True) @parameterized.named_parameters( jtu.cases_from_list( { 'testcase_name': '_dim={}_dtype={}'.format( dim, dtype.__name__), 'spatial_dimension': dim, 'dtype': dtype, } for dim in SPATIAL_DIMENSION for dtype in POSITION_DTYPE)) def test_lennard_jones_neighbor_list_force(self, spatial_dimension, dtype): key = random.PRNGKey(1) box_size = f32(15.0) displacement, _ = space.periodic(box_size) metric = space.metric(displacement) exact_force_fn = quantity.force( energy.lennard_jones_pair(displacement)) R = box_size * random.uniform(key, (PARTICLE_COUNT, spatial_dimension), dtype=dtype) neighbor_fn, energy_fn = energy.lennard_jones_neighbor_list( displacement, box_size, R) force_fn = quantity.force(energy_fn) idx = neighbor_fn(R) self.assertAllClose(np.array(exact_force_fn(R), dtype=dtype), force_fn(R, idx), True) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_num_reps={}_dtype={}'.format(num_repetitions, dtype.__name__), 'num_repetitions': num_repetitions, 'dtype': dtype, } for num_repetitions in UNIT_CELL_SIZE for dtype in POSITION_DTYPE)) def test_eam(self, num_repetitions, dtype): latvec = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]], dtype=dtype) * f32(4.05 / 2) atoms = np.array([[0, 0, 0]], dtype=dtype) atoms_repeated, latvec_repeated = lattice_repeater( atoms, latvec, num_repetitions) inv_latvec = np.array(onp.linalg.inv(onp.array(latvec_repeated)), dtype=dtype) displacement, _ = space.periodic_general(latvec_repeated) charge_fn, embedding_fn, pairwise_fn = make_eam_test_splines() assert charge_fn(np.array(1.0, dtype)).dtype == dtype assert embedding_fn(np.array(1.0, dtype)).dtype == dtype assert pairwise_fn(np.array(1.0, dtype)).dtype == dtype eam_energy = energy.eam(displacement, charge_fn, embedding_fn, pairwise_fn) tol = 1e-5 if dtype == np.float32 else 1e-6 self.assertAllClose( eam_energy(np.dot(atoms_repeated, inv_latvec)) / np.array(num_repetitions**3, dtype), dtype(-3.363338), True, tol, tol)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shapes={}_axis={}_keepdims={}_return_sign={}_use_b_{}".format( jtu.format_shape_dtype_string(shapes, dtype), axis, keepdims, return_sign, use_b), # TODO(b/133842870): re-enable when exp(nan) returns NaN on CPU. "shapes": shapes, "dtype": dtype, "axis": axis, "keepdims": keepdims, "return_sign": return_sign, "use_b": use_b} for shape_group in compatible_shapes for dtype in float_dtypes + complex_dtypes + int_dtypes for use_b in [False, True] for shapes in itertools.product(*( (shape_group, shape_group) if use_b else (shape_group,))) for axis in range(-max(len(shape) for shape in shapes), max(len(shape) for shape in shapes)) for keepdims in [False, True] for return_sign in [False, True])) @jtu.ignore_warning(category=RuntimeWarning, message="invalid value encountered in .*") def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): if jtu.device_under_test() != "cpu": rng = jtu.rand_some_inf_and_nan(self.rng()) else: rng = jtu.rand_default(self.rng()) # TODO(mattjj): test autodiff if use_b: def scipy_fun(array_to_reduce, scale_array): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) def lax_fun(array_to_reduce, scale_array): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign, b=scale_array) args_maker = lambda: [rng(shapes[0], dtype), rng(shapes[1], dtype)] else: def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims, return_sign=return_sign) args_maker = lambda: [rng(shapes[0], dtype)] tol = {np.float32: 1E-6, np.float64: 1E-14} self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng_factory": rec.rng_factory, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "nondiff_argnums": rec.nondiff_argnums, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in itertools.combinations_with_replacement(all_shapes, rec.nargs) for dtypes in (itertools.combinations_with_replacement(rec.dtypes, rec.nargs) if isinstance(rec.dtypes, list) else itertools.product(*rec.dtypes))) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng_factory, shapes, dtypes, test_autodiff, nondiff_argnums): if (jtu.device_under_test() == "cpu" and (lax_op is lsp_special.gammainc or lax_op is lsp_special.gammaincc)): # TODO(b/173608403): re-enable test when LLVM bug is fixed. raise unittest.SkipTest("Skipping test due to LLVM lowering bug") rng = rng_factory(self.rng()) args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, rtol=1e-4) if test_autodiff: def partial_lax_op(*vals): list_args = list(vals) for i in nondiff_argnums: list_args.insert(i, args[i]) return lax_op(*list_args) assert list(nondiff_argnums) == sorted(set(nondiff_argnums)) diff_args = [x for i, x in enumerate(args) if i not in nondiff_argnums] jtu.check_grads(partial_lax_op, diff_args, order=1, atol=jtu.if_device_under_test("tpu", .1, 1e-3), rtol=.1, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol={np.float32: 1e-3, np.float64: 1e-14}) self._CompileAndCheck(lax_fun, args_maker) def testIssue980(self): x = np.full((4,), -1e20, dtype=np.float32) self.assertAllClose(np.zeros((4,), dtype=np.float32), lsp_special.expit(x)) def testIssue3758(self): x = np.array([1e5, 1e19, 1e10], dtype=np.float32) q = np.array([1., 40., 30.], dtype=np.float32) self.assertAllClose(np.array([1., 0., 0.], dtype=np.float32), lsp_special.zeta(x, q)) def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): partial_xlogy = functools.partial(lsp_special.xlogy, 0.) self.assertAllClose(api.grad(partial_xlogy)(0.), 0., check_dtypes=False) def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.) self.assertAllClose(api.grad(partial_xlog1py)(-1.), 0., check_dtypes=False) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_maxdegree={}_inputsize={}".format(l_max, num_z), "l_max": l_max, "num_z": num_z} for l_max, num_z in zip([1, 2, 3], [6, 7, 8]))) def testLpmn(self, l_max, num_z): # Points on which the associated Legendre functions areevaluated. z = np.linspace(-0.2, 0.9, num_z) actual_p_vals, actual_p_derivatives = lsp_special.lpmn(m=l_max, n=l_max, z=z) # The expected results are obtained from scipy. expected_p_vals = np.zeros((l_max + 1, l_max + 1, num_z)) expected_p_derivatives = np.zeros((l_max + 1, l_max + 1, num_z)) for i in range(num_z): val, derivative = osp_special.lpmn(l_max, l_max, z[i]) expected_p_vals[:, :, i] = val expected_p_derivatives[:, :, i] = derivative with self.subTest('Test values.'): self.assertAllClose(actual_p_vals, expected_p_vals, rtol=1e-6, atol=3.2e-6) with self.subTest('Test derivatives.'): self.assertAllClose(actual_p_derivatives,expected_p_derivatives, rtol=1e-6, atol=8.4e-4)
class ScipyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLu(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(shape, dtype)] x, = args_maker() p, l, u = jsp.linalg.lu(x) self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) def testLuOfSingularMatrix(self): x = np.array([[-1., 3. / 2], [2. / 3, -1.]], dtype=onp.float32) p, l, u = jsp.linalg.lu(x) self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU. def testLuGrad(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) a = rng(shape, dtype) lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu jtu.check_grads(lu, (a, ), 2, atol=5e-2, rtol=1e-1) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(4, 5), (6, 5)] for dtype in [np.float32] for rng in [jtu.rand_default()])) def testLuBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args = [rng(shape, np.float32) for _ in range(10)] expected = list(osp.linalg.lu(x) for x in args) ps = onp.stack([out[0] for out in expected]) ls = onp.stack([out[1] for out in expected]) us = onp.stack([out[2] for out in expected]) actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args)) self.assertAllClose(ps, actual_ps, check_dtypes=True) self.assertAllClose(ls, actual_ls, check_dtypes=True) self.assertAllClose(us, actual_us, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [1, 4, 5, 200] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLuFactor(self, n, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] x, = args_maker() lu, piv = jsp.linalg.lu_factor(x) l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype) u = onp.triu(lu) for i in range(n): x[[i, piv[i]], ] = x[[piv[i], i], ] self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_trans={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), trans), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "trans": trans, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4, 2)), ] for trans in [0, 1, 2] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans, rng): _skip_if_unsupported_type(dtype) osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve( (lu, piv), rhs, trans=trans) jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve( (lu, piv), rhs, trans=trans) def args_maker(): a = rng(lhs_shape, dtype) lu, piv = osp.linalg.lu_factor(a) return [lu, piv, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_sym_pos={}_lower={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos, lower), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "sym_pos": sym_pos, "lower": lower, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4)), ] for sym_pos, lower in [ (False, False), (True, False), (True, True), ] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng): _skip_if_unsupported_type(dtype) if (sym_pos and onp.issubdtype(dtype, onp.complexfloating) and jtu.device_under_test() == "tpu"): raise unittest.SkipTest( "Complex Cholesky decomposition not implemented on TPU") osp_fun = lambda lhs, rhs: osp.linalg.solve( lhs, rhs, sym_pos=sym_pos, lower=lower) jsp_fun = lambda lhs, rhs: jsp.linalg.solve( lhs, rhs, sym_pos=sym_pos, lower=lower) def args_maker(): a = rng(lhs_shape, dtype) if sym_pos: a = onp.matmul(a, onp.conj(T(a))) a = onp.tril(a) if lower else onp.triu(a) return [a, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a, unit_diagonal), "lower": lower, "transpose_a": transpose_a, "unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower in [False, True] for transpose_a in [False, True] for unit_diagonal in [False, True] for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types for rng in [jtu.rand_default()])) def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) k = rng(lhs_shape, dtype) l = onp.linalg.cholesky( onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) l = l.astype(k.dtype) b = rng(rhs_shape, dtype) if unit_diagonal: a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype) else: a = l a = a if lower else T(a) inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype) if len(lhs_shape) == len(rhs_shape): onp_ans = onp.matmul(inv, b) else: onp_ans = onp.einsum("...ij,...j->...i", inv, b) # The standard scipy.linalg.solve_triangular doesn't support broadcasting. # But it seems like an inevitable extension so we support it. ans = jsp.linalg.solve_triangular(l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower, unit_diagonal=unit_diagonal) self.assertAllClose(onp_ans, ans, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}". format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a, unit_diagonal), "lower": lower, "transpose_a": transpose_a, "unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower in [False, True] for unit_diagonal in [False, True] for dtype in float_types + complex_types for transpose_a in ( [0, 1] if onp.issubdtype(dtype, np.floating) else [0, 1, 2]) for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") # TODO(phawkins): Test fails on TPU. def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) A = np.tril( rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype)) A = A if lower else T(A) B = rng(rhs_shape, dtype) f = partial(jsp.linalg.solve_triangular, lower=lower, trans=transpose_a, unit_diagonal=unit_diagonal) jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
class QuantityTest(jtu.JaxTestCase): def test_canonicalize_mass(self): assert quantity.canonicalize_mass(3.0) == 3.0 assert quantity.canonicalize_mass(f32(3.0)) == f32(3.0) assert quantity.canonicalize_mass(f64(3.0)) == f64(3.0) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_dim={}'.format(dim), 'spatial_dimension': dim, } for dim in SPATIAL_DIMENSION)) def test_grad_kinetic_energy(self, spatial_dimension): key = random.PRNGKey(0) @jit def do_fn(theta): key = random.PRNGKey(0) V = random.normal(key, (PARTICLE_COUNT, spatial_dimension), dtype=f32) return quantity.kinetic_energy(theta * V) grad(do_fn)(2.0) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_dtype={}'.format(dtype.__name__), 'dtype': dtype, } for dtype in DTYPES)) def test_cosine_angles(self, dtype): displacement, _ = space.free() displacement = space.map_product(displacement) R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype) dR = displacement(R, R) cangles = quantity.cosine_angles(dR) c45 = 1 / np.sqrt(2) true_cangles = np.array([[[0, 0, 0], [0, 1, c45], [0, c45, 1]], [[1, 0, 0], [0, 0, 0], [0, 0, 1]], [[1, c45, 0], [c45, 1, 0], [0, 0, 0]]], dtype=dtype) self.assertAllClose(cangles, true_cangles) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_dtype={}'.format(dtype.__name__), 'dtype': dtype, } for dtype in DTYPES)) def test_cosine_angles_neighbors(self, dtype): displacement, _ = space.free() displacement = vmap(vmap(displacement, (None, 0)), 0) R = np.array([[0, 0], [0, 1], [1, 1]], dtype=dtype) R_neigh = np.array( [[[0, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]], dtype=dtype) dR = displacement(R, R_neigh) cangles = quantity.cosine_angles(dR) c45 = 1 / np.sqrt(2) true_cangles = np.array( [[[1, c45], [c45, 1]], [[1, 1], [1, 1]], [[1, 1], [1, 1]]], dtype=dtype) self.assertAllClose(cangles, true_cangles) @parameterized.named_parameters( jtu.cases_from_list({ 'testcase_name': '_dtype={}'.format(dtype.__name__), 'dtype': dtype, } for dtype in DTYPES)) def test_pair_correlation(self, dtype): displacement = lambda Ra, Rb, **kwargs: Ra - Rb R = np.array([[1, 0], [0, 0], [0, 1]], dtype=dtype) rs = np.linspace(0, 2, 60, dtype=dtype) g = quantity.pair_correlation(displacement, rs, f32(0.1)) gs = g(R) gs = np.mean(gs, axis=0) assert np.argmax(gs) == np.argmin((rs - 1.)**2) assert gs.dtype == dtype
class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): # This test runs for all primitive harnesses. For each primitive "xxx" the # test will be called "test_prim_xxx_..." and the custom parameters for # the test are defined in the class method "jax2tf_limitations.Jax2TfLimitation.xxx". # See more details in the comment at top of file and in Jax2TfLimitation class. # If you want to run this test for only one harness, add parameter # `one_containing="foo"` to parameterized below. @primitive_harness.parameterized(primitive_harness.all_harnesses, include_jax_unimpl=False) @jtu.ignore_warning(category=UserWarning, message="Using reduced precision for gradient.*") def test_prim(self, harness: primitive_harness.Harness): limitations = Jax2TfLimitation.limitations_for_harness(harness) device = jtu.device_under_test() limitations = tuple( filter(lambda l: l.filter(device=device, dtype=harness.dtype), limitations)) func_jax = harness.dyn_fun args = harness.dyn_args_maker(self.rng()) enable_xla = harness.params.get("enable_xla", True) self.ConvertAndCompare(func_jax, *args, limitations=limitations, enable_xla=enable_xla) def test_primitive_coverage(self): """Fail if there are JAX primitives that are not implemented.""" # Harvest primitives from XLA translation tables all_primitives = (set(xla.translations) | set(xla.backend_specific_translations["cpu"]) | set(xla.backend_specific_translations["gpu"]) | set(xla.backend_specific_translations["tpu"]) | set(xla.initial_style_translations) | set(xla.parallel_translations)) tf_impl = set(jax.experimental.jax2tf.jax2tf.tf_impl) | set( jax.experimental.jax2tf.jax2tf.tf_impl_with_avals) tf_not_yet_impl = set(jax.experimental.jax2tf.jax2tf.tf_not_yet_impl) all_primitives = tuple(sorted(all_primitives, key=str)) for p in all_primitives: if p.name == "axis_index": continue if p.name in tf_not_yet_impl: self.assertNotIn( p, tf_impl ) # Should not be in both tf_impl and tf_not_yet_impl else: self.assertIn(p, tf_impl) def test_generate_limitations_doc(self): """Generates primitives_with_limited_support.md. See the doc for instructions. """ harnesses = [ h for h in primitive_harness.all_harnesses if h.filter(h, include_jax_unimpl=True) ] print(f"Found {len(harnesses)} test harnesses that work in JAX") def unique_hash(h: primitive_harness.Harness, l: Jax2TfLimitation): return (h.group_name, l.description, l.devices, tuple([np.dtype(d).name for d in l.dtypes]), l.modes) unique_limitations: Dict[Any, Tuple[primitive_harness.Harness, Jax2TfLimitation]] = {} for h in harnesses: for l in h.jax_unimplemented: if l.enabled: # Fake a Jax2TFLimitation from the Limitation tfl = Jax2TfLimitation( description="Not implemented in JAX: " + l.description, devices=l.devices, dtypes=l.dtypes, expect_tf_error=False, skip_tf_run=True) unique_limitations[hash(unique_hash(h, tfl))] = (h, tfl) for h in harnesses: for l in Jax2TfLimitation.limitations_for_harness(h): unique_limitations[hash(unique_hash(h, l))] = (h, l) print(f"Found {len(unique_limitations)} unique limitations") tf_error_table = [ """ | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | | --- | --- | --- | --- | --- |""" ] tf_numerical_discrepancies_table = list(tf_error_table) # a copy for h, l in sorted(unique_limitations.values(), key=lambda pair: unique_hash(*pair)): devices = ", ".join(sorted(l.devices)) modes = ", ".join(sorted(l.modes)) description = l.description if l.skip_comparison: description = "Numeric comparision disabled: " + description if l.expect_tf_error: description = "TF error: " + description if l.skip_tf_run: description = "TF test skipped: " + description if l.skip_tf_run or l.expect_tf_error: to_table = tf_error_table elif l.skip_comparison or l.custom_assert: to_table = tf_numerical_discrepancies_table else: continue to_table.append( f"| {h.group_name} | {description} | " f"{primitive_harness.dtypes_to_str(l.dtypes, empty_means_all=True)} | {devices} | {modes} |" ) if not os.environ.get("JAX_OUTPUT_LIMITATIONS_DOC"): raise unittest.SkipTest( "Set JAX_OUTPUT_LIMITATIONS_DOC=1 to enable the generation of the documentation" ) # The CPU has more supported types, and harnesses self.assertEqual("cpu", jtu.device_under_test()) self.assertTrue( config.x64_enabled, "Documentation generation must be run with JAX_ENABLE_X64=1") with open( os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md.template") ) as f: template = f.read() output_file = os.path.join( os.path.dirname(__file__), "../g3doc/primitives_with_limited_support.md") with open(output_file, "w") as f: f.write(template.replace("{{generation_date}}", str(datetime.date.today())) \ .replace("{{tf_error_table}}", "\n".join(tf_error_table)) \ .replace("{{tf_numerical_discrepancies_table}}", "\n".join(tf_numerical_discrepancies_table)) \ ) # The rest of the test are checking special cases @parameterized.named_parameters( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in [ jnp.add, jnp.subtract, jnp.multiply, jnp.divide, jnp.less, jnp.less_equal, jnp.equal, jnp.greater, jnp.greater_equal, jnp.not_equal, jnp.maximum, jnp.minimum ]) def test_type_promotion(self, f_jax=jnp.add): # We only test a few types here, as tensorflow does not support many # types like uint* or bool in binary ops. types = [dtypes.bfloat16, np.int32, np.int64, np.float32] for x_dtype in types: for y_dtype in types: x = np.array([1, 2], dtype=x_dtype) y = np.array([3, 4], dtype=y_dtype) self.ConvertAndCompare(f_jax, x, y) def test_integer_div(self): x = jnp.array([-4, -3, -1, 0, 1, 3, 6]) y = np.int32(3) self.ConvertAndCompare(jnp.floor_divide, x, y) expected = jnp.floor_divide(x, y) # Try it with TF 1 as well (#5831) with tf.compat.v1.Session() as sess: tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y)) self.assertAllClose(expected, tf1_res) def test_disable_xla(self): def fun(x): return lax.pad(x, np.float32(0), [(-1, 0, 0), (0, 0, 0)]) with self.assertRaisesRegex( NotImplementedError, "Call to pad cannot be converted with enable_xla=False."): self.ConvertAndCompare(fun, np.ones((2, 3), dtype=np.float32), enable_xla=False) def test_boolean_gather(self): values = np.array([[True, True], [False, True], [False, False]], dtype=np.bool_) indices = np.array([0, 1], dtype=np.int32) for axis in [0, 1]: f_jax = jax.jit(lambda v, i: jnp.take(v, i, axis=axis)) # pylint: disable=cell-var-from-loop self.ConvertAndCompare(f_jax, values, indices) def test_gather_rank_change(self): params = jnp.array([[1.0, 1.5, 2.0], [2.0, 2.5, 3.0], [3.0, 3.5, 4.0]]) indices = jnp.array([[1, 1, 2], [0, 1, 0]]) f_jax = jax.jit(lambda i: params[i]) self.ConvertAndCompare(f_jax, indices) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_numerical_input(self, f_jax): values = np.array([1, 2, 3], dtype=np.float32) self.ConvertAndCompare(f_jax, values) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{op.__name__}", op=op) for op in ( jax.ops.index_add, jax.ops.index_max, jax.ops.index_min, jax.ops.index_mul, jax.ops.index_update, ))) def test_scatter_static(self, op): values = np.ones((5, 6), dtype=np.float32) update = np.float32(6.) f_jax = jax.jit(lambda v, u: op(v, jax.ops.index[::2, 3:], u)) self.ConvertAndCompare(f_jax, values, update) @parameterized.named_parameters( jtu.cases_from_list( dict(testcase_name=f"_{f_jax.__name__}", f_jax=f_jax) for f_jax in REDUCE)) def test_reduce_ops_with_boolean_input(self, f_jax): values = np.array([True, False, True], dtype=np.bool_) self.ConvertAndCompare(f_jax, values)
class StaxTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(shape), "shape": shape} for shape in [(2, 3), (5,)])) def testRandnInitShape(self, shape): key = random.PRNGKey(0) out = stax.randn()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(shape), "shape": shape} for shape in [(2, 3), (2, 3, 4)])) def testGlorotInitShape(self, shape): key = random.PRNGKey(0) out = stax.glorot()(key, shape) self.assertEqual(out.shape, shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape} for channels in [2, 3] for filter_shape in [(1, 1), (2, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 10, 11, 1)])) def testConvShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape} for channels in [2, 3] for filter_shape in [(1, 1), (2, 3), (3, 3)] for padding in ["SAME", "VALID"] for strides in [None, (2, 1), (2, 2)] for input_shape in [(2, 10, 11, 1)])) def testConvTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.ConvTranspose(channels, filter_shape, # 2D strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_channels={}_filter_shape={}_padding={}_strides={}_input_shape={}" .format(channels, filter_shape, padding, strides, input_shape), "channels": channels, "filter_shape": filter_shape, "padding": padding, "strides": strides, "input_shape": input_shape} for channels in [2, 3] for filter_shape in [(1,), (2,), (3,)] for padding in ["SAME", "VALID"] for strides in [None, (1,), (2,)] for input_shape in [(2, 10, 1)])) def testConv1DTransposeShape(self, channels, filter_shape, padding, strides, input_shape): init_fun, apply_fun = stax.Conv1DTranspose(channels, filter_shape, strides=strides, padding=padding) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_out_dim={}_input_shape={}" .format(out_dim, input_shape), "out_dim": out_dim, "input_shape": input_shape} for out_dim in [3, 4] for input_shape in [(2, 3), (3, 4)])) def testDenseShape(self, out_dim, input_shape): init_fun, apply_fun = stax.Dense(out_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_input_shape={}_nonlinear={}" .format(input_shape, nonlinear), "input_shape": input_shape, "nonlinear": nonlinear} for input_shape in [(2, 3), (2, 3, 4)] for nonlinear in ["Relu", "Sigmoid", "Elu", "LeakyRelu"])) def testNonlinearShape(self, input_shape, nonlinear): init_fun, apply_fun = getattr(stax, nonlinear) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_window_shape={}_padding={}_strides={}_input_shape={}" "_maxpool={}_spec={}" .format(window_shape, padding, strides, input_shape, max_pool, spec), "window_shape": window_shape, "padding": padding, "strides": strides, "input_shape": input_shape, "max_pool": max_pool, "spec": spec} for window_shape in [(1, 1), (2, 3)] for padding in ["VALID"] for strides in [None, (2, 1)] for input_shape in [(2, 5, 6, 4)] for max_pool in [False, True] for spec in ["NHWC", "NCHW", "WHNC", "WHCN"])) def testPoolingShape(self, window_shape, padding, strides, input_shape, max_pool, spec): layer = stax.MaxPool if max_pool else stax.AvgPool init_fun, apply_fun = layer(window_shape, padding=padding, strides=strides, spec=spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(input_shape), "input_shape": input_shape} for input_shape in [(2, 3), (2, 3, 4)])) def testFlattenShape(self, input_shape): init_fun, apply_fun = stax.Flatten _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_input_shape={}_spec={}".format(input_shape, i), "input_shape": input_shape, "spec": spec} for input_shape in [(2, 5, 6, 1)] for i, spec in enumerate([ [stax.Conv(3, (2, 2))], [stax.Conv(3, (2, 2)), stax.Flatten, stax.Dense(4)]]))) def testSerialComposeLayersShape(self, input_shape, spec): init_fun, apply_fun = stax.serial(*spec) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape} for input_shape in [(3, 4), (2, 5, 6, 1)])) def testDropoutShape(self, input_shape): init_fun, apply_fun = stax.Dropout(0.9) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_input_shape={}".format(input_shape), "input_shape": input_shape} for input_shape in [(3, 4), (2, 5, 6, 1)])) def testFanInSum(self, input_shape): init_fun, apply_fun = stax.FanInSum _CheckShapeAgreement(self, init_fun, apply_fun, [input_shape, input_shape]) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshapes={}_axis={}".format(input_shapes, axis), "input_shapes": input_shapes, "axis": axis} for input_shapes, axis in [ ([(2, 3), (2, 1)], 1), ([(2, 3), (2, 1)], -1), ([(1, 2, 4), (1, 1, 4)], 1), ])) def testFanInConcat(self, input_shapes, axis): init_fun, apply_fun = stax.FanInConcat(axis) _CheckShapeAgreement(self, init_fun, apply_fun, input_shapes) def testIssue182(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.Softmax input_shape = (10, 3) inputs = onp.arange(30.).astype("float32").reshape(input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) assert out_shape == out.shape assert onp.allclose(onp.sum(onp.asarray(out), -1), 1.) def testBatchNormNoScaleOrCenter(self): key = random.PRNGKey(0) axes = (0, 1, 2) init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) means = onp.mean(out, axis=(0, 1, 2)) std_devs = onp.std(out, axis=(0, 1, 2)) assert onp.allclose(means, onp.zeros_like(means), atol=1e-4) assert onp.allclose(std_devs, onp.ones_like(std_devs), atol=1e-4) def testBatchNormShapeNHWC(self): key = random.PRNGKey(0) init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2)) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (7,)) self.assertEqual(gamma.shape, (7,)) self.assertEqual(out_shape, out.shape) def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) # Regression test for https://github.com/google/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (5,)) self.assertEqual(gamma.shape, (5,)) self.assertEqual(out_shape, out.shape) def testAttentionShape(self): batch_size = 7 nhead = 2 query_dim, value_dim, out_dim = (nhead * 3, nhead * 5, 29) query_src_dim, key_src_dim, value_src_dim = (11, 13, 17) kv_len, q_len = (19, 23) input_shape = [(batch_size, q_len, query_src_dim), (batch_size, kv_len, key_src_dim), (batch_size, kv_len, value_src_dim), (batch_size, kv_len)] init_fun, apply_fun = stax.GeneralAttention( out_dim, nhead=nhead, query_dim=query_dim, value_dim=value_dim) _CheckShapeAgreement(self, init_fun, apply_fun, input_shape) def testAttentionBehavior(self): key = random.PRNGKey(0) batch_size = 7 nhead = 2 query_dim, value_dim, out_dim = (nhead * 3, nhead * 3, nhead * 3) query_src_dim, key_src_dim, value_src_dim = (out_dim, out_dim, out_dim) kv_len, q_len = (19, 19) # Use same lengths for testing causal mask input_shape = [(batch_size, q_len, query_src_dim), (batch_size, kv_len, key_src_dim), (batch_size, kv_len, value_src_dim), (batch_size, kv_len)] init_fun, apply_fun = stax.GeneralAttention( out_dim, nhead=nhead, query_dim=query_dim, value_dim=value_dim, W_init=identity_initializer, b_init=nn.initializers.zeros) unused_output_shape, params = init_fun(key, input_shape) query, key, value = [onp.zeros(shape) for shape in input_shape[:3]] mask = onp.ones(input_shape[3]) output = apply_fun(params, (query, key, value, mask)) # Attention to zero values with zero keys must be zero self.assertAllClose(output, onp.zeros(output.shape), check_dtypes=True) # Make score(K[0, 4], Q[0, 3]) and score(K[0, 4], Q[0, 8]) to be a huge # value for all heads for n in range(nhead): head_offset = out_dim // nhead * n key[0, 4, 0 + head_offset] = 100.0 query[0, 3, 0 + head_offset] = 100.0 query[0, 8, 0 + head_offset] = 100.0 rand_vec = onp.random.RandomState(0).randn(out_dim) value[0, 4] = rand_vec # With identity initializers and the above key and query, # output[0, 3] must be close enough to value[0, 4]. output = apply_fun(params, (query, key, value, mask)) expected = onp.zeros(output.shape) expected[0, :, :] = rand_vec / kv_len expected[0, 3, :] = rand_vec expected[0, 8, :] = rand_vec self.assertAllClose(output, expected, check_dtypes=True) def causal_mask(shape): mask = onp.cumsum(onp.identity(shape[2]), axis=0) return mask init_fun, apply_fun = stax.GeneralAttention( out_dim, nhead=nhead, query_dim=query_dim, value_dim=value_dim, att_prob_mask_fun=causal_mask, W_init=identity_initializer, b_init=nn.initializers.zeros) output = apply_fun(params, (query, key, value, mask)) # Check that it cannot attend to the future value. self.assertAllClose(output[0, 3, :], onp.zeros(out_dim), check_dtypes=True) # But, for 8th query, 3rd value must be retrieved. self.assertAllClose(output[0, 8, :], rand_vec, check_dtypes=True)
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims } for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name) } for rec in JAX_SPECIAL_FUNCTION_RECORDS for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs))) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, test_autodiff): args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3) def testIssue980(self): x = onp.full((4, ), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4, ), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True)
class TestLBFGS(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_func={}_maxiter={}".format(func_and_init[0].__name__, maxiter), "maxiter": maxiter, "func_and_init": func_and_init} for maxiter in [None] for func_and_init in [(rosenbrock, np.zeros(2)), (himmelblau, np.zeros(2)), (matyas, np.ones(2) * 6.), (eggholder, np.ones(2) * 100.)])) def test_minimize(self, maxiter, func_and_init): func, x0 = func_and_init @jit def min_op(x0): result = jax.scipy.optimize.minimize( func(jnp), x0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(maxiter=maxiter, gtol=1e-7), ) return result.x jax_res = min_op(x0) # Note that without bounds, L-BFGS-B is just L-BFGS with jtu.ignore_warning(category=DeprecationWarning, message=".*tostring.*is deprecated.*"): scipy_res = scipy.optimize.minimize(func(np), x0, method='L-BFGS-B').x if func.__name__ == 'matyas': # scipy performs badly for Matyas, compare to true minimum instead self.assertAllClose(jax_res, jnp.zeros_like(jax_res), atol=1e-7) return if func.__name__ == 'eggholder': # L-BFGS performs poorly for the eggholder function. # Neither scipy nor jax find the true minimum, so we can only loosely (with high atol) compare the false results self.assertAllClose(jax_res, scipy_res, atol=1e-3) return self.assertAllClose(jax_res, scipy_res, atol=2e-5, check_dtypes=False) def test_minimize_complex_sphere(self): z0 = jnp.array([1., 2. - 3.j, 4., -5.j]) def f(z): return jnp.real(jnp.dot(jnp.conj(z - z0), z - z0)) @jit def min_op(x0): result = jax.scipy.optimize.minimize( f, x0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(gtol=1e-6), ) return result.x jax_res = min_op(jnp.zeros_like(z0)) self.assertAllClose(jax_res, z0) def test_complex_rosenbrock(self): complex_dim = 5 f_re = rosenbrock(jnp) init_re = jnp.zeros((2 * complex_dim,)) expect_re = jnp.ones((2 * complex_dim,)) def f(z): x_re = jnp.concatenate([jnp.real(z), jnp.imag(z)]) return f_re(x_re) init = init_re[:complex_dim] + 1.j * init_re[complex_dim:] expect = expect_re[:complex_dim] + 1.j * expect_re[complex_dim:] @jit def min_op(z0): result = jax.scipy.optimize.minimize( f, z0, method='l-bfgs-experimental-do-not-rely-on-this', options=dict(gtol=1e-6), ) return result.x jax_res = min_op(init) self.assertAllClose(jax_res, expect, atol=2e-5)