def test_argminmax(op, rng_factory, shape, dtype, axis, index_dtype): rng = rng_factory(np.random) def fun(x): return getattr(lax, op)(x, axis, index_dtype) args = [rng(shape, dtype)] tu.check_lazy_fun(fun, *args)
def test_dot_general_contract_and_batch(lhs_shape, rhs_shape, dimension_numbers, dtype, rng_factory): rng = rng_factory(np.random) args = [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] tu.check_lazy_fun(partial(lax.dot_general, dimension_numbers=dimension_numbers), *args, atol=1e-5)
def test_custom_jvp(): @custom_jvp def f(x): return x**2 f.defjvp(lambda x: 2 * x) rng = jtu.rand_small(np.random) tu.check_lazy_fun(f, rng((1, ), 'float32'))
def test_concatenate(dim, base_shape, dtype, num_arrs, rng_factory): rng = rng_factory(np.random) shapes = [ base_shape[:dim] + (size, ) + base_shape[dim + 1:] for size, _ in zip(itertools.cycle([3, 1, 4]), range(num_arrs)) ] args = [rng(shape, dtype) for shape in shapes] op = lambda *args: lax.concatenate(args, dim) tu.check_lazy_fun(op, *args)
def test_dot_general_contract_only(lhs_shape, rhs_shape, dtype, lhs_contracting, rhs_contracting, rng_factory): rng = rng_factory(np.random) args = [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] dimension_numbers = ((lhs_contracting, rhs_contracting), ((), ())) tu.check_lazy_fun(partial(lax.dot_general, dimension_numbers=dimension_numbers), *args, atol=1e-5)
def test_select(pred_shape, arg_shape, arg_dtype, rng_factory): rng = rng_factory(np.random) args = [ rng(pred_shape, np.bool_), rng(arg_shape, arg_dtype), rng(arg_shape, arg_dtype) ] return tu.check_lazy_fun(lax.select, *args)
def test_conv_general_dilated(lhs_shape, rhs_shape, dtype, strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, dimension_numbers, perms, rng_factory): rng = rng_factory(np.random) lhs_perm, rhs_perm = perms # permute to compatible shapes args = [ lax.transpose(rng(lhs_shape, dtype), lhs_perm), lax.transpose(rng(rhs_shape, dtype), rhs_perm) ] def fun(lhs, rhs): return lax.conv_general_dilated( lhs, rhs, strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count=feature_group_count, batch_group_count=batch_group_count) tu.check_lazy_fun(fun, *args, rtol=.005, atol=.2)
def test_conv(lhs_shape, rhs_shape, dtype, strides, padding, rng_factory): rng = rng_factory(np.random) args = [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] fun = lambda lhs, rhs: lax.conv(lhs, rhs, strides, padding) tu.check_lazy_fun(fun, *args)
def test_pad(shape, dtype, padding_config, rng_factory): rng = rng_factory(np.random) args = [rng(shape, dtype), rng((), dtype)] op = lambda *args: lax.pad(*args, padding_config) tu.check_lazy_fun(op, *args)
def test_broadcast_in_dim(inshape, dtype, outshape, dimensions, rng_factory): rng = rng_factory(np.random) args = [rng(inshape, dtype)] op = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) tu.check_lazy_fun(op, *args)
def test_slice(shape, dtype, starts, limits, strides, rng_factory): rng = rng_factory(np.random) args = [rng(shape, dtype)] op = lambda x: lax.slice(x, starts, limits, strides) tu.check_lazy_fun(op, *args)
def test_reduce(op_name, rng_factory, shape, axes, dtype, tol): rng = rng_factory(np.random) args = [rng(shape, dtype)] fun = partial(getattr(lax.lax, op_name), axes=axes) tu.check_lazy_fun(fun, *args, atol=tol, rtol=tol)
def test_jit(): rng = jtu.rand_small(np.random) tu.check_lazy_fun(jit(lambda x: x * 2), rng((1, ), int))
def test_dot(lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(np.random) args = [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] tu.check_lazy_fun(lax.dot, *args)
def test_rev(shape, dtype, dimensions, rng_factory): rng = rng_factory(np.random) arg = rng(shape, dtype) tu.check_lazy_fun(lambda x: lax.rev(x, dimensions=dimensions), arg)
def test_transpose(shape, dtype, permutation, rng_factory): rng = rng_factory(np.random) arg = rng(shape, dtype) tu.check_lazy_fun(lambda x: lax.transpose(x, permutation=permutation), arg)
def test_squeeze(shape, dtype, dimensions, rng_factory): rng = rng_factory(np.random) args = [rng(shape, dtype)] tu.check_lazy_fun(lambda x: lax.squeeze(x, dimensions), *args)
def test_jit_freevar(): rng = jtu.rand_small(np.random) tu.check_lazy_fun(lambda x, y: jit(lambda x: x * y)(x), rng((1, ), int), rng((1, ), int))
def test_nary(op_name, rng_factory, shapes, dtype, tol): rng = rng_factory(np.random) args = [rng(shape, dtype) for shape in shapes] tu.check_lazy_fun(getattr(lax, op_name), *args, atol=tol, rtol=tol)