def g(params): """differentiable piece of objective in η problem.""" y = params[1:] if log_transform: M = utils.M(self.n, t, np.exp(y)) else: M = utils.M(self.n, t, y) L = self.C @ M ξ = self.mu0 * L.sum(1) if folded: ξ = utils.fold(ξ) else: r = expit(params[0]) ξ = (1 - r) * ξ + r * self.AM_freq @ ξ loss_term = loss(np.squeeze(ξ), x) y_delta = y - y_ref ridge_term = (ridge_penalty / 2) * (y_delta.T @ Γ @ y_delta) return loss_term + ridge_term
def set_eta(self, eta: hst.eta): r"""Set pre-specified demographic history :math:`\eta(t)` Args: eta: demographic history object """ self.η = eta t = self.η.arrays()[0] self.M = utils.M(self.n, t, self.η.y) self.L = self.C @ self.M self.r = 0
def g(logy): """differentiable piece of objective in η problem""" L = self.C @ utils.M(self.n, t, np.exp(logy)) if mask is not None: loss_term = loss(z, x[mask, :], L[mask, :]) else: loss_term = loss(z, x, L) spline_term = (α_spline / 2) * ((D1 @ logy)**2).sum() # generalized Tikhonov logy_delta = logy - logy_ref ridge_term = (α_ridge / 2) * (logy_delta.T @ Γ @ logy_delta) return loss_term + spline_term + ridge_term
def test_constant_history(self): u"""test expected SFS under constant demography and mutation rate against the analytic formula from Fu (1995) """ n = 198 η0 = 3e4 μ0 = 40 change_points = np.array([]) η = histories.eta(change_points, np.array([η0])) t, y = η.arrays() μ = histories.mu(change_points, np.array([[μ0]])) ξ_mushi = np.squeeze(utils.C(n) @ utils.M(n, t, y) @ μ.Z) ξ_Fu = 2 * η0 * μ0 / np.arange(1, n) self.assertTrue(np.isclose(ξ_mushi, ξ_Fu).all(), msg=f'\nξ_mushi:\n{ξ_mushi}\nξ_Fu:\n{ξ_Fu}')
def simulate(self, eta: hst.eta, mu: Union[hst.mu, np.float64], seed: int = None) -> None: r"""Simulate a SFS under the Poisson random field model (no linkage) assigns simulated SFS to ``X`` attribute Args: eta: demographic history mu: mutation spectrum history (or constant rate) seed: random seed """ onp.random.seed(seed) t, y = eta.arrays() M = utils.M(self.n, t, y) L = self.C @ M if type(mu) == hst.mu: if not eta.check_grid(mu): raise ValueError('η(t) and μ(t) must use the same time grid') else: mu = hst.mu(eta.change_points, mu * np.ones_like(y)) self.X = poisson.rvs(L @ mu.Z) self.mutation_types = mu.mutation_types
def infer_history( self, # noqa: C901 change_points: np.array, mu0: np.float64, eta: hst.eta = None, eta_ref: hst.eta = None, mu_ref: hst.mu = None, infer_eta: bool = True, infer_mu: bool = True, alpha_tv: np.float64 = 0, alpha_spline: np.float64 = 0, alpha_ridge: np.float64 = 0, beta_tv: np.float64 = 0, beta_spline: np.float64 = 0, beta_rank: np.float64 = 0, hard: bool = False, beta_ridge: np.float64 = 0, max_iter: int = 1000, s0: int = 1, max_line_iter=100, gamma: np.float64 = 0.8, tol: np.float64 = 0, loss: str = 'prf', mask: np.array = None, verbose: bool = False) -> None: r"""Perform sequential inference to fit :math:`\eta(t)` and :math:`\mu(t)` Args: change_points: epoch change points (ordered times > 0) mu0: total mutation rate (per genome per generation) eta: initial demographic history. By default, a constant MLE is computed eta_ref: reference demographic history for ridge penalty. If ``None``, the constant MLE is used mu_ref: reference MuSH for ridge penalty. If None, the constant MLE is used infer_eta: perform :math:`\eta` inference if ``True`` infer_mu: perform :math:`\mu` inference if ``True`` loss: loss function, 'prf' for Poisson random field, 'kl' for Kullback-Leibler divergence, 'lsq' for least-squares mask: array of bools, with False indicating exclusion of that frequency alpha_tv: total variation penalty on :math:`\eta(t)` alpha_spline: L2 on first differences penalty on :math:`\eta(t)` alpha_ridge: L2 for strong convexity penalty on :math:`\eta(t)` hard: hard rank penalty on :math:`\mu(t)` (non-convex) beta_tv: total variation penalty on :math:`\mu(t)` beta_spline: penalty on :math:`\mu(t)` beta_rank: rank penalty on :math:`\mu(t)` beta_ridge: L2 penalty on :math:`\mu(t)` max_iter: maximum number of proximal gradient steps tol: relative tolerance in objective function (if ``0``, not used) s0: max step size max_line_iter: maximum number of line search steps gamma: step size shrinkage rate for line search verbose: print verbose messages if ``True`` """ # pithify reg paramter names α_tv = alpha_tv α_spline = alpha_spline α_ridge = alpha_ridge β_tv = beta_tv β_spline = beta_spline β_rank = beta_rank β_ridge = beta_ridge assert self.X is not None, 'use simulate() to generate data first' if self.X is None: raise ValueError('use simulate() to generate data first') if mask is not None: assert len(mask) == self.X.shape[0], 'mask must have n-1 elements' # ininitialize with MLE constant η and μ x = self.X.sum(1, keepdims=True) μ_total = hst.mu(change_points, mu0 * np.ones((len(change_points) + 1, 1))) t, z = μ_total.arrays() # number of segregating variants in each mutation type S = self.X.sum(0, keepdims=True) if eta is not None: self.η = eta elif self.η is None: # Harmonic number H = (1 / np.arange(1, self.n - 1)).sum() # constant MLE y = (S.sum() / 2 / H / mu0) * np.ones(len(z)) self.η = hst.eta(change_points, y) μ_const = hst.mu(self.η.change_points, mu0 * (S / S.sum()) * np.ones( (self.η.m, self.X.shape[1])), mutation_types=self.mutation_types.values) if self.μ is None: self.μ = μ_const self.M = utils.M(self.n, t, self.η.y) self.L = self.C @ self.M # badness of fit if loss == 'prf': def loss(*args, **kwargs): return -utils.prf(*args, **kwargs) elif loss == 'kl': loss = utils.d_kl elif loss == 'lsq': loss = utils.lsq else: raise ValueError(f'unrecognized loss argument {loss}') # some matrices we'll need for the first difference penalties D = (np.eye(self.η.m, k=0) - np.eye(self.η.m, k=-1)) # W matrix deals with boundary condition W = np.eye(self.η.m) W = index_update(W, index[0, 0], 0) D1 = W @ D # 1st difference matrix # D2 = D1.T @ D1 # 2nd difference matrix if infer_eta: if verbose: print('inferring η(t)', flush=True) # Accelerated proximal gradient method: our objective function # decomposes as f = g + h, where g is differentiable and h is not. # https://people.eecs.berkeley.edu/~elghaoui/Teaching/EE227A/lecture18.pdf # Tikhonov matrix if eta_ref is None: eta_ref = self.η Γ = np.diag(np.ones_like(eta_ref.y)) else: # - log(1 - CDF) Γ = np.diag(-np.log(utils.tmrca_sf(t, eta_ref.y, self.n))[:-1]) logy_ref = np.log(eta_ref.y) @jit def g(logy): """differentiable piece of objective in η problem""" L = self.C @ utils.M(self.n, t, np.exp(logy)) if mask is not None: loss_term = loss(z, x[mask, :], L[mask, :]) else: loss_term = loss(z, x, L) spline_term = (α_spline / 2) * ((D1 @ logy)**2).sum() # generalized Tikhonov logy_delta = logy - logy_ref ridge_term = (α_ridge / 2) * (logy_delta.T @ Γ @ logy_delta) return loss_term + spline_term + ridge_term if α_tv > 0: @jit def h(logy): """nondifferentiable piece of objective in η problem""" return α_tv * np.abs(D1 @ logy).sum() def prox(logy, s): """total variation prox operator""" return ptv.tv1_1d(logy, s * α_tv) else: @jit def h(logy): return 0 def prox(logy, s): return logy # initial iterate logy = np.log(self.η.y) logy = opt.acc_prox_grad_method(logy, g, jit(grad(g)), h, prox, tol=tol, max_iter=max_iter, s0=s0, max_line_iter=max_line_iter, gamma=gamma, verbose=verbose) y = np.exp(logy) self.η = hst.eta(self.η.change_points, y) self.M = utils.M(self.n, t, y) self.L = self.C @ self.M if infer_mu and len(self.mutation_types) > 1: if verbose: print('inferring μ(t) conditioned on η(t)', flush=True) if mu_ref is None: mu_ref = μ_const # Tikhonov matrix Γ = np.diag(np.ones_like(self.η.y)) else: # - log(1 - CDF) Γ = np.diag(-np.log(utils.tmrca_sf(t, self.η.y, self.n))[:-1]) # orthonormal basis for Aitchison simplex # NOTE: instead of Gram-Schmidt could try SVD of clr transformed X # https://en.wikipedia.org/wiki/Compositional_data#Isometric_logratio_transform basis = cmp._gram_schmidt_basis(self.μ.Z.shape[1]) # initial iterate in inverse log-ratio transform Z = cmp.ilr(self.μ.Z, basis) Z_const = cmp.ilr(μ_const.Z, basis) Z_ref = cmp.ilr(mu_ref.Z, basis) @jit def g(Z): """differentiable piece of objective in μ problem""" if mask is not None: loss_term = loss(mu0 * cmp.ilr_inv(Z, basis), self.X[mask, :], self.L[mask, :]) else: loss_term = loss(mu0 * cmp.ilr_inv(Z, basis), self.X, self.L) spline_term = (β_spline / 2) * ((D1 @ Z)**2).sum() # generalized Tikhonov Z_delta = Z - Z_ref ridge_term = (β_ridge / 2) * np.trace(Z_delta.T @ Γ @ Z_delta) return loss_term + spline_term + ridge_term if β_tv and β_rank: @jit def h1(Z): """1st nondifferentiable piece of objective in μ problem""" return β_tv * np.abs(D1 @ Z).sum() shape = Z.T.shape w = β_tv * onp.ones(shape) w[:, -1] = 0 w = w.flatten()[:-1] def prox1(Z, s): """total variation prox operator on row dimension """ return ptv.tv1w_1d(Z.T, s * w).reshape(shape).T @jit def h2(Z): """2nd nondifferentiable piece of objective in μ problem""" σ = np.linalg.svd(Z - Z_const, compute_uv=False) return β_rank * np.linalg.norm(σ, 0 if hard else 1) def prox2(Z, s): """singular value thresholding""" U, σ, Vt = np.linalg.svd(Z - Z_const, full_matrices=False) if hard: σ = index_update(σ, index[σ <= s * β_rank], 0) else: σ = np.maximum(0, σ - s * β_rank) Σ = np.diag(σ) return Z_const + U @ Σ @ Vt Z = opt.three_op_prox_grad_method(Z, g, jit(grad(g)), h1, prox1, h2, prox2, tol=tol, max_iter=max_iter, s0=s0, max_line_iter=max_line_iter, gamma=gamma, ls_tol=0, verbose=verbose) else: if β_tv: @jit def h(Z): """nondifferentiable piece of objective in μ problem""" return β_tv * np.abs(D1 @ Z).sum() shape = Z.T.shape w = β_tv * onp.ones(shape) w[:, -1] = 0 w = w.flatten()[:-1] def prox(Z, s): """total variation prox operator on row dimension """ return ptv.tv1w_1d(Z.T, s * w).reshape(shape).T elif β_rank: @jit def h(Z): """nondifferentiable piece of objective in μ problem""" σ = np.linalg.svd(Z - Z_const, compute_uv=False) return β_rank * np.linalg.norm(σ, 0 if hard else 1) def prox(Z, s): """singular value thresholding""" U, σ, Vt = np.linalg.svd(Z - Z_const, full_matrices=False) if hard: σ = index_update(σ, index[σ <= s * β_rank], 0) else: σ = np.maximum(0, σ - s * β_rank) Σ = np.diag(σ) return Z_const + U @ Σ @ Vt else: @jit def h(Z): return 0 @jit def prox(Z, s): return Z Z = opt.acc_prox_grad_method(Z, g, jit(grad(g)), h, prox, tol=tol, max_iter=max_iter, s0=s0, max_line_iter=max_line_iter, gamma=gamma, verbose=verbose) self.μ = hst.mu(self.η.change_points, mu0 * cmp.ilr_inv(Z, basis), mutation_types=self.mutation_types.values)
def infer_mush( self, *trend_penalty: Tuple[int, np.float64], ridge_penalty: np.float64 = 0, rank_penalty: np.float64 = 0, hard: bool = False, mu_ref: hst.mu = None, misid_penalty: np.float64 = 1e-4, loss: str = "prf", max_iter: int = 100, tol: np.float64 = 0, line_search_kwargs: Dict = {}, trend_kwargs: Dict = {}, verbose: bool = False, ) -> None: r"""Infer mutation spectrum history :math:`\mu(t)` Args: trend_penalty: tuple ``(k, λ)`` for :math:`k`-th order trend penalty (can pass multiple for mixed trends) ridge_penalty: ridge penalty rank_penalty: rank penalty hard: hard rank penalty (non-convex) mu_ref: reference MuSH for ridge penalty. If None, the constant MLE is used misid_penalty: ridge parameter to shrink misid rates to aggregate rate loss: loss function from :mod:`~mushi.loss_functions` module max_iter: maximum number of optimization steps tol: relative tolerance in objective function (if ``0``, not used) line_search_kwargs: line search keyword arguments, see :py:class:`mushi.optimization.LineSearcher` trend_kwargs: keyword arguments for trend filtering, see :py:meth:`mushi.optimization.TrendFilter.run` verbose: print verbose messages if ``True`` Examples: Suppose ``ksfs`` is a ``kSFS`` object, and the demography has already been fit with ``infer_eta``. The following fits a mutation spectrum history with 0-th order (piecewise constant) trend penalization of strength 100. >>> ksfs.infer_mush((0, 1e2)) Alternatively, a mixed trend solution, with constant and cubic pieces, and with rank penalization 100, is fit with >>> ksfs.infer_mush((0, 1e2), (3, 1e1), rank_penalty=1e2) The attribute ``ksfs.mu`` is now set and accessable for plotting (see :class:`~mushi.mu`). """ if self.X is None: raise TypeError("use simulate() to generate data first") self._check_eta() if self.r is None or self.r == 0: raise ValueError("ancestral misidentification rate has not been " "inferred, possibly due to folded SFS inference") if self.mutation_types is None: raise ValueError("k-SFS must contain multiple mutation types") # number of segregating variants in each mutation type S = self.X.sum(0, keepdims=True) # ininitialize with MLE constant μ μ_const = hst.mu( self.η.change_points, self.mu0 * (S / S.sum()) * np.ones((self.η.m, self.X.shape[1])), mutation_types=self.mutation_types.values, ) if self.μ is None: self.μ = μ_const t = μ_const.arrays()[0] self.M = utils.M(self.n, t, self.η.y) self.L = self.C @ self.M # badness of fit loss = getattr(loss_functions, loss) # rescale trend penalties to be comparable between orders and time grids # filter zeros from trend penalties trend_penalty = tuple((k, (self.μ.m**k / onp.math.factorial(k)) * λ) for k, λ in trend_penalty if λ > 0) if mu_ref is None: mu_ref = μ_const # Tikhonov matrix Γ = np.diag(np.ones_like(self.η.y)) else: self.μ.check_grid(mu_ref) # - log(1 - CDF) Γ = np.diag(-np.log(utils.tmrca_sf(t, self.η.y, self.n))[:-1]) # orthonormal basis for Aitchison simplex # NOTE: instead of Gram-Schmidt could try SVD of clr transformed X # https://en.wikipedia.org/wiki/Compositional_data#Isometric_logratio_transform basis = cmp._gram_schmidt_basis(self.μ.Z.shape[1]) check_orth = True if self.μ.Z.shape[1] > 2 else False # constand MLE and reference mush Z_const = cmp.ilr(μ_const.Z, basis, check_orth) Z_ref = cmp.ilr(mu_ref.Z, basis, check_orth) # weights for relating misid rates to aggregate misid rate from eta step misid_weights = self.X.sum(0) / self.X.sum() # reference composition for weighted misid (if all rates are equal) misid_ref = cmp.ilr(misid_weights, basis, check_orth) # In the following, params will hold the weighted misid composition in # the first row and the mush composition at each time in the remaining rows @jit def g(params): """differentiable piece of objective in μ problem.""" r = self.r * cmp.ilr_inv(params[0, :], basis) / misid_weights Z = params[1:, :] Ξ = self.L @ (self.mu0 * cmp.ilr_inv(Z, basis)) Ξ = Ξ * (1 - r) + self.AM_freq @ Ξ @ (self.AM_mut * r[:, np.newaxis]) loss_term = loss(Ξ, self.X) Z_delta = Z - Z_ref ridge_term = (ridge_penalty / 2) * np.sum(Z_delta * (Γ @ Z_delta)) misid_delta = params[0, :] - misid_ref misid_ridge_term = misid_penalty * np.sum(misid_delta**2) return loss_term + ridge_term + misid_ridge_term if trend_penalty: @jit def h_trend(params): """trend filtering penalty.""" return sum( λ * np.linalg.norm(np.diff(params[1:, :], k + 1, axis=0), 1) for k, λ in trend_penalty) def prox_trend(params, s): """trend filtering prox operator (no jit due to ptv module)""" k, sλ = zip(*((k, s * λ) for k, λ in trend_penalty)) trend_filterer = opt.TrendFilter(k, sλ) return params.at[1:, :].set( trend_filterer.run(params[1:, :], **trend_kwargs)) if rank_penalty: if self.mutation_types.size < 3: raise ValueError( "kSFS must have more than 2 mutation types for" " rank penalization") @jit def h_rank(params): """2nd nondifferentiable piece of objective in μ problem.""" if hard: return rank_penalty * np.linalg.matrix_rank(params[1:, :] - Z_const) # else: return rank_penalty * np.linalg.norm(params[1:, :] - Z_const, "nuc") def prox_rank(params, s): """singular value thresholding.""" U, σ, Vt = np.linalg.svd(params[1:, :] - Z_const, full_matrices=False) if hard: σ = σ.at[σ <= s * rank_penalty].set(0) else: σ = np.maximum(0, σ - s * rank_penalty) Σ = np.diag(σ) return params.at[1:, :].set(Z_const + U @ Σ @ Vt) if not hard: prox_rank = jit(prox_rank) # optimizer if trend_penalty and rank_penalty: optimizer = opt.ThreeOpProxGrad( g, jit(grad(g)), h_trend, prox_trend, h_rank, prox_rank, verbose=verbose, **line_search_kwargs, ) else: if trend_penalty: h = h_trend prox = prox_trend elif rank_penalty: h = h_rank prox = prox_rank else: @jit def h(params): return 0 @jit def prox(params, s): return params optimizer = opt.AccProxGrad(g, jit(grad(g)), h, prox, verbose=verbose, **line_search_kwargs) # initial point (note initial row is for misid rates) # --------------------------------------------------- params = np.zeros((self.μ.m + 1, self.mutation_types.size - 1)) # misid rate for each mutation type if self.r_vector is not None: r = self.r_vector else: r = self.r * np.ones(self.mutation_types.size) params = params.at[0, :].set( cmp.ilr(misid_weights * r, basis, check_orth)) # ilr transformed mush ilr_mush = cmp.ilr(self.μ.Z, basis, check_orth) # make sure it's a column vector if only 2 mutation types if ilr_mush.ndim == 1: ilr_mush = ilr_mush[:, np.newaxis] params = params.at[1:, :].set(ilr_mush) # --------------------------------------------------- # run optimizer params = optimizer.run(params, tol=tol, max_iter=max_iter) # update attributes self.r_vector = self.r * cmp.ilr_inv(params[0, :], basis) / misid_weights self.μ = hst.mu( self.η.change_points, self.mu0 * cmp.ilr_inv(params[1:, :], basis), mutation_types=self.mutation_types.values, )
def infer_eta( self, mu0: np.float64, *trend_penalty: Tuple[int, np.float64], ridge_penalty: np.float64 = 0, folded: bool = False, pts: np.float64 = 100, ta: np.float64 = None, log_transform: bool = True, eta: hst.eta = None, eta_ref: hst.eta = None, loss: str = "prf", max_iter: int = 100, tol: np.float64 = 0, line_search_kwargs: Dict = {}, trend_kwargs: Dict = {}, verbose: bool = False, ) -> None: r"""Infer demographic history :math:`\eta(t)` Args: mu0: total mutation rate (per genome per generation) trend_penalty: tuple ``(k, λ)`` for :math:`k`-th order trend penalty (can pass multiple for mixed trends) ridge_penalty: ridge penalty folded: if ``False``, infer :math:`\eta(t)` using unfolded SFS. If ``True``, can only be used with ``infer_mu=False``, and infer :math:`\eta(t)` using folded SFS. pts: number of points for time discretization ta: time (in WF generations ago) of oldest change point in time discretization. If ``None``, set automatically based on 10 * E[TMRCA] under MLE constant demography log_transform: fit :math:`\log\eta(t)` eta: initial demographic history. By default, a constant MLE is computed eta_ref: reference demographic history for ridge penalty. If ``None``, the constant MLE is used loss: loss function name from :mod:`~mushi.loss_functions` module max_iter: maximum number of optimization steps tol: relative tolerance in objective function (if ``0``, not used) line_search_kwargs: line search keyword arguments, see :py:meth:`mushi.optimization.LineSearcher` trend_kwargs: keyword arguments for trend filtering, see :py:meth:`mushi.optimization.TrendFilter.run` verbose: print verbose messages if ``True`` Examples: Suppose ``ksfs`` is a ``kSFS`` object. Then the following fits a demographic history with 0-th order (piecewise constant) trend penalization of strength 100, assuming a mutation rate of 10 mutations per genome per generation. >>> mu0 = 1 >>> ksfs.infer_eta(mu0, (0, 1e2)) Alternatively, a mixed trend solution, with constant and cubic pieces, is fit with >>> ksfs.infer_eta(mu0, (0, 1e2), (3, 1e1)) The attribute ``ksfs.eta`` is now set and accessable for plotting (see :class:`~mushi.eta`). """ if self.X is None: raise TypeError("use simulate() to generate data first") # total SFS if self.X.ndim == 1: x = self.X else: x = self.X.sum(1) # fold the spectrum if inference is on folded SFS if folded: x = utils.fold(x) # constant MLE # Harmonic number H = (1 / np.arange(1, self.n - 1)).sum() N_const = self.X.sum() / 2 / H / mu0 if ta is None: tmrca_exp = 4 * N_const * (1 - 1 / self.n) ta = 10 * tmrca_exp change_points = np.logspace(0, np.log10(ta), pts) # ininitialize with MLE constant η if eta is not None: self.set_eta(eta) elif self.η is None: y = N_const * np.ones(change_points.size + 1) self.η = hst.eta(change_points, y) t = self.η.arrays()[0] self.mu0 = mu0 # badness of fit loss = getattr(loss_functions, loss) # Accelerated proximal gradient method: our objective function # decomposes as f = g + h, where g is differentiable and h is not. # https://people.eecs.berkeley.edu/~elghaoui/Teaching/EE227A/lecture18.pdf # rescale trend penalties to be comparable between orders and time grids # filter zeros from trend penalties trend_penalty = tuple((k, (self.η.m**k / onp.math.factorial(k)) * λ) for k, λ in trend_penalty if λ > 0) # Tikhonov matrix if eta_ref is None: eta_ref = self.η Γ = np.diag(np.ones_like(eta_ref.y)) else: self.η.check_grid(eta_ref) # - log(1 - CDF) Γ = np.diag(-np.log(utils.tmrca_sf(t, eta_ref.y, self.n))[:-1]) y_ref = np.log(eta_ref.y) if log_transform else eta_ref.y # In the following, the parameter vector params will contain the # misid rate in params[0], and y in params[1:] @jit def g(params): """differentiable piece of objective in η problem.""" y = params[1:] if log_transform: M = utils.M(self.n, t, np.exp(y)) else: M = utils.M(self.n, t, y) L = self.C @ M ξ = self.mu0 * L.sum(1) if folded: ξ = utils.fold(ξ) else: r = expit(params[0]) ξ = (1 - r) * ξ + r * self.AM_freq @ ξ loss_term = loss(np.squeeze(ξ), x) y_delta = y - y_ref ridge_term = (ridge_penalty / 2) * (y_delta.T @ Γ @ y_delta) return loss_term + ridge_term @jit def h(params): """nondifferentiable piece of objective in η problem.""" return sum(λ * np.linalg.norm(np.diff(params[1:], k + 1), 1) for k, λ in trend_penalty) def prox(params, s): """trend filtering prox operator (no jit due to ptv module)""" if trend_penalty: k, sλ = zip(*((k, s * λ) for k, λ in trend_penalty)) trend_filterer = opt.TrendFilter(k, sλ) params = params.at[1:].set( trend_filterer.run(params[1:], **trend_kwargs)) if log_transform: return params # else: # clip to minimum population size of 1 return np.clip(params, 1) # optimizer optimizer = opt.AccProxGrad(g, jit(grad(g)), h, prox, verbose=verbose, **line_search_kwargs) # initial point params = np.concatenate( (np.array([logit(1e-2)]), np.log(self.η.y) if log_transform else self.η.y)) # run optimization params = optimizer.run(params, tol=tol, max_iter=max_iter) if not folded: self.r = expit(params[0]) y = np.exp(params[1:]) if log_transform else params[1:] self.η = hst.eta(self.η.change_points, y) self.M = utils.M(self.n, t, y) self.L = self.C @ self.M
def simulate( self, eta: hst.eta, mu: Union[hst.mu, np.float64], r: np.float64 = 0, seed: int = None, ) -> None: r"""Simulate a :math:`k`-SFS under the Poisson random field model (no linkage), assign to ``X`` attribute Args: eta: demographic history mu: mutation spectrum history (or constant mutation rate) r: ancestral state misidentification rate (default 0) seed: random seed Examples: Define sample size: >>> ksfs = mushi.kSFS(n=10) Define demographic history and MuSH: >>> eta = mushi.eta(np.array([1, 100, 10000]), np.array([1e4, 1e4, 1e2, 1e4])) >>> mush = mushi.mu(eta.change_points, np.ones((4, 4)), ... ['AAA>ACA', 'ACA>AAA', 'TCC>TTC', 'GAA>GGA']) Define ancestral misidentification rate: >>> r = 0.03 Set random seed: >>> seed = 0 Run simulation and print simulated :math:`k`-SFS >>> ksfs.simulate(eta, mush, r, seed) >>> ksfs.as_df() # doctest: +NORMALIZE_WHITESPACE mutation type AAA>ACA ACA>AAA TCC>TTC GAA>GGA sample frequency 1 1118 1123 1106 1108 2 147 128 120 98 3 65 55 66 60 4 49 52 64 46 5 44 43 34 36 6 35 28 36 33 7 23 32 24 35 8 34 32 24 24 9 52 41 57 56 """ onp.random.seed(seed) t, y = eta.arrays() M = utils.M(self.n, t, y) L = self.C @ M if type(mu) == hst.mu: eta.check_grid(mu) Ξ = L @ mu.Z self.mutation_types = mu.mutation_types self.AM_mut = utils.mutype_misid(self.mutation_types) self.X = np.array( poisson.rvs((1 - r) * Ξ + r * self.AM_freq @ Ξ @ self.AM_mut)) else: ξ = mu * L.sum(1) self.X = np.array(poisson.rvs((1 - r) * ξ + r * self.AM_freq @ ξ))