Beispiel #1
0
    def testQdwhWithRandomMatrix(self, m, n, log_cond):
        """Tests qdwh with random input."""
        rng = jtu.rand_uniform(self.rng(), low=0.3, high=0.9)
        a = rng((m, n), _QDWH_TEST_DTYPE)
        u, s, v = jnp.linalg.svd(a, full_matrices=False)
        cond = 10**log_cond
        s = jnp.linspace(cond, 1, min(m, n))
        a = (u * s) @ v
        is_symmetric = _check_symmetry(a)
        max_iterations = 10

        def lsp_linalg_fn(a):
            u, h, _, _ = qdwh.qdwh(a,
                                   is_symmetric=is_symmetric,
                                   max_iterations=max_iterations)
            return u, h

        args_maker = lambda: [a]

        # Sets the test tolerance.
        rtol = 1E6 * _QDWH_TEST_EPS

        with self.subTest('Test JIT compatibility'):
            self._CompileAndCheck(lsp_linalg_fn, args_maker)

        with self.subTest('Test against numpy.'):
            self._CheckAgainstNumpy(osp_linalg.polar,
                                    lsp_linalg_fn,
                                    args_maker,
                                    rtol=rtol,
                                    atol=1E-3)
Beispiel #2
0
 def binary_check(self,
                  fun,
                  lims=[-2, 2],
                  order=3,
                  finite=True,
                  dtype=None):
     dims = 2, 3
     rng = np.random.RandomState(0)
     if isinstance(lims, tuple):
         x_lims, y_lims = lims
     else:
         x_lims, y_lims = lims, lims
     if dtype is None:
         primal_in = (transform(x_lims, rng.rand(*dims)),
                      transform(y_lims, rng.rand(*dims)))
         series_in = ([rng.randn(*dims) for _ in range(order)],
                      [rng.randn(*dims) for _ in range(order)])
     else:
         rng = jtu.rand_uniform(rng, *lims)
         primal_in = (rng(dims, dtype), rng(dims, dtype))
         series_in = ([rng(dims, dtype) for _ in range(order)],
                      [rng(dims, dtype) for _ in range(order)])
     if finite:
         self.check_jet(fun, primal_in, series_in, atol=1e-4, rtol=1e-4)
     else:
         self.check_jet_finite(fun,
                               primal_in,
                               series_in,
                               atol=1e-4,
                               rtol=1e-4)
Beispiel #3
0
 def unary_check(self, fun, lims=[-2, 2], order=3, dtype=None):
   dims = 2, 3
   rng = np.random.RandomState(0)
   if dtype is None:
     primal_in = transform(lims, rng.rand(*dims))
     terms_in = [rng.randn(*dims) for _ in range(order)]
   else:
     rng = jtu.rand_uniform(rng, *lims)
     primal_in = rng(dims, dtype)
     terms_in = [rng(dims, dtype) for _ in range(order)]
   self.check_jet(fun, (primal_in,), (terms_in,), atol=1e-4, rtol=1e-4)
Beispiel #4
0
    def testResizeAgainstPIL(self, dtype, image_shape, target_shape, method):
        rng = jtu.rand_uniform(self.rng())
        args_maker = lambda: (rng(image_shape, dtype), )

        def pil_fn(x):
            pil_methods = {
                "bilinear": PIL_Image.BILINEAR,
                "bicubic": PIL_Image.BICUBIC,
                "lanczos3": PIL_Image.LANCZOS,
            }
            img = PIL_Image.fromarray(x.astype(np.float32))
            out = np.asarray(img.resize(target_shape[::-1],
                                        pil_methods[method]),
                             dtype=dtype)
            return out

        jax_fn = partial(image.resize,
                         shape=target_shape,
                         method=method,
                         antialias=True)
        self._CheckAgainstNumpy(pil_fn, jax_fn, args_maker, check_dtypes=True)
Beispiel #5
0
    test_name = test_name or name
    return OpRecord(name, nargs, dtypes, rng, test_grad, test_name)


JAX_SPECIAL_FUNCTION_RECORDS = [
    # TODO: digamma has no JVP implemented.
    op_record("digamma", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("erf", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfc", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("erfinv", 1, float_dtypes, jtu.rand_small_positive(), True),
    op_record("expit", 1, float_dtypes, jtu.rand_small_positive(), True),
    # TODO: gammaln has slightly high error.
    op_record("gammaln", 1, float_dtypes, jtu.rand_positive(), False),
    op_record("logit", 1, float_dtypes, jtu.rand_small_positive(), False),
    op_record("log_ndtr", 1, float_dtypes, jtu.rand_small(), True),
    op_record("ndtri", 1, float_dtypes, jtu.rand_uniform(0., 1.), True),
    op_record("ndtr", 1, float_dtypes, jtu.rand_default(), True),
]

CombosWithReplacement = itertools.combinations_with_replacement


class LaxBackedScipyTests(jtu.JaxTestCase):
    """Tests for LAX-backed Scipy implementation."""
    def _GetArgsMaker(self, rng, shapes, dtypes):
        return lambda: [
            rng(shape, dtype) for shape, dtype in zip(shapes, dtypes)
        ]

    @parameterized.named_parameters(
        jtu.cases_from_list({
Beispiel #6
0
    op_record("less_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), []),
    op_record("log", 1, number_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("logical_and", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_not", 1, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_or", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("logical_xor", 2, all_dtypes, all_shapes, jtu.rand_bool(), []),
    op_record("maximum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("minimum", 2, number_dtypes, all_shapes, jtu.rand_some_inf(), []),
    op_record("multiply", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("negative", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("not_equal", 2, number_dtypes, all_shapes, jtu.rand_some_equal(), ["rev"]),
    op_record("power", 2, inexact_dtypes, all_shapes, jtu.rand_positive(), ["rev"]),
    op_record("subtract", 2, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("sin", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cos", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tan", 1, number_dtypes, all_shapes, jtu.rand_uniform(-1.5, 1.5),
              ["rev"]),
    op_record("sinh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("cosh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("tanh", 1, number_dtypes, all_shapes, jtu.rand_default(), ["rev"]),
    op_record("arcsin", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccos", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan", 1, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctan2", 2, float_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arcsinh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arccosh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
    op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small(), ["rev"]),
]

JAX_COMPOUND_OP_RECORDS = [
    op_record("divide", 2, number_dtypes, all_shapes, jtu.rand_nonzero(), ["rev"]),