Exemple #1
0
def create_rnn_params_symmetric(input_size, state_size, output_size,
                      g, random_key):

    input_factor = 1.0 / np.sqrt(input_size)
    hidden_scale = 1.0
    hidden_factor = g / np.sqrt(state_size+1)
    predict_factor = 1.0 / np.sqrt(state_size+1)
    normal_mat = random.normal(random_key,(state_size, state_size)) * hidden_factor
    return {'hidden unit': np.array(random.normal(random_key,(1, state_size)) * hidden_scale),
            'change':       np.concatenate((random.normal(random_key,(input_size, state_size)) * input_factor,
                                            np.tril(normal_mat)+np.transpose(np.tril(normal_mat,-1)),
                                            random.normal(random_key,(1,state_size))*hidden_factor),
                                           axis=0),#hidden weights
            'predict':      random.normal(random_key,(state_size+1, output_size)) * predict_factor}#readout weights
def sdcorr_params_to_sds_and_corr_jax(sdcorr_params):
    dim = number_of_triangular_elements_to_dimension_jax(len(sdcorr_params))
    sds = jnp.array(sdcorr_params[:dim])
    corr = jnp.eye(dim)
    corr = index_update(corr, index[jnp.tril_indices(dim, k=-1)], sdcorr_params[dim:])
    corr += jnp.tril(corr, k=-1).T
    return sds, corr
Exemple #3
0
def crossover(parent_1, parent_2, offspring_size):
    all_offspring = []
    for o in range(offspring_size):
        lower_1 = np.tril(parent_1)
        upper_2 = np.triu(parent_2)
        offspring = lower_1 + upper_2
        all_offspring.append(offspring)
 def custom_assert(tst, result_jax, result_tf, *, tol, err_msg, **_):
     # cholesky_p returns garbage in the strictly upper triangular part of the
     # result, so we can safely ignore that part.
     tst.assertAllClose(jnp.tril(result_jax),
                        result_tf,
                        atol=tol,
                        err_msg=err_msg)
Exemple #5
0
def test_correlated_mvn(regularize):
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.0
    a = jnp.tril(
        0.5 * jnp.fliplr(jnp.eye(D))
        + 0.1 * jnp.exp(random.normal(random.PRNGKey(0), shape=(D, D)))
    )
    true_cov = jnp.dot(a, a.T)
    true_prec = jnp.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * jnp.dot(z.T, jnp.dot(true_prec, z))

    init_params = jnp.zeros(D)
    kernel = NUTS(
        potential_fn=potential_fn, dense_mass=True, regularize_mass_matrix=regularize
    )
    mcmc = MCMC(kernel, warmup_steps, num_samples)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples), true_mean, atol=0.02)
    assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D ** 2 < 0.02
        def custom_assert(tst, result_jax, result_tf, *, args, tol, err_msg):
            operand, = args
            lu, pivots, perm = result_tf
            batch_dims = operand.shape[:-2]
            m, n = operand.shape[-2], operand.shape[-1]

            def _make_permutation_matrix(perm):
                result = []
                for idx in itertools.product(*map(range, operand.shape[:-1])):
                    result += [0 if c != perm[idx] else 1 for c in range(m)]
                result = np.reshape(np.array(result, dtype=dtype),
                                    [*batch_dims, m, m])
                return result

            k = min(m, n)
            l = jnp.tril(lu, -1)[..., :, :k] + jnp.eye(m, k, dtype=dtype)
            u = jnp.triu(lu)[..., :k, :]
            p_mat = _make_permutation_matrix(perm)

            tst.assertArraysEqual(
                lax.linalg.lu_pivots_to_permutation(pivots, m), perm)
            tst.assertAllClose(jnp.matmul(p_mat, operand),
                               jnp.matmul(l, u),
                               atol=tol,
                               rtol=tol,
                               err_msg=err_msg)
    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads}).")

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.resid_dropout)

        dense = partial(
            nn.Dense,
            self.embed_dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )

        self.q_proj, self.k_proj, self.v_proj = dense(use_bias=False), dense(
            use_bias=False), dense(use_bias=False)
        self.out_proj = dense()

        self.causal_mask = make_causal_mask(jnp.ones(
            (1, config.max_position_embeddings), dtype="bool"),
                                            dtype="bool")
        if self.attention_type == "local":
            self.causal_mask = self.causal_mask ^ jnp.tril(
                self.causal_mask, -config.window_size)
