Ejemplo n.º 1
0
    def test_dtype_warning(self):
        # cf. issue #1230
        if FLAGS.jax_enable_x64:
            return  # test only applies when x64 is disabled

        def check_warning(warn, nowarn):
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter("always")

                nowarn()  # get rid of extra startup warning

                prev_len = len(w)
                nowarn()
                assert len(w) == prev_len

                warn()
                assert len(w) > 0
                msg = str(w[-1].message)
                expected_prefix = "Explicitly requested dtype "
                self.assertEqual(expected_prefix, msg[:len(expected_prefix)])

                prev_len = len(w)
                nowarn()
                assert len(w) == prev_len

        check_warning(
            lambda: np.array([1, 2, 3], dtype="float64"),
            lambda: np.array([1, 2, 3], dtype="float32"),
        )
        check_warning(lambda: np.ones(3, dtype=onp.float64),
                      lambda: np.ones(3))
        check_warning(lambda: np.ones_like(3, dtype=onp.int64),
                      lambda: np.ones_like(3, dtype=onp.int32))
        check_warning(lambda: np.zeros(3, dtype="int64"),
                      lambda: np.zeros(3, dtype="int32"))
        check_warning(lambda: np.zeros_like(3, dtype="float64"),
                      lambda: np.zeros_like(3, dtype="float32"))
        check_warning(lambda: np.full((2, 3), 1, dtype="int64"),
                      lambda: np.full((2, 3), 1))
        check_warning(lambda: np.ones(3).astype("float64"),
                      lambda: np.ones(3).astype("float32"))
        check_warning(lambda: np.eye(3, dtype=onp.float64), lambda: np.eye(3))
        check_warning(lambda: np.arange(3, dtype=onp.float64),
                      lambda: np.arange(3, dtype=onp.float32))
        check_warning(lambda: np.linspace(0, 3, dtype=onp.float64),
                      lambda: np.linspace(0, 3, dtype=onp.float32))
        check_warning(lambda: np.tri(2, dtype="float64"),
                      lambda: np.tri(2, dtype="float32"))
Ejemplo n.º 2
0
def make_cholesky_factor(l_param: np.ndarray) -> np.ndarray:
    """Get the actual cholesky factor from our parameterization of L."""
    lmask = np.tri(l_param.shape[0])
    lmask = index_update(lmask, (0, 0), 0)
    tmp = l_param * lmask
    idx = np.diag_indices(l_param.shape[0])
    return index_update(tmp, idx, np.exp(tmp[idx]))
Ejemplo n.º 3
0
 def testTri(self, m, n, k, dtype, rng):
     onp_fun = lambda: onp.tri(n, M=m, k=k, dtype=dtype)
     lnp_fun = lambda: lnp.tri(n, M=m, k=k, dtype=dtype)
     args_maker = lambda: []
     self._CheckAgainstNumpy(onp_fun,
                             lnp_fun,
                             args_maker,
                             check_dtypes=True)
     self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)
def fill_lower_diag(a):
    """
    Fill lower triangle of a matrix using a vector.
    Source: https://stackoverflow.com/questions/51439271/convert-1d-array-to-lower-triangular-matrix
    :param a: 1D numpy array
    :return:
    """
    n = int(np.sqrt(len(a)*2))+1
    mask = np.tri(n,dtype=bool, k=-1) # or np.arange(n)[:,None] > np.arange(n)
    out = np.zeros((n,n),dtype=int)
    out[mask] = a
    return out
Ejemplo n.º 5
0
def init_params(key: np.ndarray) -> Params:
    """Initiliaze the optimization parameters."""
    key, subkey = random.split(key)
    # init diagonal at 0, because it will be exponentiated
    L = 0.05 * np.tri(FLAGS.dim_theta + 1, k=-1)
    L *= random.normal(subkey, (FLAGS.dim_theta + 1, FLAGS.dim_theta + 1))
    corr = make_correlation_matrix(L)
    assert np.all(np.isclose(np.linalg.cholesky(corr),
                             make_cholesky_factor(L))), "not PSD"
    key, subkey = random.split(key)
    if FLAGS.response_type == "poly":
        mu = 0.001 * random.normal(subkey, (FLAGS.dim_theta, ))
        log_sigma = np.array(
            [np.log(1. / (i + 1)) for i in range(FLAGS.dim_theta)])
    elif FLAGS.response_type == "gp":
        mu = np.ones(FLAGS.dim_theta) / FLAGS.dim_theta
        log_sigma = 0.5 * np.ones(FLAGS.dim_theta)
    else:
        mu = 0.01 * random.normal(subkey, (FLAGS.dim_theta, ))
        log_sigma = 0.5 * np.ones(FLAGS.dim_theta)
    params = (L, mu, log_sigma)
    return params
Ejemplo n.º 6
0
def log_all_cubes_unique_intersect_volume(points, width):
    log_I = vmap(lambda x: vmap(lambda y: log_cubes_intersect_volume(
        x, y, width))(points))(points)
    return logsumexp(jnp.where(jnp.tri(points.shape[0], k=-1), log_I,
                               -jnp.inf))
Ejemplo n.º 7
0
def tri(N, M=None, k=0, dtype=None):
  return JaxArray(jnp.tri(N, M=M, k=k, dtype=dtype))