def _update_T_Z(m, T, Z): mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1), (2, 2))) - T[m, m] r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype) c = mu[0] / r s = T[m, m - 1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:] T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0) col_mask = jnp.arange(N) >= m - 1 G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0) T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols) T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0) # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1) row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1) # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1) Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m - 1, axis=1) return T, Z
def _process_parameters(self, n, p): p_ = 1. - p[..., :-1].sum(axis=-1) p, p_ = promote_shapes(p, p_) lax.dynamic_update_slice_in_dim(p, p_, 0, axis=-1) # true for bad p pcond = np.any(p < 0, axis=-1) | np.any(p > 1, axis=-1) # true for bad n n = np.array(n, dtype=np.int32) ncond = n <= 0 return n, p, ncond | pcond
def gibbs_update(rng_key, gibbs_sites, gibbs_state): u_new, pads, new_idxs, starts = _block_update_proxy( num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) new_states = defaultdict(dict) ref_subsample_log_liks = log_likelihood(ref_params_flat, new_idxs) ref_subsample_log_lik_grads = jacfwd(log_likelihood)( ref_params_flat, new_idxs) ref_subsample_log_lik_hessians = jacfwd(jacfwd(log_likelihood))( ref_params_flat, new_idxs) for stat, new_block_values, last_values in zip( ["log_liks", "grads", "hessians"], [ ref_subsample_log_liks, ref_subsample_log_lik_grads, ref_subsample_log_lik_hessians ], [ gibbs_state.ref_subsample_log_liks, gibbs_state.ref_subsample_log_lik_grads, gibbs_state.ref_subsample_log_lik_hessians ]): for name, subsample_idx in gibbs_sites.items(): size, subsample_size = subsample_plate_sizes[name] pad, start = pads[name], starts[name] new_value = jnp.pad(last_values[name], [(0, pad)] + [(0, 0)] * (jnp.ndim(last_values[name]) - 1)) new_value = lax.dynamic_update_slice_in_dim( new_value, new_block_values[name], start, 0) new_states[stat][name] = new_value[:subsample_size] gibbs_state = TaylorProxyState(new_states["log_liks"], new_states["grads"], new_states["hessians"]) return u_new, gibbs_state
def _update_block(rng_key, num_blocks, subsample_idx, plate_size): size, subsample_size = plate_size rng_key, subkey, block_key = random.split(rng_key, 3) block_size = (subsample_size - 1) // num_blocks + 1 pad = block_size - (subsample_size - 1) % block_size - 1 chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) new_idx = random.randint(subkey, minval=0, maxval=size, shape=(block_size,)) subsample_idx_padded = jnp.pad(subsample_idx, (0, pad)) start = chosen_block * block_size subsample_idx_padded = lax.dynamic_update_slice_in_dim( subsample_idx_padded, new_idx, start, 0) return rng_key, subsample_idx_padded[:subsample_size], pad, new_idx, start
def gibbs_update(rng_key, gibbs_sites, gibbs_state): u_new, pads, new_idxs, starts = _block_update_proxy( num_blocks, rng_key, gibbs_sites, subsample_plate_sizes) new_subsample_weights = {} for name, subsample_weights in gibbs_state.subsample_weights.items( ): size, subsample_size = subsample_plate_sizes[ name] # TODO: fix duplication! pad, new_idx, start = pads[name], new_idxs[name], starts[name] new_value = jnp.pad( subsample_weights, [(0, pad)] + [(0, 0)] * (jnp.ndim(subsample_weights) - 1), ) new_value = lax.dynamic_update_slice_in_dim( new_value, weights[name][new_idx], start, 0) new_subsample_weights[name] = new_value[:subsample_size] gibbs_state = VariationalProxyState(new_subsample_weights) return u_new, gibbs_state
def _dynamic_concat(a, b, m, axis=0): "Concatenates padded arrays `a` and `b` where the true size of `a` is `m`." if m is None: return jnp.concatenate([a, b], axis=axis) return lax.dynamic_update_slice_in_dim( _pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)