Exemple #8
0
 def __call__(self, x):
     tril = jnp.tril(x)
     lower_triangular = jnp.all(jnp.reshape(tril == x, x.shape[:-2] + (-1,)), axis=-1)
     positive_diagonal = jnp.all(jnp.diagonal(x, axis1=-2, axis2=-1) > 0, axis=-1)
     x_norm = jnp.linalg.norm(x, axis=-1)
     unit_norm_row = jnp.all((x_norm <= 1) & (x_norm > 1 - 1e-6), axis=-1)
     return lower_triangular & positive_diagonal & unit_norm_row
Exemple #9
0
def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)):
    eps = 1e-6

    if isinstance(constraint, constraints._Boolean):
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints._GreaterThan):
        return np.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints._IntegerInterval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints._IntegerGreaterThan):
        return constraint.lower_bound + poisson(key, 5, shape=size)
    elif isinstance(constraint, constraints._Interval):
        lower_bound = np.broadcast_to(constraint.lower_bound, size)
        upper_bound = np.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif isinstance(constraint, constraints._Real):
        return random.normal(key, size)
    elif isinstance(constraint, constraints._Simplex):
        return osp.dirichlet.rvs(alpha=np.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints._Multinomial):
        n = size[-1]
        return multinomial(key, p=np.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1])
    elif isinstance(constraint, constraints._CorrCholesky):
        return signed_stick_breaking_tril(
            random.uniform(key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                           minval=-1, maxval=1))
    elif isinstance(constraint, constraints._LowerCholesky):
        return np.tril(random.uniform(key, size))
    elif isinstance(constraint, constraints._PositiveDefinite):
        x = random.normal(key, size)
        return np.matmul(x, np.swapaxes(x, -2, -1))
    else:
        raise NotImplementedError('{} not implemented.'.format(constraint))
Exemple #10
0
 def __call__(self, x):
     tril = np.tril(x)
     lower_triangular = np.all(np.reshape(tril == x, x.shape[:-2] + (-1, )),
                               axis=-1)
     positive_diagonal = np.all(np.diagonal(x, axis1=-2, axis2=-1) > 0,
                                axis=-1)
     return lower_triangular & positive_diagonal
Exemple #11
0
 def testTriangularSolveGradPrecision(self):
     rng = jtu.rand_default()
     a = np.tril(rng((3, 3), onp.float32))
     b = rng((1, 3), onp.float32)
     jtu.assert_dot_precision(lax.Precision.HIGHEST,
                              partial(jvp, lax_linalg.triangular_solve),
                              (a, b), (a, b))
Exemple #12
0
def _band_part(input, num_lower, num_upper, name=None):  # pylint: disable=redefined-builtin
    del name
    result = input
    if num_lower > -1:
        result = np.triu(result, -num_lower)
    if num_upper > -1:
        result = np.tril(result, num_upper)
    return result
Exemple #13
0
 def testSolveTriangularGrad(self, lower, transpose_a, unit_diagonal,
                             lhs_shape, rhs_shape, dtype, rng):
   _skip_if_unsupported_type(dtype)
   A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
   A = A if lower else T(A)
   B = rng(rhs_shape, dtype)
   f = partial(jsp.linalg.solve_triangular, lower=lower, trans=transpose_a,
               unit_diagonal=unit_diagonal)
   jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
