예제 #1
0
파일: linalg.py 프로젝트: xueeinstein/jax
    def _update_T_Z(m, T, Z):
        mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1),
                                                 (2, 2))) - T[m, m]
        r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype)
        c = mu[0] / r
        s = T[m, m - 1] / r
        G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype)

        # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:]
        T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0)
        col_mask = jnp.arange(N) >= m - 1
        G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0)
        T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols)
        T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0)

        # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T
        T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1)
        row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1
        T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T
        T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH)
        T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1)

        # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T
        Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1)
        Z = lax.dynamic_update_slice_in_dim(Z,
                                            Z_cols @ G.conj().T,
                                            m - 1,
                                            axis=1)
        return T, Z
예제 #2
0
    def _process_parameters(self, n, p):
        p_ = 1. - p[..., :-1].sum(axis=-1)
        p, p_ = promote_shapes(p, p_)
        lax.dynamic_update_slice_in_dim(p, p_, 0, axis=-1)

        # true for bad p
        pcond = np.any(p < 0, axis=-1) | np.any(p > 1, axis=-1)

        # true for bad n
        n = np.array(n, dtype=np.int32)
        ncond = n <= 0

        return n, p, ncond | pcond
예제 #3
0
        def gibbs_update(rng_key, gibbs_sites, gibbs_state):
            u_new, pads, new_idxs, starts = _block_update_proxy(
                num_blocks, rng_key, gibbs_sites, subsample_plate_sizes)

            new_states = defaultdict(dict)
            ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs)
            ref_subsample_log_lik_grads = jacfwd(log_likelihood)(
                ref_params_flat, new_idxs)
            ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))(
                ref_params_flat, new_idxs)
            for stat, new_block_values, last_values in zip(
                ["log_liks", "grads", "hessians"], [
                    ref_subsample_log_liks, ref_subsample_log_lik_grads,
                    ref_subsample_log_lik_hessians
                ], [
                    gibbs_state.ref_subsample_log_liks,
                    gibbs_state.ref_subsample_log_lik_grads,
                    gibbs_state.ref_subsample_log_lik_hessians
                ]):
                for name, subsample_idx in gibbs_sites.items():
                    size, subsample_size = subsample_plate_sizes[name]
                    pad, start = pads[name], starts[name]
                    new_value = jnp.pad(last_values[name],
                                        [(0, pad)] + [(0, 0)] *
                                        (jnp.ndim(last_values[name]) - 1))
                    new_value = lax.dynamic_update_slice_in_dim(
                        new_value, new_block_values[name], start, 0)
                    new_states[stat][name] = new_value[:subsample_size]
            gibbs_state = TaylorProxyState(new_states["log_liks"],
                                           new_states["grads"],
                                           new_states["hessians"])
            return u_new, gibbs_state
예제 #4
0
def _update_block(rng_key, num_blocks, subsample_idx, plate_size):
    size, subsample_size = plate_size
    rng_key, subkey, block_key = random.split(rng_key, 3)
    block_size = (subsample_size - 1) // num_blocks + 1
    pad = block_size - (subsample_size - 1) % block_size - 1

    chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks)
    new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,))
    subsample_idx_padded = jnp.pad(subsample_idx, (0, pad))
    start = chosen_block * block_size
    subsample_idx_padded = lax.dynamic_update_slice_in_dim(
        subsample_idx_padded, new_idx, start, 0)
    return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start
예제 #5
0
        def gibbs_update(rng_key, gibbs_sites, gibbs_state):
            u_new, pads, new_idxs, starts = _block_update_proxy(
                num_blocks, rng_key, gibbs_sites, subsample_plate_sizes)

            new_subsample_weights = {}
            for name, subsample_weights in gibbs_state.subsample_weights.items(
            ):
                size, subsample_size = subsample_plate_sizes[
                    name]  # TODO: fix duplication!
                pad, new_idx, start = pads[name], new_idxs[name], starts[name]
                new_value = jnp.pad(
                    subsample_weights,
                    [(0, pad)] + [(0, 0)] * (jnp.ndim(subsample_weights) - 1),
                )
                new_value = lax.dynamic_update_slice_in_dim(
                    new_value, weights[name][new_idx], start, 0)
                new_subsample_weights[name] = new_value[:subsample_size]
            gibbs_state = VariationalProxyState(new_subsample_weights)
            return u_new, gibbs_state
예제 #6
0
def _dynamic_concat(a, b, m, axis=0):
    "Concatenates padded arrays `a` and `b` where the true size of `a` is `m`."
    if m is None:
        return jnp.concatenate([a, b], axis=axis)
    return lax.dynamic_update_slice_in_dim(
        _pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)