def testDotGeneralContractOnly(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, bdims): rng = jtu.rand_small(self.rng()) dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng)
def test_bcoo_rdot_general_contract_and_batch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) def args_maker(): lhs = rng(lhs_shape, dtype) rhs = rng_sparse(rhs_shape, dtype) data, indices = sparse_ops.bcoo_fromdense(rhs, n_batch=n_batch, n_dense=n_dense) return data, indices, lhs, rhs def f_dense(data, indices, lhs, rhs): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) def f_sparse(data, indices, lhs, rhs): return sparse_ops.bcoo_rdot_general( lhs, data, indices, rhs_shape=rhs.shape, dimension_numbers=dimension_numbers) self._CheckAgainstNumpy(f_dense, f_sparse, args_maker) self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
def test_bcoo_dot_general_ad(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) X = rng_sparse(lhs_shape, dtype) data, indices = sparse_ops.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense) Y = rng(rhs_shape, dtype) def f_dense(Y): return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) def f_sparse(Y): return sparse_ops.bcoo_dot_general( data, indices, Y, lhs_shape=X.shape, dimension_numbers=dimension_numbers) jf_dense = jax.jacfwd(f_dense)(Y) jr_dense = jax.jacrev(f_dense)(Y) jf_sparse = jax.jacfwd(f_sparse)(Y) jr_sparse = jax.jacrev(f_sparse)(Y) tol = {} if jtu.device_under_test() == "tpu": tol = {np.float32: 5E-3} self.assertAllClose(jf_dense, jf_sparse, rtol=tol) self.assertAllClose(jr_dense, jr_sparse, rtol=tol) self.assertAllClose(jf_sparse, jr_sparse, rtol=tol)
def test_bcoo_dot_general_partial_batch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, n_batch, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) X = rng_sparse(lhs_shape, dtype) data, indices = sparse_ops.bcoo_fromdense(X, n_batch=n_batch, n_dense=n_dense) Y = rng(rhs_shape, dtype) def f_dense(X, Y): return lax.dot_general(X, Y, dimension_numbers=dimension_numbers) def f_sparse(data, indices, Y): return sparse_ops.bcoo_dot_general( data, indices, Y, lhs_shape=X.shape, dimension_numbers=dimension_numbers) for data, indices in itertools.product([data, data[:1]], [indices, indices[:1]]): X = sparse_ops.bcoo_todense(data, indices, shape=X.shape) self.assertAllClose(f_dense(X, Y), f_sparse(data, indices, Y))
def testConvGrad(self, lhs_shape, rhs_shape, dtype, strides, padding): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) conv = partial(lax.conv, window_strides=strides, padding=padding, precision=lax.Precision.HIGHEST) check_grads_bilinear(conv, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=1e-2, rtol=1e-2)
def test_custom_jvp(): @custom_jvp def f(x): return x**2 f.defjvp(lambda x: 2 * x) rng = jtu.rand_small(np.random) tu.check_lazy_fun(f, rng((1, ), 'float32'))
def testPadGrad(self, shape, dtype, pads): rng = jtu.rand_small(self.rng()) operand = rng(shape, dtype) pad = lambda operand: lax.pad(operand, np.array(0, dtype), pads) check_grads(pad, (operand,), 2, ["fwd", "rev"], eps=1.) operand = rng(shape, dtype) padding_value = np.array(0., dtype) pad = lambda operand, padding_value: lax.pad(operand, padding_value, pads) check_grads(pad, (operand, padding_value), 2, ["fwd", "rev"], eps=1.)
def testReduceWindow(self, op, init_val, dtype, shape, dims, strides, padding, base_dilation, window_dilation): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) def fun(operand): return lax.reduce_window(operand, init_val, op, dims, strides, padding, base_dilation, window_dilation) for bdims in all_bdims(shape): self._CheckBatching(fun, 3, bdims, (shape, ), (dtype, ), rng)
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers): rng = jtu.rand_small(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def testSelectAndScatterAdd(self, dtype, padding, shape, dims, strides): rng = jtu.rand_small(self.rng()) pads = lax.padtype_to_pads(shape, dims, strides, padding) def fun(operand, cotangents): return lax._select_and_scatter_add(operand, cotangents, lax.ge_p, dims, strides, pads) ones = (1,) * len(shape) cotangent_shape = api.eval_shape( lambda x: lax._select_and_gather_add(x, x, lax.ge_p, dims, strides, pads, ones, ones), np.ones(shape, dtype)).shape for bdims in all_bdims(cotangent_shape, shape): self._CheckBatching(fun, 3, bdims, (cotangent_shape, shape), (dtype, dtype), rng)
def testSelectAndGatherAdd(self, dtype, padding): rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1, ) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)
def testSelectAndGatherAdd(self, dtype, padding): if jtu.device_under_test() == "tpu" and dtype == dtypes.bfloat16: raise SkipTest( "bfloat16 _select_and_gather_add doesn't work on tpu") rng = jtu.rand_small(self.rng()) all_configs = itertools.chain( itertools.product([(4, 6)], [(2, 1), (1, 2)], [(1, 1), (2, 1), (1, 2)]), itertools.product([(3, 2, 4, 6)], [(1, 1, 2, 1), (2, 1, 2, 1)], [(1, 2, 2, 1), (1, 1, 1, 1)])) def fun(operand, tangents): pads = lax.padtype_to_pads(operand.shape, dims, strides, padding) ones = (1, ) * len(operand.shape) return lax._select_and_gather_add(operand, tangents, lax.ge_p, dims, strides, pads, ones, ones) for shape, dims, strides in all_configs: for bdims in all_bdims(shape, shape): self._CheckBatching(fun, 3, bdims, (shape, shape), (dtype, dtype), rng)
def test_bcoo_dot_general_contract_only(self, lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, n_dense): rng = jtu.rand_small(self.rng()) rng_sparse = rand_sparse(self.rng()) def args_maker(): lhs = rng_sparse(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) data, indices = sparse.bcoo_fromdense(lhs, n_dense=n_dense) return data, indices, lhs, rhs dimension_numbers = ((lhs_contracting, rhs_contracting), ([], [])) def f_dense(data, indices, lhs, rhs): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers) def f_sparse(data, indices, lhs, rhs): return sparse.bcoo_dot_general(data, indices, rhs, lhs_shape=lhs.shape, dimension_numbers=dimension_numbers) self._CheckAgainstNumpy(f_dense, f_sparse, args_maker) self._CheckAgainstNumpy(f_dense, jit(f_sparse), args_maker)
def testReduce(self, op, init_val, shape, dtype, dims, bdims): rng = jtu.rand_small(self.rng()) init_val = np.asarray(init_val, dtype=dtype) fun = lambda operand: lax.reduce(operand, init_val, op, dims) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
def test_jit(): rng = jtu.rand_small(np.random) tu.check_lazy_fun(jit(lambda x: x * 2), rng((1, ), int))
def testPad(self, shape, dtype, pads, bdims): rng = jtu.rand_small(self.rng()) fun = lambda operand: lax.pad(operand, np.array(0, dtype), pads) self._CheckBatching(fun, 5, bdims, (shape, ), (dtype, ), rng)
def test_jit_freevar(): rng = jtu.rand_small(np.random) tu.check_lazy_fun(lambda x, y: jit(lambda x: x * y)(x), rng((1, ), int), rng((1, ), int))
op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5), ["rev"]), op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]), ] JAX_COMPOUND_OP_RECORDS = [ op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]), op_record("exp2", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_positive(), [], test_name="expm1_large"), op_record("expm1", 1, number_dtypes, all_shapes, jtu.rand_small_positive(), []), op_record("floor_divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
op_record("logical_or", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("logical_xor", 2, default_dtypes, all_shapes, jtu.rand_bool(), []), op_record("maximum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("minimum", 2, default_dtypes, all_shapes, jtu.rand_some_inf(), []), op_record("multiply", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("negative", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("not_equal", 2, default_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]), op_record("power", 2, float_dtypes, all_shapes, jtu.rand_positive(), ["rev"]), op_record("subtract", 2, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sin", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cos", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tan", 1, default_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("sinh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("cosh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("tanh", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("arcsin", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arccos", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctan", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctan2", 2, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arcsinh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arccosh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), op_record("arctanh", 1, default_dtypes, all_shapes, jtu.rand_small(), ["rev"]), ] JAX_COMPOUND_OP_RECORDS = [ op_record("divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]), op_record("exp2", 1, numeric_dtypes, all_shapes, jtu.rand_default(), ["rev"]), op_record("expm1", 1, numeric_dtypes, all_shapes, jtu.rand_positive(), [], test_name="expm1_large"), op_record("expm1", 1, numeric_dtypes, all_shapes, jtu.rand_small_positive(), []), op_record("floor_divide", 2, default_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),
def op_record(name, nargs, dtypes, rng, test_grad, test_name=None): test_name = test_name or name return OpRecord(name, nargs, dtypes, rng, test_grad, test_name) JAX_SPECIAL_FUNCTION_RECORDS = [ # TODO: digamma has no JVP implemented. op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False), op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True), op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True), # TODO: gammaln has slightly high error. op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False), op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False), op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True), op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True), op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True), ] CombosWithReplacement = itertools.combinations_with_replacement class LaxBackedScipyTests(jtu.JaxTestCase): """Tests for LAX-backed Scipy implementation.""" def _GetArgsMaker(self, rng, shapes, dtypes): return lambda: [ rng(shape, dtype) for shape, dtype in zip(shapes, dtypes) ] @parameterized.named_parameters(