Exemple #14
0
 def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape, rhs_shape,
                             dtype, rng):
     A = np.tril(
         rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
     A = A if lower else T(A)
     B = rng(rhs_shape, dtype)
     f = partial(jsp.linalg.solve_triangular,
                 lower=lower,
                 trans=transpose_a)
     jtu.check_grads(f, (A, B), 2, rtol=2e-2, eps=1e-3)
Exemple #15
0
        def inverse_fun(params, inputs, **kwargs):
            L, U, S = params
            L = np.tril(L, -1) + identity
            U = np.triu(U, 1)
            W = P @ L @ (U + np.diag(S))

            outputs = inputs @ linalg.inv(W)
            log_det_jacobian = np.full(inputs.shape[:1],
                                       -np.log(np.abs(S)).sum())
            return outputs, log_det_jacobian
Exemple #16
0
 def testSolveTriangularGrad(self, lower, transpose_a, lhs_shape,
                                    rhs_shape, dtype, rng):
   # TODO(frostig): change ensemble to support a bigger rtol
   self.skipTest("rtol does not cover all devices and precision modes")
   A = np.tril(rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
   A = A if lower else T(A)
   B = rng(rhs_shape, dtype)
   f = partial(jsp.linalg.solve_triangular, lower=lower,
               trans=1 if transpose_a else 0)
   jtu.check_grads(f, (A, B), 2, rtol=1e-3)
Exemple #17
0
 def house_qr_j_lt_m(H, j):
     m, n = H.shape
     Htri = jnp.tril(H)
     v, thisbeta = house_padded(Htri[:, j], j)
     #  Hjj = jax.lax.dynamic_slice(H, (j, j), (m-j, n-j))  # H[j:, j:]
     #  H_update = house_leftmult(Hjj, v, thisbeta)
     #  H = index_update(H, index[:, :],
     #                   jax.lax.dynamic_update_slice(H, H_update, [j, j]))
     #  H = index_update(H, index[:, :],
     #                   jax.lax.dynamic_update_slice(H, v[1:], [j+1, j]))
     return H, thisbeta
Exemple #18
0
 def testSolveTriangularBlockedGrad(self, lower, transpose_a, lhs_shape,
                                    rhs_shape, dtype, rng):
     # TODO(frostig): change ensemble to support a bigger rtol
     A = np.tril(
         rng(lhs_shape, dtype) + 5 * onp.eye(lhs_shape[-1], dtype=dtype))
     A = A if lower else T(A)
     B = rng(rhs_shape, dtype)
     f = partial(scipy.linalg.solve_triangular,
                 lower=lower,
                 trans=1 if transpose_a else 0)
     jtu.check_grads(f, (A, B), 2, rtol=1e-3)
Exemple #19
0
 def testTriangularSolveBatching(self, left_side, a_shape, b_shape, bdims):
   rng = jtu.rand_default()
   A = np.tril(rng(a_shape, onp.float32)
               + 5 * onp.eye(a_shape[-1], dtype=onp.float32))
   B = rng(b_shape, onp.float32)
   solve = partial(lax_linalg.triangular_solve, lower=True,
                   transpose_a=False, conjugate_a=False,
                   unit_diagonal=False, left_side=left_side)
   X = vmap(solve, bdims)(A, B)
   matmul = partial(np.matmul, precision=lax.Precision.HIGHEST)
   Y = matmul(A, X) if left_side else matmul(X, A)
   onp.testing.assert_allclose(Y - B, 0, atol=1e-5)
Exemple #20
0
    def build(self, sc: OrderedDict):
        gp_matrices = OrderedDict()
        gp_matrices["l_u_beta"] = aux_math.matrix_diag_transform(np.tril(sc["S_u_beta"]), nn.softplus)
        gp_matrices["l_u_gamma"] = aux_math.matrix_diag_transform(np.tril(sc["S_u_gamma"]), nn.softplus)

        # Kernel
        gp_matrices["k_beta_beta"] = self.kernel.matrix(sc["X_u_beta"], sc["X_u_beta"])
        gp_matrices["l_beta_beta"] = scipy.linalg.cholesky(gp_matrices["k_beta_beta"] +
                                                           self.jitter * np.eye(self.n_inducing_beta),
                                                           lower=True)

        k_gamma_gamma = \
            self.kernel.matrix(sc["X_u_gamma"], sc["X_u_gamma"]) + self.jitter * np.eye(self.n_inducing_gamma)

        k_beta_gamma = self.kernel.matrix(sc["X_u_beta"], sc["X_u_gamma"])

        l_beta_inv_k_beta_gamma = scipy.linalg.solve_triangular(gp_matrices["l_beta_beta"], k_beta_gamma,
                                                                lower=True)
        gp_matrices["l_beta_inv_k_beta_gamma"] = l_beta_inv_k_beta_gamma

        c_gamma_gamma = k_gamma_gamma - np.matmul(
            np.transpose(gp_matrices["l_beta_inv_k_beta_gamma"], (0, 2, 1)), gp_matrices["l_beta_inv_k_beta_gamma"])

        gp_matrices["l_gamma_gamma"] = scipy.linalg.cholesky(c_gamma_gamma +
                                                             self.jitter * np.eye(self.n_inducing_gamma), lower=True)

        # U_beta_dists
        gp_matrices["q_u_beta_mean"] = sc["mu_u_beta"]
        gp_matrices["q_u_beta_tril"] = gp_matrices["l_u_beta"]
        gp_matrices["p_u_beta_mean"] = np.zeros([self.ndims_out, self.n_inducing_beta])
        gp_matrices["p_u_beta_tril"] = gp_matrices["l_beta_beta"]

        # U_gamma_dists
        gp_matrices["q_u_gamma_mean"] = sc["mu_u_gamma"]
        gp_matrices["q_u_gamma_tril"] = gp_matrices["l_u_gamma"]
        gp_matrices["p_u_gamma_mean"] = np.zeros([self.ndims_out, self.n_inducing_gamma])
        gp_matrices["p_u_gamma_tril"] = gp_matrices["l_gamma_gamma"]

        return gp_matrices
def K_net(x):
    k=1024
    mlp = hk.Sequential([
              hk.Linear(k), jax.nn.swish,
              hk.Linear(k), jax.nn.swish,
              hk.Linear(k), jax.nn.swish,
              hk.Linear(k), jax.nn.swish,
              hk.Linear(9)])
    y = radial_hyperbolic_compactification(x)
    out = mlp(y)
    B = jnp.tril(out.reshape((3,3)))
    K = B+B.T
    return 0*.05*K
def cov_params_to_matrix_jax(cov_params):
    """Build covariance matrix from 1d array with its lower triangular elements.

    Args:
        cov_params (np.array): 1d array with the lower triangular elements of a
            covariance matrix (in C-order)

    Returns:
        cov (np.array): a covariance matrix

    """
    lower = chol_params_to_lower_triangular_matrix_jax(cov_params)
    cov = lower + jnp.tril(lower, k=-1).T
    return cov
def gamma_net(x):
    k=1024
    mlp = hk.Sequential([
              hk.Linear(k), jax.nn.swish,
              hk.Linear(k), jax.nn.swish,
              hk.Linear(k), jax.nn.swish,
              hk.Linear(k), jax.nn.swish,
              hk.Linear(9)])
    y = radial_hyperbolic_compactification(x)
    out = mlp(y)
    A = jnp.tril(out.reshape((3,3)))
    L = jnp.eye(3) + .3*(1-jnp.linalg.norm(y))*A
    gamma = [email protected]
    return gamma
Exemple #24
0
 def testTriangularSolveGrad(
     self, lower, transpose_a, conjugate_a, unit_diagonal, left_side, a_shape,
     b_shape, dtype, rng_factory):
   _skip_if_unsupported_type(dtype)
   rng = rng_factory()
   # Test lax_linalg.triangular_solve instead of scipy.linalg.solve_triangular
   # because it exposes more options.
   A = np.tril(rng(a_shape, dtype) + 5 * onp.eye(a_shape[-1], dtype=dtype))
   A = A if lower else T(A)
   B = rng(b_shape, dtype)
   f = partial(lax_linalg.triangular_solve, lower=lower,
               transpose_a=transpose_a, conjugate_a=conjugate_a,
               unit_diagonal=unit_diagonal, left_side=left_side)
   jtu.check_grads(f, (A, B), 2, rtol=4e-2, eps=1e-3)
Exemple #25
0
    def test_mask_arg(self):
        seq_len = 3
        embed_size = 2
        model_size = 15
        query = key = value = jnp.zeros((seq_len, embed_size))
        causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))
        causal_mask = causal_mask[None, :, :]

        mha = attention.MultiHeadAttention(key_size=7,
                                           num_heads=11,
                                           value_size=13,
                                           model_size=model_size,
                                           w_init_scale=1.0)(query,
                                                             key,
                                                             value,
                                                             mask=causal_mask)
        self.assertEqual(mha.shape, (seq_len, model_size))
