def testTranspose(self, shape, dtype, perm, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.transpose(x, perm) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)
class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters( {"testcase_name": jtu.format_test_name_suffix(rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name)} for rec in itertools.chain(JAX_ONE_TO_ONE_OP_RECORDS, JAX_COMPOUND_OP_RECORDS) for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) def testOp(self, onp_op, lnp_op, rng, shapes, dtypes): args_maker = self._GetArgsMaker(rng, shapes, dtypes) self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis, "keepdims": keepdims} for rec in JAX_REDUCER_RECORDS for shape in all_shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True]) def testReducer(self, onp_op, lnp_op, rng, shape, dtype, axis, keepdims): onp_fun = lambda x: onp_op(x, axis, keepdims=keepdims) lnp_fun = lambda x: lnp_op(x, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "{}_inshape={}_axis={}".format( rec.test_name.capitalize(), jtu.format_shape_dtype_string(shape, dtype), axis), "rng": rec.rng, "shape": shape, "dtype": dtype, "onp_op": getattr(onp, rec.name), "lnp_op": getattr(lnp, rec.name), "axis": axis} for rec in JAX_ARGMINMAX_RECORDS for shape in all_shapes for dtype in rec.dtypes for axis in range(-len(shape), len(shape))) def testArgMinMax(self, onp_op, lnp_op, rng, shape, dtype, axis): def onp_fun(array_to_reduce): return onp_op(array_to_reduce, axis) def lnp_fun(array_to_reduce): return lnp_op(array_to_reduce, axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("matrix-scalar", (3, 3), ()), ("scalar-matrix", (), (3, 3)), ("matrix-vector", (4, 5), (5,)), ("vector-matrix", (6,), (6, 4)), ("matrix-matrix", (3, 4), (4, 5)), ("tensor-vector", (4, 3, 2), (2,)), ("vector-tensor", (2,), (3, 2, 4)), ("tensor-matrix", (4, 3, 2), (2, 5)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-tensor", (2, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)) def testDot(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.dot, lnp.dot, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.dot, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_{}_{}".format( name, jtu.format_shape_dtype_string(lhs_shape, lhs_dtype), jtu.format_shape_dtype_string(rhs_shape, rhs_dtype)), "lhs_shape": lhs_shape, "lhs_dtype": lhs_dtype, "rhs_shape": rhs_shape, "rhs_dtype": rhs_dtype, "rng": rng} for rng in [jtu.rand_default()] for name, lhs_shape, rhs_shape in [ ("vector-vector", (3,), (3,)), ("matrix-vector", (3, 3), (3,)), ("vector-matrix", (3,), (3, 3)), ("matrix-matrix", (3, 3), (3, 3)), ("vector-tensor", (3,), (5, 3, 2)), ("tensor-vector", (5, 3, 2), (2,)), ("matrix-tensor", (5, 2), (3, 2, 4)), ("tensor-matrix", (5, 2, 3), (3, 2)), ("tensor-tensor", (5, 3, 4), (5, 4, 1)), ("tensor-tensor-broadcast", (3, 1, 3, 4), (5, 4, 1))] for lhs_dtype, rhs_dtype in CombosWithReplacement(float_dtypes, 2)) def testMatmul(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, rng): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CheckAgainstNumpy(onp.matmul, lnp.matmul, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.matmul, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_amin={}_amax={}".format( jtu.format_shape_dtype_string(shape, dtype), a_min, a_max), "shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max, "rng": jtu.rand_default()} for shape in all_shapes for dtype in float_dtypes for a_min, a_max in [(-1, None), (None, 1), (-1, 1)]) def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng): onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max) lnp_fun = lambda x: lnp.clip(x, a_min=a_min, a_max=a_max) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_decimals={}".format( jtu.format_shape_dtype_string(shape, dtype), decimals), "shape": shape, "dtype": dtype, "decimals": decimals, "rng": jtu.rand_default()} for shape in all_shapes for dtype in float_dtypes for decimals in [0, 1, -2]) def testRoundStaticDecimals(self, shape, dtype, decimals, rng): onp_fun = lambda x: onp.round(x, decimals=decimals) lnp_fun = lambda x: lnp.round(x, decimals=decimals) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_axis={}_baseshape=[{}]_dtypes=[{}]".format( axis, ",".join(str(d) for d in base_shape), ",".join(onp.dtype(dtype).name for dtype in dtypes)), "axis": axis, "base_shape": base_shape, "dtypes": dtypes, "rng": jtu.rand_default()} for num_arrs in [3] for dtypes in CombosWithReplacement(default_dtypes, num_arrs) for base_shape in [(4,), (3, 4), (2, 3, 4)] for axis in range(-len(base_shape)+1, len(base_shape))) def testConcatenate(self, axis, base_shape, dtypes, rng): wrapped_axis = axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), dtypes)] onp_fun = lambda *args: onp.concatenate(args, axis=axis) lnp_fun = lambda *args: lnp.concatenate(args, axis=axis) def args_maker(): return [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}".format( jtu.format_test_name_suffix("", [shape] * len(dtypes), dtypes)), "shape": shape, "dtypes": dtypes, "rng": rng} for dtypes in [ [onp.float32], [onp.float32, onp.float32], [onp.float32, onp.int32, onp.float32], [onp.float32, onp.int64, onp.float32], [onp.float32, onp.int32, onp.float64], ] for shape in [(), (2,), (3, 4), (1, 100)] for rng in [jtu.rand_default()]) def testStack(self, shape, dtypes, rng): args_maker = lambda: [[rng(shape, dtype) for dtype in dtypes]] self._CheckAgainstNumpy(lnp.stack, onp.stack, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape=[{}]_indtype={}_outdtype={}".format( "_".join(str(d) for d in shape), onp.dtype(fill_value_dtype).name, onp.dtype(out_dtype).name), "shape": shape, "fill_value_dtype": fill_value_dtype, "out_dtype": out_dtype, "rng": jtu.rand_default()} for shape in all_shapes for fill_value_dtype in default_dtypes for out_dtype in default_dtypes) def testFull(self, shape, fill_value_dtype, out_dtype, rng): onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype) lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype) args_maker = lambda: [rng((), fill_value_dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_{}_axis={}_{}sections".format( jtu.format_shape_dtype_string(shape, dtype), axis, num_sections), "shape": shape, "num_sections": num_sections, "axis": axis, "dtype": dtype, "rng": jtu.rand_default()} for shape, axis, num_sections in [ ((3,), 0, 3), ((12,), 0, 3), ((12, 4), 0, 4), ((12, 4), 1, 2), ((2, 3, 4), -1, 2), ((2, 3, 4), -2, 3)] for dtype in default_dtypes) def testSplitStaticInt(self, shape, num_sections, axis, dtype, rng): onp_fun = lambda x: onp.split(x, num_sections, axis=axis) lnp_fun = lambda x: lnp.split(x, num_sections, axis=axis) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_outshape={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), jtu.format_shape_dtype_string(out_shape, dtype)), "arg_shape": arg_shape, "out_shape": out_shape, "dtype": dtype, "rng": jtu.rand_default()} for dtype in default_dtypes for arg_shape, out_shape in [ ((3, 4), 12), ((3, 4), (12,)), ((3, 4), -1), ((2, 1, 4), (-1,)), ((2, 2, 4), (2, 8)) ]) def testReshape(self, arg_shape, out_shape, dtype, rng): onp_fun = lambda x: onp.reshape(x, out_shape) lnp_fun = lambda x: lnp.reshape(x, out_shape) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_expanddim={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), dim), "arg_shape": arg_shape, "dtype": dtype, "dim": dim, "rng": jtu.rand_default()} for arg_shape in [(), (3,), (3, 4)] for dtype in default_dtypes for dim in range(-len(arg_shape)+1, len(arg_shape))) def testExpandDimsStaticDim(self, arg_shape, dtype, dim, rng): onp_fun = lambda x: onp.expand_dims(x, dim) lnp_fun = lambda x: lnp.expand_dims(x, dim) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_axes=({},{})".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax1, ax2), "arg_shape": arg_shape, "dtype": dtype, "ax1": ax1, "ax2": ax2, "rng": jtu.rand_default()} for arg_shape, ax1, ax2 in [ ((3, 4), 0, 1), ((3, 4), 1, 0), ((3, 4, 5), 1, 2), ((3, 4, 5), -1, -2), ((3, 4, 5), 0, 1)] for dtype in default_dtypes) def testSwapAxesStaticAxes(self, arg_shape, dtype, ax1, ax2, rng): onp_fun = lambda x: onp.swapaxes(x, ax1, ax2) lnp_fun = lambda x: lnp.swapaxes(x, ax1, ax2) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_inshape={}_axis={}".format( jtu.format_shape_dtype_string(arg_shape, dtype), ax), "arg_shape": arg_shape, "dtype": dtype, "ax": ax, "rng": jtu.rand_default()} for arg_shape, ax in [ ((3, 1), None), ((3, 1), 1), ((1, 3, 1), (0, 2)), ((1, 4, 1), (0,))] for dtype in default_dtypes) def testSqueeze(self, arg_shape, dtype, ax, rng): onp_fun = lambda x: onp.squeeze(x, ax) lnp_fun = lambda x: lnp.squeeze(x, ax) args_maker = lambda: [rng(arg_shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( {"testcase_name": "_arg{}".format(i), "arg": arg} for i, arg in enumerate([ [1, 2, 3], [1., 2., 3.], [[1, 2], [3, 4], [5, 6]], [[1, 2.], [3, 4], [5, 6]], [[3, onp.array(2), 1], onp.arange(3.)], ])) def testArray(self, arg): args_maker = lambda: [arg] self._CheckAgainstNumpy(onp.array, lnp.array, args_maker, check_dtypes=True) self._CompileAndCheck(lnp.array, args_maker, check_dtypes=True) def testAllClose(self): rng = onp.random.RandomState(0) x = rng.randn(2, 2) y = rng.randn(2) def same(list1, list2): allclose = functools.partial(lnp.allclose, atol=1e-3, rtol=1e-3) elements_close = list(map(allclose, list1, list2)) return lnp.all(lnp.array(elements_close)) csame = api.jit(same) a1 = same((x, y), (x, y)) a2 = csame((x, y), (x, y)) a3 = csame((x, y), (x, 2 * y)) self.assertTrue(a1) self.assertTrue(a2) self.assertFalse(a3) @jtu.skip_on_devices("tpu") # TODO(mattjj): investigate this failure def DISABLED_testOnesBroadcastingConstantHandler(self): # TODO(mattjj): update this test for jax3 def fun(x): ones = lnp.ones((3, 4)) assert isinstance(ones, onp.ndarray) and ones.strides == (0, 0) # To check that the constant handler generates a Broadcast for stride-zero # arrays, we monkey-patch the client instance. # TODO(mattjj): once we have better HLO dumping and inspecting facilities, # we can check the HLO more directly. c = x._node.c Broadcast = c.Broadcast # pylint: disable=invalid-name was_called = [] c.Broadcast = lambda *args: was_called.append(True) or Broadcast(*args) out = x + ones # the ndarray constant handler should call Broadcast here assert was_called, "Broadcast was not called." return out fun = api.jit(fun) out_val = fun(lnp.ones(4)) self.assertAllClose(out_val, onp.full((3, 4), 2.), check_dtypes=False) def testZeroStridesConstantHandler(self): raw_const = onp.random.RandomState(0).randn(1, 2, 1, 1, 5, 1) const = onp.broadcast_to(raw_const, (3, 2, 3, 4, 5, 6)) def fun(x): return x * const fun = api.jit(fun) out_val = fun(3.) self.assertAllClose(out_val, 3. * const, check_dtypes=False) def testIsInstanceNdarrayDuringTracing(self): arr = onp.ones(3) @api.jit def f(x): self.assertIsInstance(x, lnp.ndarray) return lnp.sum(x) f(arr) def testNonArrayErrorMessage(self): x = [1., 2.] y = onp.array([3., 4.]) def g(x, y): return lnp.add(x, y) def f(x, y): return lnp.dot(x, y) self.assertRaises(TypeError, lambda: g(x, y)) self.assertRaises(TypeError, lambda: f(x, y)) self.assertRaises(TypeError, lambda: api.jit(g)(x, y)) self.assertRaises(TypeError, lambda: api.jit(f)(x, y)) def testAbstractionErrorMessage(self): @api.jit def f(x, n): for _ in range(n): x = x * x return x self.assertRaises(TypeError, lambda: f(3., 3)) @api.jit def g(x): if x > 0.: return x * 2 else: return x + 2 self.assertRaises(TypeError, lambda: g(3.)) def DISABLED_testTracingPrimitiveWithNoTranslationErrorMessage(self): # TODO(mattjj): update this for jax3 foo = lnp._not_implemented(lambda x: x) # No error if there's no tracing. foo(onp.arange(3)) cfoo = api.jit(foo) self.assertRaises(NotImplementedError, lambda: cfoo(onp.arange(3))) # TODO(mattjj): test infix operator overrides def DISABLED_testRavel(self): # TODO(mattjj): support this method-based syntax? rng = onp.random.RandomState(0) args_maker = lambda: [rng.randn(3, 4).astype("float32")] self._CompileAndCheck(lambda x: x.ravel(), args_maker, check_dtypes=True)
class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)] for dtype in float_types() for rng in [jtu.rand_default()])) def testCholesky(self, shape, dtype, rng): def args_maker(): factor_shape = shape[:-1] + (2 * shape[-1],) a = rng(factor_shape, dtype) return [onp.matmul(a, np.conj(T(a)))] self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True) if onp.finfo(dtype).bits == 64: jtu.check_grads(np.linalg.cholesky, args_maker(), order=2) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)), "n": n, "dtype": dtype, "rng": rng} for n in [0, 4, 5, 25] # TODO(mattjj): complex64 unstable on large sizes? for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testDet(self, n, dtype, rng): args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)), "n": n, "dtype": dtype, "rng": rng} for n in [0, 4, 10, 200] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSlogdet(self, n, dtype, rng): args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format( jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEig(self, shape, dtype, rng): self.skipTest("Test disabled until Jaxlib 0.1.15 is released") # TODO(phawkins) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() w, v = np.linalg.eig(a) self.assertTrue(onp.all(norm(onp.matmul(a, v) - w[..., None, :] * v) < 100)) self._CompileAndCheck(partial(np.linalg.eig), args_maker, check_dtypes=True, rtol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testEigBatching(self, shape, dtype, rng): self.skipTest("Test disabled until Jaxlib 0.1.15 is released") # TODO(phawkins) shape = (10,) + shape args = rng(shape, dtype) ws, vs = vmap(np.linalg.eig)(args) self.assertTrue(onp.all(onp.linalg.norm( onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}_lower={}".format( jtu.format_shape_dtype_string((n,n), dtype), lower), "n": n, "dtype": dtype, "lower": lower, "rng": rng} for n in [0, 4, 5, 50] for dtype in float_types() + complex_types() for lower in [False, True] for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEigh(self, n, dtype, lower, rng): args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() a = (a + onp.conj(a.T)) / 2 w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo, symmetrize_input=False) self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5) self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30) self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker, check_dtypes=True, rtol=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype), lower), "shape": shape, "dtype": dtype, "rng": rng, "lower":lower} for shape in [(1, 1), (4, 4), (5, 5), (50, 50)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()] for lower in [True, False])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEighGrad(self, shape, dtype, rng, lower): self.skipTest("Test fails with numeric errors.") uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + onp.conj(a.T)) / 2 a = onp.tril(a) if lower else onp.triu(a) # Gradient checks will fail without symmetrization as the eigh jvp rule # is only correct for tangents in the symmetric subspace, whereas the # checker checks against unconstrained (co)tangents. if dtype not in complex_types(): f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True) else: # only check eigenvalue grads for complex matrices f = lambda a: partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0] jtu.check_grads(f, (a,), 2, rtol=1e-1) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype), lower), "shape": shape, "dtype": dtype, "rng": rng, "lower":lower, "eps":eps} for shape in [(1, 1), (4, 4), (5, 5), (50, 50)] for dtype in complex_types() for rng in [jtu.rand_default()] for lower in [True, False] for eps in [1e-4])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps): # Special case to test for complex eigenvector grad correctness. # Exact eigenvector coordinate gradients are hard to test numerically for complex # eigensystem solvers given the extra degrees of per-eigenvector phase freedom. # Instead, we numerically verify the eigensystem properties on the perturbed # eigenvectors. You only ever want to optimize eigenvector directions, not coordinates! uplo = "L" if lower else "U" a = rng(shape, dtype) a = (a + onp.conj(a.T)) / 2 a = onp.tril(a) if lower else onp.triu(a) a_dot = eps * rng(shape, dtype) a_dot = (a_dot + onp.conj(a_dot.T)) / 2 a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot) # evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix f = partial(np.linalg.eigh, UPLO=uplo) (w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,)) new_a = a + a_dot new_w, new_v = f(new_a) new_a = (new_a + onp.conj(new_a.T)) / 2 # Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues. RTOL=1e-2 assert onp.max( onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL # Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues. assert onp.max( onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) / onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0) ) < RTOL @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 4), (5, 5)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testEighBatching(self, shape, dtype, rng): self.skipTest("Test disabled until Jaxlib 0.1.15 is released") # TODO(phawkins) shape = (10,) + shape args = rng(shape, dtype) args = (args + onp.conj(T(args))) / 2 ws, vs = vmap(jsp.linalg.eigh)(args) self.assertTrue(onp.all(onp.linalg.norm( onp.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims, "ord": ord, "rng": rng} for axis, shape in [ (None, (1,)), (None, (7,)), (None, (5, 8)), (0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)), (-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9))] for keepdims in [False, True] for ord in ( [None, 0, 1, 2, 3, -1, -2, -3, np.inf, -np.inf] if (axis is None and len(shape) == 1) or isinstance(axis, int) or (isinstance(axis, tuple) and len(axis) == 1) else [None, 'fro', 1, 2, -1, -2, np.inf, -np.inf, 'nuc']) for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) def testNorm(self, shape, dtype, ord, axis, keepdims, rng): # TODO(mattjj,phawkins): re-enable after checking internal tests self.skipTest("internal test failures") if (ord in ('nuc', 2, -2) and isinstance(axis, tuple) and len(axis) == 2 and (not FLAGS.jax_test_dut or not FLAGS.jax_test_dut.startswith("cpu") or len(shape) != 2)): raise SkipTest("No adequate SVD implementation available") args_maker = lambda: [rng(shape, dtype)] onp_fn = partial(onp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims) self._CheckAgainstNumpy(onp_fn, np_fn, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np_fn, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( jtu.format_shape_dtype_string((m, n), dtype), full_matrices, compute_uv), "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices, "compute_uv": compute_uv, "rng": rng} for m in [2, 7, 29, 53] for n in [2, 7, 29, 53] for dtype in float_types() + complex_types() for full_matrices in [False, True] for compute_uv in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng): args_maker = lambda: [rng((m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50)) else: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50)) else: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50)) # Check the unitary properties of the singular vector matrices. self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) if not full_matrices: svd = partial(np.linalg.svd, full_matrices=False) jtu.check_jvp(svd, partial(jvp, svd), (a,), atol=1e-1 if FLAGS.jax_enable_x64 else jtu.ATOL) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_fullmatrices={}".format( jtu.format_shape_dtype_string(shape, dtype), full_matrices), "shape": shape, "dtype": dtype, "full_matrices": full_matrices, "rng": rng} for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)] for dtype in float_types() for full_matrices in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("cpu") def testQr(self, shape, dtype, full_matrices, rng): m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support broadcasting. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,)) @jtu.skip_on_devices("gpu", "tpu") def testQrBatching(self): shape = (10, 4, 5) dtype = np.float32 rng = jtu.rand_default() args = rng(shape, np.float32) qs, rs = vmap(jsp.linalg.qr)(args) self.assertTrue(onp.all(onp.linalg.norm(args - onp.matmul(qs, rs)) < 1e-3)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng} for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4,)), ((8, 8), (8, 4)), ] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testSolve(self, lhs_shape, rhs_shape, dtype, rng): args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)] for dtype in float_types() for rng in [jtu.rand_default()])) def testInv(self, shape, dtype, rng): def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True) # Regression test for incorrect type for eigenvalues of a complex matrix. @jtu.skip_on_devices("gpu", "tpu") def testIssue669(self): def test(x): val, vec = np.linalg.eigh(x) return np.real(np.sum(val)) grad_test_jc = jit(grad(jit(test))) xc = onp.eye(3, dtype=onp.complex) self.assertAllClose(xc, grad_test_jc(xc), check_dtypes=True)
def testArgminmax(self, op, shape, dtype, dim, bdims): rng = jtu.rand_default(self.rng()) fun = lambda operand: op(operand, dim, np.int32) self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
def testScatterAdd(self, arg_shape, dtype, idxs, update_shape, dnums, bdims): fun = partial(lax.scatter_add, dimension_numbers=dnums) self._CheckBatching(fun, 5, bdims, [arg_shape, idxs.shape, update_shape], [dtype, idxs.dtype, dtype], jtu.rand_default(self.rng()), rtol={np.float16: 5e-3})
def test_sum_2d(self): self.check(jnp.sum, ['(m, n)'], '', dict(m=2, n=3), [(3, 4)], ['float_'], jtu.rand_default(self.rng()))
def test_where(self): # Requires mask(jit) raise SkipTest self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n', {'n': 2}, [(3, )], ['float_'], jtu.rand_default(self.rng()))
def testClassUnaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) args = (rng(shape, dtype), ) class_op = lambda x: op(_DoubleDouble(x)).to_array() self.assertAllClose(op(*args), class_op(*args))
def testBinaryOp(self, dtype, shape, op): rng = jtu.rand_default(self.rng()) op_doubled = doubledouble(op) args = rng(shape, dtype), rng(shape, dtype) self.assertAllClose(op(*args), op_doubled(*args))
def testEighGradPrecision(self): rng = jtu.rand_default() a = rng((3, 3), onp.float32) jtu.assert_dot_precision(lax.Precision.HIGHEST, partial(jvp, np.linalg.eigh), (a, ), (a, ))
class ScipyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (50, 50)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLu(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng(shape, dtype)] x, = args_maker() p, l, u = jsp.linalg.lu(x) self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) def testLuOfSingularMatrix(self): x = np.array([[-1., 3. / 2], [2. / 3, -1.]], dtype=onp.float32) p, l, u = jsp.linalg.lu(x) self.assertAllClose(x, onp.matmul(p, onp.matmul(l, u)), check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(1, 1), (4, 5), (10, 5), (10, 10), (6, 7, 7)] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") # TODO(phawkins): precision problems on TPU. def testLuGrad(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) a = rng(shape, dtype) lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu jtu.check_grads(lu, (a, ), 2, atol=5e-2, rtol=1e-1) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng } for shape in [(4, 5), (6, 5)] for dtype in [np.float32] for rng in [jtu.rand_default()])) def testLuBatching(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) args = [rng(shape, np.float32) for _ in range(10)] expected = list(osp.linalg.lu(x) for x in args) ps = onp.stack([out[0] for out in expected]) ls = onp.stack([out[1] for out in expected]) us = onp.stack([out[2] for out in expected]) actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args)) self.assertAllClose(ps, actual_ps, check_dtypes=True) self.assertAllClose(ls, actual_ls, check_dtypes=True) self.assertAllClose(us, actual_us, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n, n), dtype)), "n": n, "dtype": dtype, "rng": rng } for n in [1, 4, 5, 200] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLuFactor(self, n, dtype, rng): _skip_if_unsupported_type(dtype) args_maker = lambda: [rng((n, n), dtype)] x, = args_maker() lu, piv = jsp.linalg.lu_factor(x) l = onp.tril(lu, -1) + onp.eye(n, dtype=dtype) u = onp.triu(lu) for i in range(n): x[[i, piv[i]], ] = x[[piv[i], i], ] self.assertAllClose(x, onp.matmul(l, u), check_dtypes=True, rtol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_trans={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), trans), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "trans": trans, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4, 2)), ] for trans in [0, 1, 2] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testLuSolve(self, lhs_shape, rhs_shape, dtype, trans, rng): _skip_if_unsupported_type(dtype) osp_fun = lambda lu, piv, rhs: osp.linalg.lu_solve( (lu, piv), rhs, trans=trans) jsp_fun = lambda lu, piv, rhs: jsp.linalg.lu_solve( (lu, piv), rhs, trans=trans) def args_maker(): a = rng(lhs_shape, dtype) lu, piv = osp.linalg.lu_factor(a) return [lu, piv, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_sym_pos={}_lower={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos, lower), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "sym_pos": sym_pos, "lower": lower, "rng": rng } for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4, )), ((8, 8), (8, 4)), ] for sym_pos, lower in [ (False, False), (True, False), (True, True), ] for dtype in float_types + complex_types for rng in [jtu.rand_default()])) def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng): _skip_if_unsupported_type(dtype) if (sym_pos and onp.issubdtype(dtype, onp.complexfloating) and jtu.device_under_test() == "tpu"): raise unittest.SkipTest( "Complex Cholesky decomposition not implemented on TPU") osp_fun = lambda lhs, rhs: osp.linalg.solve( lhs, rhs, sym_pos=sym_pos, lower=lower) jsp_fun = lambda lhs, rhs: jsp.linalg.solve( lhs, rhs, sym_pos=sym_pos, lower=lower) def args_maker(): a = rng(lhs_shape, dtype) if sym_pos: a = onp.matmul(a, onp.conj(T(a))) a = onp.tril(a) if lower else onp.triu(a) return [a, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a, unit_diagonal), "lower": lower, "transpose_a": transpose_a, "unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower in [False, True] for transpose_a in [False, True] for unit_diagonal in [False, True] for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types for rng in [jtu.rand_default()])) def testSolveTriangular(self, lower, transpose_a, unit_diagonal, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) k = rng(lhs_shape, dtype) l = onp.linalg.cholesky( onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) l = l.astype(k.dtype) b = rng(rhs_shape, dtype) if unit_diagonal: a = onp.tril(l, -1) + onp.eye(lhs_shape[-1], dtype=dtype) else: a = l a = a if lower else T(a) inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype) if len(lhs_shape) == len(rhs_shape): onp_ans = onp.matmul(inv, b) else: onp_ans = onp.einsum("...ij,...j->...i", inv, b) # The standard scipy.linalg.solve_triangular doesn't support broadcasting. # But it seems like an inevitable extension so we support it. ans = jsp.linalg.solve_triangular(l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower, unit_diagonal=unit_diagonal) self.assertAllClose(onp_ans, ans, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}_unit_diagonal={}". format(jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a, unit_diagonal), "lower": lower, "transpose_a": transpose_a, "unit_diagonal": unit_diagonal, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng } for lower in [False, True] for unit_diagonal in [False, True] for dtype in float_types + complex_types for transpose_a in ( [0, 1] if onp.issubdtype(dtype, np.floating) else [0, 1, 2]) for lhs_shape, rhs_shape in [ ((4, 4), (4, )), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("tpu") # TODO(phawkins): Test fails on TPU. def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal, lhs_shape, rhs_shape, dtype, rng): _skip_if_unsupported_type(dtype) A = np.tril( rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype)) A = A if lower else T(A) B = rng(rhs_shape, dtype) f = partial(jsp.linalg.solve_triangular, lower=lower, trans=transpose_a, unit_diagonal=unit_diagonal) jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
def testIfftshift(self, shape, dtype, axes): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng(shape, dtype), ) jnp_fn = lambda arg: jnp.fft.ifftshift(arg, axes=axes) np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)
class IndexedUpdateTest(jtu.JaxTestCase): @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for rng in [jtu.rand_default()])) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, rng, indexer, op): args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for rng in [jtu.rand_default()])) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng, indexer, op): args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS for shape, indexer in index_specs for op in UpdateOps for dtype in ( all_dtypes if op == UpdateOps.UPDATE else default_dtypes) for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else all_dtypes) for rng in [jtu.rand_default()])) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, rng, indexer, op): args_maker = lambda: [ rng(shape, dtype), rng(update_shape, update_dtype) ] onp_fn = lambda x, y: UpdateOps.onp_fn(op, indexer, x, y) jax_fn = lambda x, y: UpdateOps.jax_fn(op, indexer, x, y) self._CheckAgainstNumpy(onp_fn, jax_fn, args_maker, check_dtypes=True) self._CompileAndCheck(jax_fn, args_maker, check_dtypes=True) @parameterized.named_parameters( jtu.cases_from_list( { "testcase_name": "{}_inshape={}_indexer={}_update={}_op={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer, jtu.format_shape_dtype_string(update_shape, update_dtype), op.name), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer, "update_shape": update_shape, "update_dtype": update_dtype, "op": op } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for op in UpdateOps for dtype in float_dtypes for update_shape in _broadcastable_shapes( _update_shape(shape, indexer)) for update_dtype in ( [dtype] if op == UpdateOps.ADD else float_dtypes) for rng in [jtu.rand_default()])) def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, rng, indexer, op): jax_op = ops.index_update if op == UpdateOps.UPDATE else ops.index_add jax_fn = lambda x, y: jax_op(x, indexer, y) x = rng(shape, dtype) y = rng(update_shape, update_dtype) check_grads(jax_fn, (x, y), 2, rtol=1e-3, atol=1e-3, eps=1.) def testSegmentSumBehavior(self): # testAdvancedIndexing compares against NumPy, and as a result doesn't check # repeated indices. This test is just a simple manual check, based on # https://www.tensorflow.org/api_docs/python/tf/math/segment_sum data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) ans = ops.index_add(onp.zeros(onp.max(segment_ids) + 1), segment_ids, data) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) def testSegmentSum(self): data = onp.array([5, 1, 7, 2, 3, 4, 1, 3]) segment_ids = onp.array([0, 0, 0, 1, 2, 2, 3, 3]) # test with explicit num_segments ans = ops.segment_sum(data, segment_ids, num_segments=4) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False) # test without explicit num_segments ans = ops.segment_sum(data, segment_ids) expected = onp.array([13, 2, 7, 4]) self.assertAllClose(ans, expected, check_dtypes=False)
class IndexingTest(jtu.JaxTestCase): """Tests for Numpy indexing translation rules.""" @parameterized.named_parameters( jtu.cases_from_list({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in STATIC_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()])) def testStaticIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype)] fun = lambda x: x[indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in STATIC_INDEXING_GRAD_TESTS for shape, indexer in index_specs for dtype in float_dtypes for rng in [jtu.rand_default()]) def testStaticIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = lax.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones)) ) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t( (pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneSliceIndex", [ IndexSpec(shape=(5, ), indexer=slice(1, 3)), IndexSpec(shape=(5, 4), indexer=slice(1, 3)) ]), ("TwoSliceIndices", [ IndexSpec(shape=(5, 4), indexer=(slice(1, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, 2))) ]), ("NonUnitStrides", [ IndexSpec(shape=(3, ), indexer=slice(None, None, -1)), IndexSpec(shape=(3, 3), indexer=slice(0, 3, -2)), IndexSpec(shape=(3, 4, 5), indexer=slice(0, 4, 2)) ]), ("OnlyStartOrStopDynamic", [ IndexSpec(shape=(5, 4), indexer=(slice(None, 3), slice(0, 2))), IndexSpec(shape=(5, 4, 3), indexer=(slice(1, 3), slice(0, None))) ]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testDynamicIndexingWithSlicesErrors(self, shape, dtype, rng, indexer): unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self.assertRaises(IndexError, lambda: fun(*args_maker())) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2) ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)) ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testDynamicIndexingWithIntegers(self, shape, dtype, rng, indexer): unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) def fun(x, unpacked_indexer): indexer = pack_indexer(unpacked_indexer) return x[indexer] args_maker = lambda: [rng(shape, dtype), unpacked_indexer] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("OneIntIndex", [ IndexSpec(shape=(3, ), indexer=1), IndexSpec(shape=(3, 3), indexer=0), IndexSpec(shape=(3, 4, 5), indexer=2), IndexSpec(shape=(3, ), indexer=-1), IndexSpec(shape=(3, ), indexer=-2), ]), ("TwoIntIndices", [ IndexSpec(shape=(3, 3), indexer=(2, 1)), IndexSpec(shape=(3, 4, 5), indexer=(1, 2)), IndexSpec(shape=(3, 4, 5), indexer=(-1, 2)), ]), ("ThreeIntIndices", [IndexSpec((3, 4, 5), indexer=(1, 2, 3))]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng in [jtu.rand_default()]) def testDynamicIndexingWithIntegersGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None unpacked_indexer, pack_indexer = self._ReplaceSlicesWithTuples(indexer) @api.jit def fun(unpacked_indexer, x): indexer = pack_indexer(unpacked_indexer) return x[indexer] arr = rng(shape, dtype) check_grads(partial(fun, unpacked_indexer), (arr, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): args_maker = lambda: [rng(shape, dtype), indexer] fun = lambda x, idx: x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in [ ("One1DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([0, 1])), IndexSpec(shape=(3, 3), indexer=onp.array([1, 2, 1])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([0, 2, 0, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-1, 1])), IndexSpec(shape=(3, ), indexer=onp.array([-2, -1])), ]), ("One2DIntArrayIndex", [ IndexSpec(shape=(3, ), indexer=onp.array([[0, 0]])), IndexSpec(shape=(3, 3), indexer=onp.array([[1, 2, 1], [0, 1, -1]])), IndexSpec(shape=(3, 4, 5), indexer=onp.array([[0, 2, 0, 1], [-1, -2, 1, 0]])), ]), ("Two1DIntArrayIndicesNoBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([0, 1]), onp.array([1, 2])]), IndexSpec( shape=(3, 4, 5), indexer=[onp.array([0, 2, 0, 1]), onp.array([-1, 0, -1, 2])]), ]), ("Two1DIntArrayIndicesWithBroadcasting", [ IndexSpec(shape=(3, 3), indexer=[onp.array([[0, 1]]), onp.array([1, 2])]), IndexSpec(shape=(3, 4, 5), indexer=[ onp.array([[0, 2, 0, 1]]), onp.array([-1, 0, -1, 2]) ]), ]), ("ListOfPythonInts", [ IndexSpec(shape=(3, ), indexer=[0, 1, 0]), IndexSpec(shape=(3, 4, 5), indexer=[0, -1]), ]), ("ListOfListsOfPythonInts", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1]]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], [[2, 3, 0, 3]]]), ]), ("ListOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[0, onp.array([0, 1])]), IndexSpec(shape=(3, 4, 5), indexer=[0, 1, onp.array([[2, 3, 0, 3]])]), ]), ("ListOfListsOfPythonIntsAndIntArrays", [ IndexSpec(shape=(3, 4, 5), indexer=[[0, 1], onp.array([0])]), IndexSpec(shape=(3, 4, 5), indexer=[[[0], [-1]], onp.array([[2, 3, 0, 3]])]), ]), ] for shape, indexer in index_specs for dtype in float_dtypes for rng in [jtu.rand_default()]) def testAdvancedIntegerIndexingGrads(self, shape, dtype, rng, indexer): tol = 1e-2 if onp.finfo(dtype).bits == 32 else None arg = rng(shape, dtype) fun = lambda x: x[indexer]**2 check_grads(fun, (arg, ), 2, tol, tol, tol) @parameterized.named_parameters({ "testcase_name": "{}_inshape={}_indexer={}".format( name, jtu.format_shape_dtype_string(shape, dtype), indexer), "shape": shape, "dtype": dtype, "rng": rng, "indexer": indexer } for name, index_specs in MIXED_ADVANCED_INDEXING_TESTS for shape, indexer in index_specs for dtype in all_dtypes for rng in [jtu.rand_default()]) def testMixedAdvancedIntegerIndexing(self, shape, dtype, rng, indexer): indexer_with_dummies = [ e if isinstance(e, onp.ndarray) else () for e in indexer ] substitutes = [(i, e) for i, e in enumerate(indexer) if not isinstance(e, onp.ndarray)] args_maker = lambda: [rng(shape, dtype), indexer_with_dummies] def fun(x, indexer_with_dummies): idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes)) return x[idx] self._CompileAndCheck(fun, args_maker, check_dtypes=True) def testAdvancedIndexingManually(self): x = onp.random.RandomState(0).randn(3, 4, 5) index_array = onp.array([0, 2, -1, 0]) op = lambda x, index_array: x[..., index_array, :] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[..., index_array, :, index_array, None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) op = lambda x, index_array: x[index_array, ..., index_array[:, None], None] cop = api.jit(op) a1 = op(x, index_array) a2 = cop(x, index_array) self.assertAllClose(a1, a2, check_dtypes=True) def testUnpacking(self): def foo(x): a, b, c = x return a + b + c cfoo = api.jit(foo) a1 = foo(onp.arange(3)) a2 = cfoo(onp.arange(3)) self.assertAllClose(a1, a2, check_dtypes=True) def testBooleanIndexingArray1D(self): idx = onp.array([True, True, False]) x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList1D(self): idx = [True, True, False] x = api.device_put(onp.arange(3)) ans = x[idx] expected = onp.arange(3)[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2DBroadcast(self): idx = onp.array([True, True, False, True]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingList2DBroadcast(self): idx = [True, True, False, True] x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingArray2D(self): idx = onp.array([[True, False], [False, True], [False, False], [True, True]]) x = onp.arange(8).reshape(4, 2) ans = api.device_put(x)[idx] expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingDynamicShapeError(self): x = onp.zeros(3) i = onp.array([True, True, False]) self.assertRaises(IndexError, lambda: api.jit(lambda x, i: x[i])(x, i)) def testIssue187(self): x = lnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash x = onp.arange(25).reshape((5, 5)) ans = api.jit(lambda x: x[[0, 2, 4], [0, 2, 4]])(x) expected = x[[0, 2, 4], [0, 2, 4]] self.assertAllClose(ans, expected, check_dtypes=False) def testJVPOfGradOfIndexing(self): # Should return a value, even though we didn't pass a symbolic zero as the # index tangent. x = lnp.ones((3, 4), lnp.float32) i = lnp.ones((3, ), lnp.int32) f = lambda x, i: lnp.sum(x[i]) primals, tangents = api.jvp(api.grad(f), (x, i), (x, onp.zeros_like(i))) expected = onp.broadcast_to( onp.array([0, 3, 0], dtype=onp.float32)[:, None], (3, 4)) self.assertAllClose(expected, primals, check_dtypes=True) self.assertAllClose(onp.zeros_like(x), tangents, check_dtypes=True)
def test_lax_slice(self): self.check(lambda x: lax.slice(x, (1, ), (x.shape[0], )), ['n'], 'n+-1', {'n': 2}, [(3, )], ['float_'], jtu.rand_default(self.rng()))
class NumpyLinalgTest(jtu.JaxTestCase): @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)] for dtype in float_types() for rng in [jtu.rand_default()])) def testCholesky(self, shape, dtype, rng): def args_maker(): a = rng(shape, dtype) return [onp.matmul(a, np.conj(T(a)))] self._CheckAgainstNumpy(onp.linalg.cholesky, np.linalg.cholesky, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.cholesky, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)), "n": n, "dtype": dtype, "rng": rng} for n in [0, 4, 5, 50] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testDet(self, n, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.det, np.linalg.det, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.det, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)), "n": n, "dtype": dtype, "rng": rng} for n in [0, 4, 10, 200] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSlogdet(self, n, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(onp.linalg.slogdet, np.linalg.slogdet, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.slogdet, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}_lower={}".format( jtu.format_shape_dtype_string((n,n), dtype), lower), "n": n, "dtype": dtype, "lower": lower, "rng": rng} for n in [0, 4, 5, 50] for dtype in float_types() | complex_types() for lower in [False, True] for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an eigendecomposition implementation # for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testEigh(self, n, dtype, lower, rng): if not hasattr(lapack, "jax_syevd"): self.skipTest("No symmetric eigendecomposition implementation available") args_maker = lambda: [rng((n, n), dtype)] uplo = "L" if lower else "U" # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / ((n + 1) * onp.finfo(dtype).eps) a, = args_maker() a = (a + onp.conj(a.T)) / 2 w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo) self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5) self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30) self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}_full_matrices={}_compute_uv={}".format( jtu.format_shape_dtype_string((m, n), dtype), full_matrices, compute_uv), "m": m, "n": n, "dtype": dtype, "full_matrices": full_matrices, "compute_uv": compute_uv, "rng": rng} for m in [2, 7, 29, 53] for n in [2, 7, 29, 53] for dtype in float_types() | complex_types() for full_matrices in [False, True] for compute_uv in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testSVD(self, m, n, dtype, full_matrices, compute_uv, rng): if not hasattr(lapack, "jax_gesdd"): self.skipTest("No singular value decomposition implementation available") args_maker = lambda: [rng((m, n), dtype)] # Norm, adjusted for dimension and type. def norm(x): norm = onp.linalg.norm(x, axis=(-2, -1)) return norm / (max(m, n) * onp.finfo(dtype).eps) a, = args_maker() out = np.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) if compute_uv: # Check the reconstructed matrices if full_matrices: k = min(m, n) if m < n: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2][:k, :])) < 50)) else: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0][:, :k], out[2])) < 50)) else: self.assertTrue(onp.all(norm(a - onp.matmul(out[1] * out[0], out[2])) < 50)) # Check the unitary properties of the singular vector matrices. self.assertTrue(onp.all(norm(onp.eye(out[0].shape[1]) - onp.matmul(onp.conj(T(out[0])), out[0])) < 10)) if m >= n: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[1]) - onp.matmul(onp.conj(T(out[2])), out[2])) < 10)) else: self.assertTrue(onp.all(norm(onp.eye(out[2].shape[0]) - onp.matmul(out[2], onp.conj(T(out[2])))) < 20)) else: self.assertTrue(onp.allclose(onp.linalg.svd(a, compute_uv=False), onp.asarray(out), atol=1e-4, rtol=1e-4)) self._CompileAndCheck(partial(np.linalg.svd, full_matrices=full_matrices, compute_uv=compute_uv), args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}_fullmatrices={}".format( jtu.format_shape_dtype_string(shape, dtype), full_matrices), "shape": shape, "dtype": dtype, "full_matrices": full_matrices, "rng": rng} for shape in [(1, 1), (3, 4), (2, 10, 5), (2, 200, 100)] for dtype in float_types() for full_matrices in [False, True] for rng in [jtu.rand_default()])) @jtu.skip_on_devices("cpu") def testQr(self, shape, dtype, full_matrices, rng): m, n = shape[-2:] if full_matrices: mode, k = "complete", m else: mode, k = "reduced", min(m, n) a = rng(shape, dtype) lq, lr = np.linalg.qr(a, mode=mode) # onp.linalg.qr doesn't support broadcasting. But it seems like an # inevitable extension so we support it in our version. nq = onp.zeros(shape[:-2] + (m, k), dtype) nr = onp.zeros(shape[:-2] + (k, n), dtype) for index in onp.ndindex(*shape[:-2]): nq[index], nr[index] = onp.linalg.qr(a[index], mode=mode) max_rank = max(m, n) # Norm, adjusted for dimension and type. def norm(x): n = onp.linalg.norm(x, axis=(-2, -1)) return n / (max_rank * onp.finfo(dtype).eps) def compare_orthogonal(q1, q2): # Q is unique up to sign, so normalize the sign first. sum_of_ratios = onp.sum(onp.divide(q1, q2), axis=-2, keepdims=True) phases = onp.divide(sum_of_ratios, onp.abs(sum_of_ratios)) q1 *= phases self.assertTrue(onp.all(norm(q1 - q2) < 30)) # Check a ~= qr self.assertTrue(onp.all(norm(a - onp.matmul(lq, lr)) < 30)) # Compare the first 'k' vectors of Q; the remainder form an arbitrary # orthonormal basis for the null space. compare_orthogonal(nq[..., :k], lq[..., :k]) # Check that q is close to unitary. self.assertTrue(onp.all(norm(onp.eye(k) - onp.matmul(T(lq), lq)) < 5)) if not full_matrices and m >= n: jtu.check_jvp(np.linalg.qr, partial(jvp, np.linalg.qr), (a,)) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype)), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng} for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4,)), ((8, 8), (8, 4)), ] for dtype in float_types() | complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testSolve(self, lhs_shape, rhs_shape, dtype, rng): if not hasattr(lapack, "jax_getrf"): self.skipTest("No LU implementation available") args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] self._CheckAgainstNumpy(onp.linalg.solve, np.linalg.solve, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.solve, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 4), (2, 5, 5), (200, 200), (5, 5, 5)] for dtype in float_types() for rng in [jtu.rand_default()])) def testInv(self, shape, dtype, rng): def args_maker(): invertible = False while not invertible: a = rng(shape, dtype) try: onp.linalg.inv(a) invertible = True except onp.linalg.LinAlgError: pass return [a] self._CheckAgainstNumpy(onp.linalg.inv, np.linalg.inv, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(np.linalg.inv, args_maker, check_dtypes=True)
def test_transpose(self): self.check(lambda x: lax.transpose(x, (1, 0, 2)), ['(a, b, c)'], 'b, a, c', dict(a=2, b=3, c=4), [(3, 4, 5)], ['float_'], jtu.rand_default(self.rng()))
def op_record(name, nargs, dtypes, rng, test_grad, test_name=None): test_name = test_name or name return OpRecord(name, nargs, dtypes, rng, test_grad, test_name) JAX_SPECIAL_FUNCTION_RECORDS = [ # TODO: digamma has no JVP implemented. op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False), op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True), # TODO: gammaln has slightly high error. op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False), op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False), op_record("log_ndtr", 1, float_dtypes, jtu.rand_default(), True), op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0.05, 0.95), True), op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True), # TODO(phawkins): gradient of entr yields NaNs. op_record("entr", 1, float_dtypes, jtu.rand_default(), False), ] CombosWithReplacement = itertools.combinations_with_replacement class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)]
def test_expit(self): raise SkipTest("custom_jvp doesn't work with masking yet") self.check(expit, ['n'], 'n', dict(n=3), [(4, )], ['float_'], jtu.rand_default(self.rng()))
class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)] @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_axis={}_keepdims={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, keepdims), "rng": jtu.rand_default(), "shape": shape, "dtype": dtype, "axis": axis, "keepdims": keepdims} for shape in all_shapes for dtype in float_dtypes for axis in range(-len(shape), len(shape)) for keepdims in [False, True])) @jtu.skip_on_flag("jax_xla_backend", "xrt") def testLogSumExp(self, rng, shape, dtype, axis, keepdims): # TODO(mattjj): test autodiff def scipy_fun(array_to_reduce): return osp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) def lax_fun(array_to_reduce): return lsp_special.logsumexp(array_to_reduce, axis, keepdims=keepdims) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix( rec.test_name, shapes, dtypes), "rng": rec.rng, "shapes": shapes, "dtypes": dtypes, "test_autodiff": rec.test_autodiff, "scipy_op": getattr(osp_special, rec.name), "lax_op": getattr(lsp_special, rec.name)} for shapes in CombosWithReplacement(all_shapes, rec.nargs) for dtypes in CombosWithReplacement(rec.dtypes, rec.nargs)) for rec in JAX_SPECIAL_FUNCTION_RECORDS)) def testScipySpecialFun(self, scipy_op, lax_op, rng, shapes, dtypes, test_autodiff): args_maker = self._GetArgsMaker(rng, shapes, dtypes) args = args_maker() self.assertAllClose(scipy_op(*args), lax_op(*args), atol=1e-3, rtol=1e-3, check_dtypes=False) self._CompileAndCheck(lax_op, args_maker, check_dtypes=True) if test_autodiff: jtu.check_grads(lax_op, args, order=1, atol=1e-3, rtol=3e-3, eps=1e-3) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_inshape={}_d={}".format( jtu.format_shape_dtype_string(shape, dtype), d), "rng": jtu.rand_positive(), "shape": shape, "dtype": dtype, "d": d} for shape in all_shapes for dtype in float_dtypes for d in [1, 2, 5])) def testMultigammaln(self, rng, shape, dtype, d): def scipy_fun(a): return osp_special.multigammaln(a, d) def lax_fun(a): return lsp_special.multigammaln(a, d) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=True) self._CompileAndCheck(lax_fun, args_maker, check_dtypes=True) def testIssue980(self): x = onp.full((4,), -1e20, dtype=onp.float32) self.assertAllClose(onp.zeros((4,), dtype=onp.float32), lsp_special.expit(x), check_dtypes=True)
def test_reduce(self, operator): self.check(operator, ['(m, n)'], '', { 'm': 3, 'n': 4 }, [(4, 5)], ['float_'], jtu.rand_default(self.rng()))
def testSparseAttrAccess(self, attr): rng = jtu.rand_default(self.rng()) args_maker = lambda: [make_sparse_array(rng, (10, ), jnp.float32)] f = lambda x: getattr(x, attr) self._CompileAndCheck(f, args_maker)
def testGather(self, shape, dtype, idxs, dnums, slice_sizes, bdims): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) self._CheckBatching(fun, 5, bdims, [shape, idxs.shape], [dtype, idxs.dtype], jtu.rand_default(self.rng()))
def test_where(self): self.check(lambda x: jnp.where(x < 0, x, 0. * x), ['n'], 'n', {'n': 2}, [(3, )], ['float_'], jtu.rand_default(self.rng()))
int_dtypes = [onp.int32, onp.int64] bool_dtypes = [onp.bool_] default_dtypes = float_dtypes + int_dtypes numeric_dtypes = float_dtypes + complex_dtypes + int_dtypes OpRecord = collections.namedtuple("OpRecord", ["name", "nargs", "dtypes", "rng", "diff_modes", "test_name"]) def op_record(name, nargs, dtypes, rng, diff_modes, test_name=None): test_name = test_name or name return OpRecord(name, nargs, dtypes, rng, diff_modes, test_name) JAX_ONE_TO_ONE_OP_RECORDS = [ op_record("abs", 1, default_dtypes, jtu.rand_default(), ["rev"]), op_record("add", 2, default_dtypes, jtu.rand_default(), ["rev"]), op_record("bitwise_and", 2, default_dtypes, jtu.rand_bool(), []), op_record("bitwise_not", 1, default_dtypes, jtu.rand_bool(), []), op_record("bitwise_or", 2, default_dtypes, jtu.rand_bool(), []), op_record("bitwise_xor", 2, default_dtypes, jtu.rand_bool(), []), op_record("ceil", 1, float_dtypes, jtu.rand_default(), []), op_record("conj", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("conjugate", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("exp", 1, numeric_dtypes, jtu.rand_default(), ["rev"]), op_record("floor", 1, float_dtypes, jtu.rand_default(), []), op_record("greater", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("greater_equal", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("less", 2, default_dtypes, jtu.rand_some_equal(), []), op_record("less_equal", 2, default_dtypes, jtu.rand_some_equal(), []),
def test_split(self): self.check(lambda x: jnp.split(x, 2), ['2*n'], ['n', 'n'], dict(n=4), [(8, )], ['float_'], jtu.rand_default(self.rng())) self.check(lambda x: jnp.split(x, [10]), ['n'], ['10', 'n+-10'], dict(n=12), [(12, )], ['float_'], jtu.rand_default(self.rng()))
class ScipyLinalgTest(jtu.JaxTestCase): # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 5), (10, 5), (50, 50)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testLu(self, shape, dtype, rng): args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(jsp.linalg.lu, osp.linalg.lu, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_shape={}".format(jtu.format_shape_dtype_string(shape, dtype)), "shape": shape, "dtype": dtype, "rng": rng} for shape in [(1, 1), (4, 5), (10, 5), (10, 10)] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testLuGrad(self, shape, dtype, rng): a = rng(shape, dtype) jtu.check_grads(jsp.linalg.lu, (a,), 2, rtol=1e-1) @jtu.skip_on_devices("gpu", "tpu") def testLuBatching(self): shape = (4, 5) dtype = np.float32 rng = jtu.rand_default() args = [rng(shape, np.float32) for _ in range(10)] expected = list(osp.linalg.lu(x) for x in args) ps = onp.stack([out[0] for out in expected]) ls = onp.stack([out[1] for out in expected]) us = onp.stack([out[2] for out in expected]) actual_ps, actual_ls, actual_us = vmap(jsp.linalg.lu)(np.stack(args)) self.assertAllClose(ps, actual_ps, check_dtypes=True) self.assertAllClose(ls, actual_ls, check_dtypes=True) self.assertAllClose(us, actual_us, check_dtypes=True) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_n={}".format(jtu.format_shape_dtype_string((n,n), dtype)), "n": n, "dtype": dtype, "rng": rng} for n in [1, 4, 5, 200] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) @jtu.skip_on_devices("gpu", "tpu") def testLuFactor(self, n, dtype, rng): args_maker = lambda: [rng((n, n), dtype)] self._CheckAgainstNumpy(jsp.linalg.lu_factor, osp.linalg.lu_factor, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp.linalg.lu_factor, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_sym_pos={}_lower={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), sym_pos, lower), "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "sym_pos": sym_pos, "lower": lower, "rng": rng} for lhs_shape, rhs_shape in [ ((1, 1), (1, 1)), ((4, 4), (4,)), ((8, 8), (8, 4)), ] for sym_pos, lower in [ (False, False), (True, False), (True, True), ] for dtype in float_types() + complex_types() for rng in [jtu.rand_default()])) # TODO(phawkins): enable when there is an LU implementation for GPU/TPU. @jtu.skip_on_devices("gpu", "tpu") def testSolve(self, lhs_shape, rhs_shape, dtype, sym_pos, lower, rng): osp_fun = lambda lhs, rhs: osp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower) jsp_fun = lambda lhs, rhs: jsp.linalg.solve(lhs, rhs, sym_pos=sym_pos, lower=lower) def args_maker(): a = rng(lhs_shape, dtype) if sym_pos: a = onp.matmul(a, onp.conj(T(a))) a = onp.tril(a) if lower else onp.triu(a) return [a, rng(rhs_shape, dtype)] self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=True, tol=1e-3) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a), "lower": lower, "transpose_a": transpose_a, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng} for lower, transpose_a in itertools.product([False, True], repeat=2) for lhs_shape, rhs_shape in [ ((4, 4), (4,)), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types() for rng in [jtu.rand_default()])) def testSolveTriangular(self, lower, transpose_a, lhs_shape, rhs_shape, dtype, rng): k = rng(lhs_shape, dtype) l = onp.linalg.cholesky(onp.matmul(k, T(k)) + lhs_shape[-1] * onp.eye(lhs_shape[-1])) l = l.astype(k.dtype) b = rng(rhs_shape, dtype) a = l if lower else T(l) inv = onp.linalg.inv(T(a) if transpose_a else a).astype(a.dtype) if len(lhs_shape) == len(rhs_shape): onp_ans = onp.matmul(inv, b) else: onp_ans = onp.einsum("...ij,...j->...i", inv, b) # The standard scipy.linalg.solve_triangular doesn't support broadcasting. # But it seems like an inevitable extension so we support it. ans = jsp.linalg.solve_triangular( l if lower else T(l), b, trans=1 if transpose_a else 0, lower=lower) self.assertAllClose(onp_ans, ans, check_dtypes=True) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_lhs={}_rhs={}_lower={}_transposea={}".format( jtu.format_shape_dtype_string(lhs_shape, dtype), jtu.format_shape_dtype_string(rhs_shape, dtype), lower, transpose_a), "lower": lower, "transpose_a": transpose_a, "lhs_shape": lhs_shape, "rhs_shape": rhs_shape, "dtype": dtype, "rng": rng} for lower, transpose_a in itertools.product([False, True], repeat=2) for lhs_shape, rhs_shape in [ ((4, 4), (4,)), ((4, 4), (4, 3)), ((2, 8, 8), (2, 8, 10)), ] for dtype in float_types() for rng in [jtu.rand_default()])) def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape, rhs_shape, dtype, rng): # TODO(frostig): change ensemble to support a bigger rtol self.skipTest("rtol does not cover all devices and precision modes") A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype)) A = A if lower else T(A) B = rng(rhs_shape, dtype) f = partial(jsp.linalg.solve_triangular, lower=lower, trans=1 if transpose_a else 0) jtu.check_grads(f, (A, B), 2, rtol=1e-3)
def test_mean(self): # TODO Shapecheck fails - shape_as_value can't deal with abstract eval yet raise SkipTest self.check(lambda x: jnp.sum(x) / shape_as_value(x.shape)[0], ['n'], '', {'n': 3}, [(4, )], ['float_'], jtu.rand_default(self.rng()))
class BatchingTest(jtu.JaxTestCase): def testConstantFunction(self): ans = vmap(lambda x: 3)(onp.ones(4)) expected = 3 * onp.ones(4) self.assertAllClose(ans, expected, check_dtypes=False) def testNestedBatchingMatMat(self): matvec = vmap(np.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = onp.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = onp.dot(A, B) self.assertAllClose(ans, expected, check_dtypes=False) # this is a crude check that we only call a single dot def pv_like(x): aval = ShapedArray(onp.shape(x), onp.result_type(x)) return pe.PartialVal((aval, unit)) def make_jaxpr(fun, example_args): jaxpr, _, _, _ = trace_to_jaxpr(fun, map(pv_like, example_args)) return jaxpr jaxpr = make_jaxpr(matmat, (A, B)) self.assertEqual(len(jaxpr.eqns), 1) def testPerExampleGradients(self): def predict(params, inputs): for W, b in params: outputs = np.dot(W, inputs) + b inputs = np.tanh(outputs) return outputs def loss(params, data): inputs, targets = data predictions = predict(params, inputs) return np.sum((predictions - targets)**2) batch_size = 5 layer_sizes = [3, 2, 4] R = onp.random.RandomState(0).randn params = [(R(m, n), R(m)) for m, n in zip(layer_sizes[1:], layer_sizes[:-1])] input_vec = R(3) target_vec = R(4) datum = (input_vec, target_vec) input_batch = R(5, 3) target_batch = R(5, 4) batch = (input_batch, target_batch) ans = vmap(partial(grad(loss), params))(batch) for ans_pair, param_pair in zip(ans, params): dW, db = ans_pair W, b = param_pair self.assertEqual(dW.shape, (batch_size,) + W.shape) self.assertEqual(db.shape, (batch_size,) + b.shape) def testJacobians(self): def jacbwd(f, x): y, pullback = vjp(f, x) std_basis = onp.eye(onp.size(y)).reshape((-1,) + onp.shape(y)) jac_flat, = vmap(pullback, out_axes=onp.ndim(y))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,)) std_basis = onp.eye(onp.size(x)).reshape((-1,) + onp.shape(x)) y, jac_flat = vmap(pushfwd, out_axes=(None, 0))(std_basis) return jac_flat.reshape(onp.shape(y) + onp.shape(x)) R = onp.random.RandomState(0).randn A = R(4, 3) b = R(4) f = lambda x: np.tanh(np.dot(A, x) + b) x = R(3) self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False) def testBatchOfCompile(self): side = [] @jit def f(x): side.append(None) return x + x g = jit(vmap(f)) self.assertAllClose(g(onp.ones(2)), 2 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) self.assertAllClose(g(2 * onp.ones(2)), 4 * onp.ones(2), check_dtypes=False) self.assertEqual(len(side), 1) def testSliceLax(self): fun = lambda x: lax.slice(x, (2,), (4,)) R = onp.random.RandomState(0).randn x = R(5, 10) ans = vmap(fun)(x) expected_ans = x[:, 2:4] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testSliceNumpy(self): fun = lambda x: x[:, 2] R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = x[:, :, 2] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = onp.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testRevNumpy(self): fun = lambda x: x[:, ::-1] R = onp.random.RandomState(0).randn x = R(3, 2, 4) ans = vmap(fun)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[:, :, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (2,), 2)(x) expected_ans = x[:, ::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpMaximum(self): fun = lambda x: np.maximum(x, 0.0) R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(fun)(x) expected_ans = onp.maximum(x, 0.0) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testNpGtrThan(self): R = onp.random.RandomState(0).randn x = R(10, 5, 3, 7) ans = vmap(lambda x: x > 1.0)(x) expected_ans = x > 1.0 self.assertAllClose(ans, expected_ans, check_dtypes=True) def testNpMaximumPerExampleGrad(self): R = onp.random.RandomState(0).randn x = R(10, 5) W = R(5, 5) fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0) ** 2) ans = vmap(partial(grad(fun), W))(x) W_t = np.transpose(W) for i in range(10): x_ex = x[i:i + 1] expected_ans = 2.0 * np.dot( np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex) expected_ans = np.transpose(expected_ans) self.assertAllClose(ans[i], expected_ans, check_dtypes=False) def testDotGeneral(self): R = onp.random.RandomState(0).randn x = R(10, 3, 4, 5) y = R(10, 3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun)(x, y) expected = lax.dot_general(x, y, [((3,), (2,)), ((0, 1), (0, 1))]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 10, 5) y = R(3, 10, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(2, 1))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) expected = onp.stack([fun(x[..., i, :], y[:, i, ...]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5, 10) y = R(3, 5, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(3, None))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) expected = onp.stack([fun(x[..., i], y) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) x = R(3, 4, 5) y = R(3, 5, 10, 6) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) ans = vmap(fun, in_axes=(None, 2))(x, y) fun = lambda x, y: lax.dot_general(x, y, [((2,), (1,)), ((0,), (0,))]) expected = onp.stack([fun(x, y[..., i, :]) for i in range(10)]) self.assertAllClose(ans, expected, check_dtypes=True) def testDot(self): # these tests are based on @shoyer's notebook studying gufuncs def vecvec(a, b): dot = np.dot for ndim in range(1, max(a.ndim, b.ndim)): a_ax = 0 if a.ndim > ndim else None b_ax = 0 if b.ndim > ndim else None dot = vmap(dot, in_axes=(a_ax, b_ax)) return dot(a, b) assert vecvec(np.zeros((3,)), np.zeros((3,))).shape == () assert vecvec(np.zeros((2, 3)), np.zeros((3,))).shape == (2,) assert vecvec(np.zeros((4, 2, 3)), np.zeros((3,))).shape == (4, 2) def testDot2(self): R = onp.random.RandomState(0).randn xs = R(10, 3) ys = R(10, 3) ans = vmap(np.dot)(xs, ys) expected = onp.einsum('ni,ni->n', xs, ys) self.assertAllClose(ans, expected, check_dtypes=False) def testPad(self): R = onp.random.RandomState(0).randn fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1)]) x = R(5, 10).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda x: lax.pad(x, onp.float32(0), [(1, 2, 1), (0, 1, 0)]) x = R(5, 10, 3).astype(onp.float32) ans = vmap(fun)(x) expected_ans = np.stack(list(map(fun, x))) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testConcatenate(self): R = lambda *shape: onp.random.RandomState(0).randn(*shape).astype(onp.float32) fun = lambda *args: lax.concatenate(args, dimension=0) x, y, z = R(10, 2, 3), R(1, 10, 3), R(4, 3) ans = vmap(fun, in_axes=(0, 1, None))(x, y, z) expected_ans = onp.concatenate([x, onp.swapaxes(y, 0, 1), onp.broadcast_to(z, (10, 4, 3))], 1) self.assertAllClose(ans, expected_ans, check_dtypes=False) fun = lambda *args: lax.concatenate(args, dimension=1) x, y, z = R(10, 2, 1), R(2, 3), R(2, 4, 10) ans = vmap(fun, in_axes=(0, None, 2))(x, y, z) expected_ans = onp.concatenate([x, onp.broadcast_to(y, (10, 2, 3)), onp.moveaxis(z, 2, 0)], 2) self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): return np.array([x for x in xs]) xs = np.ones((5, 1)) jacrev(func)(xs) # don't crash jacfwd(func)(xs) # don't crash def testAny(self): # test modeling the code in https://github.com/google/jax/issues/108 ans = vmap(np.any)(np.array([[True, False], [False, False]])) expected = np.array([True, False]) self.assertAllClose(ans, expected, check_dtypes=True) @jtu.skip_on_devices("tpu") def testHessian(self): # test based on code from sindhwani@google def fun(x, t): return np.sum(np.power(np.maximum(x, 0.0), 2)) + t x = onp.array([-1., -0.5, 0., 0.5, 1.0]) ans = hessian(lambda x: fun(x, 0.0))(x) expected = onp.array([[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0.,0.5, 0., 0.], [0., 0., 0., 2., 0.], [0., 0., 0., 0., 2.]]) self.assertAllClose(ans, expected, check_dtypes=False) def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax x = onp.arange(30).reshape((10, 3)) ans = vmap(lambda x, i: x[i], in_axes=(0, None))(x, 1) expected = x[:, 1] self.assertAllClose(ans, expected, check_dtypes=False) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(0, 0))(x, idx) expected = x[onp.arange(10), idx] self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(3) idx = onp.array([0, 1, 2, 1, 0] * 2) ans = vmap(lambda x, i: x[i], in_axes=(None, 0))(x, idx) expected = x[idx] self.assertAllClose(ans, expected, check_dtypes=False) def testRandom(self): seeds = vmap(random.PRNGKey)(onp.arange(10)) ans = vmap(partial(random.normal, shape=(3, 2)))(seeds) expected = onp.stack([random.normal(random.PRNGKey(seed), (3, 2)) for seed in onp.arange(10)]) self.assertAllClose(ans, expected, check_dtypes=False) assert len(onp.unique(ans)) == 10 * 3 * 2 def testSortKeyVal(self): k = onp.arange(12)[::-1].reshape(3, 4) v = onp.random.RandomState(0).permutation(12).reshape(3, 4) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 0))(k, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 1), 1)(k, v) self.assertAllClose(sk, k[::-1, :], check_dtypes=True) self.assertAllClose(sv, v[::-1, :], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (0, 1))(k, v.T) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, 0))(k.T, v) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (None, 0))(k[0], v) self.assertAllClose(sk, onp.broadcast_to(k[0, ::-1], (3, 4)), check_dtypes=True) self.assertAllClose(sv, v[:, ::-1], check_dtypes=True) sk, sv = vmap(partial(lax.sort_key_val, dimension=0), (1, None))(k.T, v[0]) self.assertAllClose(sk, k[:, ::-1], check_dtypes=True) self.assertAllClose(sv, onp.broadcast_to(v[0, ::-1], (3, 4)), check_dtypes=True) def testConvGeneralDilated(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) return y grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ np.reshape(g, (1,) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testMaxPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, -np.inf, lax.max, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ np.reshape(g, (1,) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testSumPool(self): W = np.array(onp.random.randn(3, 3, 1, 5), dtype=onp.float32) X = np.array(onp.random.randn(10, 5, 5, 1), dtype=onp.float32) def f(params, x): one = (1, 1) dimension_numbers = ('NHWC', 'HWIO', 'NHWC') y = lax.conv_general_dilated( x, params, one, 'SAME', one, one, dimension_numbers) y = lax.reduce_window( y, 0.0, lax.add, (1, 2, 2, 1), (1, 1, 1, 1), 'SAME') return y grad_loss = grad(lambda params, x: np.mean(f(params, x) ** 2)) # Test forward prop. per_example = vmap(partial(f, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example = np.reshape(per_example, (10, 5, 5, 5)) per_example_direct = f(W, X) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) # Test gradients. per_example = vmap(partial(grad_loss, W))(np.reshape(X, (10, 1, 5, 5, 1))) per_example_direct = [] for i in range(10): g = grad_loss(W, np.reshape(X[i], (1, 5, 5, 1))) per_example_direct += [ np.reshape(g, (1,) + g.shape)] per_example_direct = np.concatenate(per_example_direct, axis=0) self.assertAllClose(per_example, per_example_direct, check_dtypes=True) def testSelect(self): pred = onp.array([True, False]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select)(pred, on_true, on_false) expected = onp.array([0, 3]) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1]) on_false = onp.array([2, 3]) ans = vmap(lax.select, (0, None, None))(pred, on_true, on_false) expected = onp.array([[2, 3], [0, 1]]) self.assertAllClose(ans, expected, check_dtypes=True) pred = True on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (None, 0, None))(pred, on_true, on_false) expected = onp.array([0, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([0, 1], onp.float32) on_false = onp.array(3, onp.float32) ans = vmap(lax.select, (0, 0, None))(pred, on_true, on_false) expected = onp.array([3, 1], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) pred = onp.array([False, True]) on_true = onp.array([2], onp.float32) on_false = onp.array([[3, 4]], onp.float32) ans = vmap(lax.select, (0, None, 1), 1)(pred, on_true, on_false) expected = onp.array([[3, 2]], onp.float32) self.assertAllClose(ans, expected, check_dtypes=True) def testLaxLinalgCholesky(self): a = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) a = onp.matmul(a, onp.conj(onp.swapaxes(a, -1, -2))) ans = vmap(lax_linalg.cholesky)(a) expected = onp.linalg.cholesky(a) self.assertAllClose(ans, expected, check_dtypes=False) b = onp.random.RandomState(0).randn(10, 5, 5).astype(onp.float32) b = onp.matmul(b, onp.conj(onp.swapaxes(b, -1, -2))) b_trans = onp.swapaxes(b, 0, 1) # shape is (5, 10, 5) ans = vmap(lax_linalg.cholesky, in_axes=1, out_axes=0)(b_trans) expected = onp.linalg.cholesky(b) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx} for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1,)), (1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,), index_vector_dim=1), (2,)), (1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (axis, None))(operand, idxs) expected = onp.stack([fun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx} for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (3, 5), onp.array([0, 2]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1,)), (1, (10, 3), onp.array([0, 0, 0]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,), index_vector_dim=1), (2,)), (1, (10, 3, 5), onp.array([0, 2, 1]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1, 3)), (2, (10, 5, 3), onp.array([[0, 2], [1, 0]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedOperand(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (axis, None))(operand, idxs) expected = onp.stack([gfun(operand[(slice(None),) * axis + (i,)], idxs) for i in range(operand.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx} for dtype in [onp.float32, onp.int32] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1,)), (1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,), index_vector_dim=1), (2,)), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1, 3)), (0, (10, 5), onp.array([[[0, 2], [1, 0]], [[1, 2], [0, 3]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) ans = vmap(fun, (None, axis))(operand, idxs) expected = onp.stack([fun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), axis, idxs, dnums, slice_sizes), "axis": axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx} for dtype in [onp.float32, onp.float64] for axis, shape, idxs, dnums, slice_sizes in [ (0, (5,), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1,)), (1, (10,), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,), index_vector_dim=1), (2,)), (1, (10, 5), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1, 3)), (0, (10, 5), onp.array([[[0, 2], [1, 0]], [[1, 2], [0, 3]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedIndices(self, axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) ans = vmap(gfun, (None, axis))(operand, idxs) expected = onp.stack([gfun(operand, idxs[(slice(None),) * axis + (i,)]) for i in range(idxs.shape[axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx} for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1,)), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,), index_vector_dim=1), (2,)), (0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(fun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([fun(operand[(slice(None),) * op_axis + (i,)], idxs[(slice(None),) * idxs_axis + (i,)]) for i in range(idxs.shape[idxs_axis])]) self.assertAllClose(ans, expected, check_dtypes=False) @parameterized.named_parameters( {"testcase_name": "_shape={}_op_axis={}_idxs_axis={}_idxs={}_dnums={}_slice_sizes={}".format( jtu.format_shape_dtype_string(shape, dtype), op_axis, idxs_axis, idxs, dnums, slice_sizes), "op_axis": op_axis, "idxs_axis": idxs_axis, "shape": shape, "dtype": dtype, "idxs": idxs, "dnums": dnums, "slice_sizes": slice_sizes, "rng": rng, "rng_idx": rng_idx} for dtype in [onp.float32, onp.int32] for op_axis, idxs_axis, shape, idxs, dnums, slice_sizes in [ (0, 0, (2, 5), onp.array([[0, 2], [1, 3]]), lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1,)), (1, 1, (10, 2), onp.array([[0, 0, 0], [0, 2, 1]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,), index_vector_dim=1), (2,)), (0, 1, (2, 10, 5,), onp.array([[0, 2, 1], [0, 3, 3]]).T, lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,), index_vector_dim=1), (1, 3)), (2, 0, (10, 5, 2), onp.array([[[0, 2], [1, 0]], [[1, 0], [2, 0]]]), lax.GatherDimensionNumbers( offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0, 1), index_vector_dim=1), (1, 3)), ] for rng_idx in [jtu.rand_int(max(shape))] for rng in [jtu.rand_default()]) def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs, dnums, slice_sizes, rng, rng_idx): fun = partial(lax.gather, dimension_numbers=dnums, slice_sizes=slice_sizes) gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx)))) operand = rng(shape, dtype) assert operand.shape[op_axis] == idxs.shape[idxs_axis] ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs) expected = onp.stack([gfun(operand[(slice(None),) * op_axis + (i,)], idxs[(slice(None),) * idxs_axis + (i,)]) for i in range(idxs.shape[idxs_axis])]) self.assertAllClose(ans, expected, check_dtypes=False) def testNumpyIndexing1(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) ind = onp.array([[0, 1], [2, 0]]) def f(a, ind): return a[:, ind] expected = onp.stack([f(a, ind[i, :]) for i in range(ind.shape[0])]) ans = vmap(f, (None, 0))(a, ind) assert onp.all(ans == expected) def testNumpyIndexing2(self): a = np.arange(2 * 3 * 4).reshape((2, 3, 4)) def f(a): inds = np.array([0, 2]) return a[:, inds] ans = vmap(f)(a) expected = onp.stack([f(a[:, i, :]) for i in range(a.shape[1])], axis=1) assert onp.all(ans == expected) def testTranspose(self): x = onp.arange(4 * 3 * 3).reshape((4, 3, 3)) ans = vmap(lambda x: x + x.T)(x) expected = x + onp.swapaxes(x, -1, -2) self.assertAllClose(ans, expected, check_dtypes=False) def testTransposePermutation(self): x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 0, 2)))(x) expected = onp.transpose(x, (0, 2, 1, 3)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((6, 3, 4, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)))(x) expected = onp.transpose(x, (0, 2, 3, 1)) self.assertAllClose(ans, expected, check_dtypes=False) x = onp.arange(6 * 3 * 4 * 5).reshape((3, 4, 6, 5)) ans = vmap(lambda x: np.transpose(x, (1, 2, 0)), in_axes=2)(x) expected = onp.transpose(x, (2, 1, 3, 0)) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue354(self): psd_mat = onp.random.randn(20, 10) psd_mat = psd_mat.T.dot(psd_mat) vec = onp.random.randn(10) def f(scale): scaled_mat = scale * psd_mat chol = np.linalg.cholesky(scaled_mat) return -0.5 * np.sum((np.einsum('ij,j->i', chol, vec))**2) vmapped_f = vmap(f) vmapped_f_grad = grad(lambda x: np.sum(vmapped_f(x))) scales = onp.array([[0.1], [0.2], [0.3], [0.4], [0.5]]) ans = vmapped_f_grad(scales) # don't crash! expected = onp.stack([grad(f)(scale) for scale in scales]) self.assertAllClose(ans, expected, check_dtypes=False) def testIssue387(self): # https://github.com/google/jax/issues/387 R = onp.random.RandomState(0).rand(100, 2) def dist_sq(R): dR = R[:, np.newaxis, :] - R[np.newaxis, :, :] zero = np.zeros_like(dR) dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR)) return np.sum(dR ** 2, axis=2) @jit def f(R): dr = dist_sq(R) return np.sum(R ** 2) H = hessian(f)(R) # don't crash on UnshapedArray
def testSlice(self, shape, dtype, starts, limits, strides, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape, ), (dtype, ), rng)