Exemple #26
0
def test_correlated_mvn():
    # This requires dense mass matrix estimation.
    D = 5

    warmup_steps, num_samples = 5000, 8000

    true_mean = 0.
    a = np.tril(0.5 * np.fliplr(np.eye(D)) + 0.1 * np.exp(random.normal(random.PRNGKey(0), shape=(D, D))))
    true_cov = np.dot(a, a.T)
    true_prec = np.linalg.inv(true_cov)

    def potential_fn(z):
        return 0.5 * np.dot(z.T, np.dot(true_prec, z))

    init_params = np.zeros(D)
    samples = mcmc(warmup_steps, num_samples, init_params, potential_fn=potential_fn, dense_mass=True)
    assert_allclose(np.mean(samples), true_mean, atol=0.02)
    assert onp.sum(onp.abs(onp.cov(samples.T) - true_cov)) / D**2 < 0.02
Exemple #27
0
 def setup(self):
     self.model = DeepNetwork(H=self.H, kernel_size=self.kernel_size)
     if self.params is None:
         self.params = self.model.init(
             jax.random.PRNGKey(0),
             jnp.expand_dims(jnp.ones([self.history_len]),
                             axis=(0, 2)))["params"]
     # linear feature transform:
     # errs -> [average of last h errs, ..., average of last 2 errs, last err]
     # emulates low-pass filter bank
     self.featurizer = jnp.tril(
         jnp.ones((self.history_len, self.history_len)))
     self.featurizer /= jnp.expand_dims(jnp.arange(self.history_len, 0, -1),
                                        axis=0)
     # TODO(dsuo): resolve jit load/unload
     if self.use_model_apply_jit:
         self.model_apply = jax.jit(self.model.apply)
     else:
         self.model_apply = self.model.apply
Exemple #28
0
    def _log_probs(self, X: np.ndarray, alpha: np.ndarray) -> np.ndarray:
        """Compute class log probabilities for X.

        Args:
            X: An array of shape ``(n_samples, n_features)`` containing the training
                examples.
            alpha: The SVM normal vector scales. Normally this should be
                ``self.alpha_``, but we leave it as an argument so we can differentiate
                through this function when fitting.

        Returns:
            An array of shape ``(n_samples, n_classes)`` containing the predicted log
            probabilities for each class.

        """
        n = alpha.shape[0]
        L = jnp.tril(np.ones((n, n)))
        A = jnp.diag(alpha)
        likelihoods = (L @ A @ (self.coefs_ @ X.T + self.b_[:, None])).T
        return logsoftmax(likelihoods)
Exemple #29
0
  def setup(self, waveform=None):
    self.model = ActorCritic()
    if self.params is None:
      self.params = self.model.init(
          jax.random.PRNGKey(0),
          jnp.expand_dims(jnp.ones([self.history_len]), axis=(0, 1)))['params']

    # linear feature transform:
    # errs -> [average of last h errs, ..., average of last 2 errs, last err]
    # emulates low-pass filter bank

    self.featurizer = jnp.tril(jnp.ones((self.history_len, self.history_len)))
    self.featurizer /= jnp.expand_dims(
        jnp.arange(self.history_len, 0, -1), axis=0)

    if waveform is None:
      self.waveform = BreathWaveform.create()

    if self.normalize:
      self.u_scaler = u_scaler
      self.p_scaler = p_scaler
Exemple #30
0
    def _validate_butcher_tableau(alphas: jnp.array, betas: jnp.array,
                                  gammas: jnp.array) -> None:
        _error_msg = []
        if len(alphas) != len(gammas):
            _error_msg.append(
                "Alpha and gamma vectors must have the same length")

        if betas.shape[0] != betas.shape[1]:
            _error_msg.append("Betas must be a quadratic matrix with the same "
                              "dimension as the alphas/gammas arrays")

        # for an explicit method, betas must be lower triangular
        if not jnp.allclose(betas, jnp.tril(betas, k=-1)):
            _error_msg.append("The beta matrix has to be lower triangular for "
                              "an explicit Runge-Kutta method, i.e. "
                              "b_ij = 0 for i <= j")

        if _error_msg:
            raise ValueError("An error occurred while validating the Input "
                             "Butcher tableau. More information: "
                             "{}.".format(",".join(_error_msg)))