Example #1
0
 def g(params):
     """differentiable piece of objective in η problem."""
     y = params[1:]
     if log_transform:
         M = utils.M(self.n, t, np.exp(y))
     else:
         M = utils.M(self.n, t, y)
     L = self.C @ M
     ξ = self.mu0 * L.sum(1)
     if folded:
         ξ = utils.fold(ξ)
     else:
         r = expit(params[0])
         ξ = (1 - r) * ξ + r * self.AM_freq @ ξ
     loss_term = loss(np.squeeze(ξ), x)
     y_delta = y - y_ref
     ridge_term = (ridge_penalty / 2) * (y_delta.T @ Γ @ y_delta)
     return loss_term + ridge_term
Example #2
0
    def set_eta(self, eta: hst.eta):
        r"""Set pre-specified demographic history :math:`\eta(t)`

        Args:
            eta: demographic history object
        """
        self.η = eta
        t = self.η.arrays()[0]
        self.M = utils.M(self.n, t, self.η.y)
        self.L = self.C @ self.M
        self.r = 0
Example #3
0
 def g(logy):
     """differentiable piece of objective in η problem"""
     L = self.C @ utils.M(self.n, t, np.exp(logy))
     if mask is not None:
         loss_term = loss(z, x[mask, :], L[mask, :])
     else:
         loss_term = loss(z, x, L)
     spline_term = (α_spline / 2) * ((D1 @ logy)**2).sum()
     # generalized Tikhonov
     logy_delta = logy - logy_ref
     ridge_term = (α_ridge / 2) * (logy_delta.T @ Γ @ logy_delta)
     return loss_term + spline_term + ridge_term
Example #4
0
    def test_constant_history(self):
        u"""test expected SFS under constant demography and mutation rate
        against the analytic formula from Fu (1995)
        """
        n = 198
        η0 = 3e4
        μ0 = 40
        change_points = np.array([])
        η = histories.eta(change_points, np.array([η0]))
        t, y = η.arrays()
        μ = histories.mu(change_points, np.array([[μ0]]))

        ξ_mushi = np.squeeze(utils.C(n) @ utils.M(n, t, y) @ μ.Z)
        ξ_Fu = 2 * η0 * μ0 / np.arange(1, n)

        self.assertTrue(np.isclose(ξ_mushi, ξ_Fu).all(),
                        msg=f'\nξ_mushi:\n{ξ_mushi}\nξ_Fu:\n{ξ_Fu}')
Example #5
0
    def simulate(self,
                 eta: hst.eta,
                 mu: Union[hst.mu, np.float64],
                 seed: int = None) -> None:
        r"""Simulate a SFS under the Poisson random field model (no linkage)
        assigns simulated SFS to ``X`` attribute

        Args:
            eta: demographic history
            mu: mutation spectrum history (or constant rate)
            seed: random seed
        """
        onp.random.seed(seed)
        t, y = eta.arrays()
        M = utils.M(self.n, t, y)
        L = self.C @ M
        if type(mu) == hst.mu:
            if not eta.check_grid(mu):
                raise ValueError('η(t) and μ(t) must use the same time grid')
        else:
            mu = hst.mu(eta.change_points, mu * np.ones_like(y))
        self.X = poisson.rvs(L @ mu.Z)
        self.mutation_types = mu.mutation_types
Example #6
0
    def infer_history(
            self,  # noqa: C901
            change_points: np.array,
            mu0: np.float64,
            eta: hst.eta = None,
            eta_ref: hst.eta = None,
            mu_ref: hst.mu = None,
            infer_eta: bool = True,
            infer_mu: bool = True,
            alpha_tv: np.float64 = 0,
            alpha_spline: np.float64 = 0,
            alpha_ridge: np.float64 = 0,
            beta_tv: np.float64 = 0,
            beta_spline: np.float64 = 0,
            beta_rank: np.float64 = 0,
            hard: bool = False,
            beta_ridge: np.float64 = 0,
            max_iter: int = 1000,
            s0: int = 1,
            max_line_iter=100,
            gamma: np.float64 = 0.8,
            tol: np.float64 = 0,
            loss: str = 'prf',
            mask: np.array = None,
            verbose: bool = False) -> None:
        r"""Perform sequential inference to fit :math:`\eta(t)` and
        :math:`\mu(t)`

        Args:
            change_points: epoch change points (ordered times > 0)
            mu0: total mutation rate (per genome per generation)
            eta: initial demographic history. By default, a
                 constant MLE is computed
            eta_ref: reference demographic history for ridge penalty. If
                     ``None``, the constant MLE is used
            mu_ref: reference MuSH for ridge penalty. If None, the constant
                    MLE is used
            infer_eta: perform :math:`\eta` inference if ``True``
            infer_mu: perform :math:`\mu` inference if ``True``
            loss: loss function, 'prf' for Poisson random field, 'kl' for
                  Kullback-Leibler divergence, 'lsq' for least-squares
            mask: array of bools, with False indicating exclusion of that
                  frequency
            alpha_tv: total variation penalty on :math:`\eta(t)`
            alpha_spline: L2 on first differences penalty on :math:`\eta(t)`
            alpha_ridge: L2 for strong convexity penalty on :math:`\eta(t)`
            hard: hard rank penalty on :math:`\mu(t)` (non-convex)
            beta_tv: total variation penalty on :math:`\mu(t)`
            beta_spline: penalty on :math:`\mu(t)`
            beta_rank: rank penalty on :math:`\mu(t)`
            beta_ridge: L2 penalty on :math:`\mu(t)`
            max_iter: maximum number of proximal gradient steps
            tol: relative tolerance in objective function (if ``0``, not used)
            s0: max step size
            max_line_iter: maximum number of line search steps
            gamma: step size shrinkage rate for line search
            verbose: print verbose messages if ``True``
        """

        # pithify reg paramter names
        α_tv = alpha_tv
        α_spline = alpha_spline
        α_ridge = alpha_ridge

        β_tv = beta_tv
        β_spline = beta_spline
        β_rank = beta_rank
        β_ridge = beta_ridge

        assert self.X is not None, 'use simulate() to generate data first'
        if self.X is None:
            raise ValueError('use simulate() to generate data first')
        if mask is not None:
            assert len(mask) == self.X.shape[0], 'mask must have n-1 elements'

        # ininitialize with MLE constant η and μ
        x = self.X.sum(1, keepdims=True)
        μ_total = hst.mu(change_points,
                         mu0 * np.ones((len(change_points) + 1, 1)))
        t, z = μ_total.arrays()
        # number of segregating variants in each mutation type
        S = self.X.sum(0, keepdims=True)

        if eta is not None:
            self.η = eta
        elif self.η is None:
            # Harmonic number
            H = (1 / np.arange(1, self.n - 1)).sum()
            # constant MLE
            y = (S.sum() / 2 / H / mu0) * np.ones(len(z))
            self.η = hst.eta(change_points, y)

        μ_const = hst.mu(self.η.change_points,
                         mu0 * (S / S.sum()) * np.ones(
                             (self.η.m, self.X.shape[1])),
                         mutation_types=self.mutation_types.values)

        if self.μ is None:
            self.μ = μ_const
        self.M = utils.M(self.n, t, self.η.y)
        self.L = self.C @ self.M

        # badness of fit
        if loss == 'prf':

            def loss(*args, **kwargs):
                return -utils.prf(*args, **kwargs)
        elif loss == 'kl':
            loss = utils.d_kl
        elif loss == 'lsq':
            loss = utils.lsq
        else:
            raise ValueError(f'unrecognized loss argument {loss}')

        # some matrices we'll need for the first difference penalties
        D = (np.eye(self.η.m, k=0) - np.eye(self.η.m, k=-1))
        # W matrix deals with boundary condition
        W = np.eye(self.η.m)
        W = index_update(W, index[0, 0], 0)
        D1 = W @ D  # 1st difference matrix
        # D2 = D1.T @ D1  # 2nd difference matrix

        if infer_eta:
            if verbose:
                print('inferring η(t)', flush=True)

            # Accelerated proximal gradient method: our objective function
            # decomposes as f = g + h, where g is differentiable and h is not.
            # https://people.eecs.berkeley.edu/~elghaoui/Teaching/EE227A/lecture18.pdf

            # Tikhonov matrix
            if eta_ref is None:
                eta_ref = self.η
                Γ = np.diag(np.ones_like(eta_ref.y))
            else:
                # - log(1 - CDF)
                Γ = np.diag(-np.log(utils.tmrca_sf(t, eta_ref.y, self.n))[:-1])
            logy_ref = np.log(eta_ref.y)

            @jit
            def g(logy):
                """differentiable piece of objective in η problem"""
                L = self.C @ utils.M(self.n, t, np.exp(logy))
                if mask is not None:
                    loss_term = loss(z, x[mask, :], L[mask, :])
                else:
                    loss_term = loss(z, x, L)
                spline_term = (α_spline / 2) * ((D1 @ logy)**2).sum()
                # generalized Tikhonov
                logy_delta = logy - logy_ref
                ridge_term = (α_ridge / 2) * (logy_delta.T @ Γ @ logy_delta)
                return loss_term + spline_term + ridge_term

            if α_tv > 0:

                @jit
                def h(logy):
                    """nondifferentiable piece of objective in η problem"""
                    return α_tv * np.abs(D1 @ logy).sum()

                def prox(logy, s):
                    """total variation prox operator"""
                    return ptv.tv1_1d(logy, s * α_tv)
            else:

                @jit
                def h(logy):
                    return 0

                def prox(logy, s):
                    return logy

            # initial iterate
            logy = np.log(self.η.y)

            logy = opt.acc_prox_grad_method(logy,
                                            g,
                                            jit(grad(g)),
                                            h,
                                            prox,
                                            tol=tol,
                                            max_iter=max_iter,
                                            s0=s0,
                                            max_line_iter=max_line_iter,
                                            gamma=gamma,
                                            verbose=verbose)

            y = np.exp(logy)

            self.η = hst.eta(self.η.change_points, y)
            self.M = utils.M(self.n, t, y)
            self.L = self.C @ self.M

        if infer_mu and len(self.mutation_types) > 1:
            if verbose:
                print('inferring μ(t) conditioned on η(t)', flush=True)

            if mu_ref is None:
                mu_ref = μ_const
                # Tikhonov matrix
                Γ = np.diag(np.ones_like(self.η.y))
            else:
                # - log(1 - CDF)
                Γ = np.diag(-np.log(utils.tmrca_sf(t, self.η.y, self.n))[:-1])

            # orthonormal basis for Aitchison simplex
            # NOTE: instead of Gram-Schmidt could try SVD of clr transformed X
            #       https://en.wikipedia.org/wiki/Compositional_data#Isometric_logratio_transform
            basis = cmp._gram_schmidt_basis(self.μ.Z.shape[1])
            # initial iterate in inverse log-ratio transform
            Z = cmp.ilr(self.μ.Z, basis)
            Z_const = cmp.ilr(μ_const.Z, basis)
            Z_ref = cmp.ilr(mu_ref.Z, basis)

            @jit
            def g(Z):
                """differentiable piece of objective in μ problem"""
                if mask is not None:
                    loss_term = loss(mu0 * cmp.ilr_inv(Z, basis),
                                     self.X[mask, :], self.L[mask, :])
                else:
                    loss_term = loss(mu0 * cmp.ilr_inv(Z, basis), self.X,
                                     self.L)
                spline_term = (β_spline / 2) * ((D1 @ Z)**2).sum()
                # generalized Tikhonov
                Z_delta = Z - Z_ref
                ridge_term = (β_ridge / 2) * np.trace(Z_delta.T @ Γ @ Z_delta)
                return loss_term + spline_term + ridge_term

            if β_tv and β_rank:

                @jit
                def h1(Z):
                    """1st nondifferentiable piece of objective in μ problem"""
                    return β_tv * np.abs(D1 @ Z).sum()

                shape = Z.T.shape
                w = β_tv * onp.ones(shape)
                w[:, -1] = 0
                w = w.flatten()[:-1]

                def prox1(Z, s):
                    """total variation prox operator on row dimension
                    """
                    return ptv.tv1w_1d(Z.T, s * w).reshape(shape).T

                @jit
                def h2(Z):
                    """2nd nondifferentiable piece of objective in μ problem"""
                    σ = np.linalg.svd(Z - Z_const, compute_uv=False)
                    return β_rank * np.linalg.norm(σ, 0 if hard else 1)

                def prox2(Z, s):
                    """singular value thresholding"""
                    U, σ, Vt = np.linalg.svd(Z - Z_const, full_matrices=False)
                    if hard:
                        σ = index_update(σ, index[σ <= s * β_rank], 0)
                    else:
                        σ = np.maximum(0, σ - s * β_rank)
                    Σ = np.diag(σ)
                    return Z_const + U @ Σ @ Vt

                Z = opt.three_op_prox_grad_method(Z,
                                                  g,
                                                  jit(grad(g)),
                                                  h1,
                                                  prox1,
                                                  h2,
                                                  prox2,
                                                  tol=tol,
                                                  max_iter=max_iter,
                                                  s0=s0,
                                                  max_line_iter=max_line_iter,
                                                  gamma=gamma,
                                                  ls_tol=0,
                                                  verbose=verbose)

            else:
                if β_tv:

                    @jit
                    def h(Z):
                        """nondifferentiable piece of objective in μ problem"""
                        return β_tv * np.abs(D1 @ Z).sum()

                    shape = Z.T.shape
                    w = β_tv * onp.ones(shape)
                    w[:, -1] = 0
                    w = w.flatten()[:-1]

                    def prox(Z, s):
                        """total variation prox operator on row dimension
                        """
                        return ptv.tv1w_1d(Z.T, s * w).reshape(shape).T

                elif β_rank:

                    @jit
                    def h(Z):
                        """nondifferentiable piece of objective in μ problem"""
                        σ = np.linalg.svd(Z - Z_const, compute_uv=False)
                        return β_rank * np.linalg.norm(σ, 0 if hard else 1)

                    def prox(Z, s):
                        """singular value thresholding"""
                        U, σ, Vt = np.linalg.svd(Z - Z_const,
                                                 full_matrices=False)
                        if hard:
                            σ = index_update(σ, index[σ <= s * β_rank], 0)
                        else:
                            σ = np.maximum(0, σ - s * β_rank)
                        Σ = np.diag(σ)
                        return Z_const + U @ Σ @ Vt
                else:

                    @jit
                    def h(Z):
                        return 0

                    @jit
                    def prox(Z, s):
                        return Z

                Z = opt.acc_prox_grad_method(Z,
                                             g,
                                             jit(grad(g)),
                                             h,
                                             prox,
                                             tol=tol,
                                             max_iter=max_iter,
                                             s0=s0,
                                             max_line_iter=max_line_iter,
                                             gamma=gamma,
                                             verbose=verbose)

            self.μ = hst.mu(self.η.change_points,
                            mu0 * cmp.ilr_inv(Z, basis),
                            mutation_types=self.mutation_types.values)
Example #7
0
    def infer_mush(
        self,
        *trend_penalty: Tuple[int, np.float64],
        ridge_penalty: np.float64 = 0,
        rank_penalty: np.float64 = 0,
        hard: bool = False,
        mu_ref: hst.mu = None,
        misid_penalty: np.float64 = 1e-4,
        loss: str = "prf",
        max_iter: int = 100,
        tol: np.float64 = 0,
        line_search_kwargs: Dict = {},
        trend_kwargs: Dict = {},
        verbose: bool = False,
    ) -> None:
        r"""Infer mutation spectrum history :math:`\mu(t)`

        Args:
            trend_penalty: tuple ``(k, λ)`` for :math:`k`-th order trend
                           penalty (can pass multiple for mixed trends)
            ridge_penalty: ridge penalty
            rank_penalty: rank penalty
            hard: hard rank penalty (non-convex)
            mu_ref: reference MuSH for ridge penalty. If None, the constant
                    MLE is used
            misid_penalty: ridge parameter to shrink misid rates to aggregate
                           rate
            loss: loss function from :mod:`~mushi.loss_functions` module
            max_iter: maximum number of optimization steps
            tol: relative tolerance in objective function (if ``0``, not used)
            line_search_kwargs: line search keyword arguments,
                                see :py:class:`mushi.optimization.LineSearcher`
            trend_kwargs: keyword arguments for trend filtering,
                          see :py:meth:`mushi.optimization.TrendFilter.run`
            verbose: print verbose messages if ``True``

        Examples:
            Suppose ``ksfs`` is a ``kSFS`` object, and the demography has
            already been fit with ``infer_eta``. The following fits a
            mutation spectrum history with 0-th order (piecewise constant) trend
            penalization of strength 100.

            >>> ksfs.infer_mush((0, 1e2))

            Alternatively, a mixed trend solution, with constant and cubic
            pieces, and with rank penalization 100, is fit with

            >>> ksfs.infer_mush((0, 1e2), (3, 1e1), rank_penalty=1e2)

            The attribute ``ksfs.mu`` is now set and accessable for plotting
            (see :class:`~mushi.mu`).
        """
        if self.X is None:
            raise TypeError("use simulate() to generate data first")
        self._check_eta()
        if self.r is None or self.r == 0:
            raise ValueError("ancestral misidentification rate has not been "
                             "inferred, possibly due to folded SFS inference")
        if self.mutation_types is None:
            raise ValueError("k-SFS must contain multiple mutation types")

        # number of segregating variants in each mutation type
        S = self.X.sum(0, keepdims=True)
        # ininitialize with MLE constant μ
        μ_const = hst.mu(
            self.η.change_points,
            self.mu0 * (S / S.sum()) * np.ones((self.η.m, self.X.shape[1])),
            mutation_types=self.mutation_types.values,
        )
        if self.μ is None:
            self.μ = μ_const
        t = μ_const.arrays()[0]
        self.M = utils.M(self.n, t, self.η.y)
        self.L = self.C @ self.M

        # badness of fit
        loss = getattr(loss_functions, loss)

        # rescale trend penalties to be comparable between orders and time grids
        # filter zeros from trend penalties
        trend_penalty = tuple((k, (self.μ.m**k / onp.math.factorial(k)) * λ)
                              for k, λ in trend_penalty if λ > 0)

        if mu_ref is None:
            mu_ref = μ_const
            # Tikhonov matrix
            Γ = np.diag(np.ones_like(self.η.y))
        else:
            self.μ.check_grid(mu_ref)
            # - log(1 - CDF)
            Γ = np.diag(-np.log(utils.tmrca_sf(t, self.η.y, self.n))[:-1])

        # orthonormal basis for Aitchison simplex
        # NOTE: instead of Gram-Schmidt could try SVD of clr transformed X
        #       https://en.wikipedia.org/wiki/Compositional_data#Isometric_logratio_transform
        basis = cmp._gram_schmidt_basis(self.μ.Z.shape[1])
        check_orth = True if self.μ.Z.shape[1] > 2 else False

        # constand MLE and reference mush
        Z_const = cmp.ilr(μ_const.Z, basis, check_orth)
        Z_ref = cmp.ilr(mu_ref.Z, basis, check_orth)

        # weights for relating misid rates to aggregate misid rate from eta step
        misid_weights = self.X.sum(0) / self.X.sum()
        # reference composition for weighted misid (if all rates are equal)
        misid_ref = cmp.ilr(misid_weights, basis, check_orth)

        # In the following, params will hold the weighted misid composition in
        # the first row and the mush composition at each time in the remaining rows

        @jit
        def g(params):
            """differentiable piece of objective in μ problem."""
            r = self.r * cmp.ilr_inv(params[0, :], basis) / misid_weights
            Z = params[1:, :]
            Ξ = self.L @ (self.mu0 * cmp.ilr_inv(Z, basis))
            Ξ = Ξ * (1 - r) + self.AM_freq @ Ξ @ (self.AM_mut *
                                                  r[:, np.newaxis])
            loss_term = loss(Ξ, self.X)
            Z_delta = Z - Z_ref
            ridge_term = (ridge_penalty / 2) * np.sum(Z_delta * (Γ @ Z_delta))
            misid_delta = params[0, :] - misid_ref
            misid_ridge_term = misid_penalty * np.sum(misid_delta**2)
            return loss_term + ridge_term + misid_ridge_term

        if trend_penalty:

            @jit
            def h_trend(params):
                """trend filtering penalty."""
                return sum(
                    λ *
                    np.linalg.norm(np.diff(params[1:, :], k + 1, axis=0), 1)
                    for k, λ in trend_penalty)

            def prox_trend(params, s):
                """trend filtering prox operator (no jit due to ptv module)"""
                k, sλ = zip(*((k, s * λ) for k, λ in trend_penalty))
                trend_filterer = opt.TrendFilter(k, sλ)
                return params.at[1:, :].set(
                    trend_filterer.run(params[1:, :], **trend_kwargs))

        if rank_penalty:
            if self.mutation_types.size < 3:
                raise ValueError(
                    "kSFS must have more than 2 mutation types for"
                    " rank penalization")

            @jit
            def h_rank(params):
                """2nd nondifferentiable piece of objective in μ problem."""
                if hard:
                    return rank_penalty * np.linalg.matrix_rank(params[1:, :] -
                                                                Z_const)
                # else:
                return rank_penalty * np.linalg.norm(params[1:, :] - Z_const,
                                                     "nuc")

            def prox_rank(params, s):
                """singular value thresholding."""
                U, σ, Vt = np.linalg.svd(params[1:, :] - Z_const,
                                         full_matrices=False)
                if hard:
                    σ = σ.at[σ <= s * rank_penalty].set(0)
                else:
                    σ = np.maximum(0, σ - s * rank_penalty)
                Σ = np.diag(σ)
                return params.at[1:, :].set(Z_const + U @ Σ @ Vt)

            if not hard:
                prox_rank = jit(prox_rank)

        # optimizer
        if trend_penalty and rank_penalty:
            optimizer = opt.ThreeOpProxGrad(
                g,
                jit(grad(g)),
                h_trend,
                prox_trend,
                h_rank,
                prox_rank,
                verbose=verbose,
                **line_search_kwargs,
            )
        else:
            if trend_penalty:
                h = h_trend
                prox = prox_trend
            elif rank_penalty:
                h = h_rank
                prox = prox_rank
            else:

                @jit
                def h(params):
                    return 0

                @jit
                def prox(params, s):
                    return params

            optimizer = opt.AccProxGrad(g,
                                        jit(grad(g)),
                                        h,
                                        prox,
                                        verbose=verbose,
                                        **line_search_kwargs)

        # initial point (note initial row is for misid rates)
        # ---------------------------------------------------
        params = np.zeros((self.μ.m + 1, self.mutation_types.size - 1))
        # misid rate for each mutation type
        if self.r_vector is not None:
            r = self.r_vector
        else:
            r = self.r * np.ones(self.mutation_types.size)
        params = params.at[0, :].set(
            cmp.ilr(misid_weights * r, basis, check_orth))
        # ilr transformed mush
        ilr_mush = cmp.ilr(self.μ.Z, basis, check_orth)
        # make sure it's a column vector if only 2 mutation types
        if ilr_mush.ndim == 1:
            ilr_mush = ilr_mush[:, np.newaxis]
        params = params.at[1:, :].set(ilr_mush)
        # ---------------------------------------------------

        # run optimizer
        params = optimizer.run(params, tol=tol, max_iter=max_iter)

        # update attributes
        self.r_vector = self.r * cmp.ilr_inv(params[0, :],
                                             basis) / misid_weights
        self.μ = hst.mu(
            self.η.change_points,
            self.mu0 * cmp.ilr_inv(params[1:, :], basis),
            mutation_types=self.mutation_types.values,
        )
Example #8
0
    def infer_eta(
        self,
        mu0: np.float64,
        *trend_penalty: Tuple[int, np.float64],
        ridge_penalty: np.float64 = 0,
        folded: bool = False,
        pts: np.float64 = 100,
        ta: np.float64 = None,
        log_transform: bool = True,
        eta: hst.eta = None,
        eta_ref: hst.eta = None,
        loss: str = "prf",
        max_iter: int = 100,
        tol: np.float64 = 0,
        line_search_kwargs: Dict = {},
        trend_kwargs: Dict = {},
        verbose: bool = False,
    ) -> None:
        r"""Infer demographic history :math:`\eta(t)`

        Args:
            mu0: total mutation rate (per genome per generation)
            trend_penalty: tuple ``(k, λ)`` for :math:`k`-th order trend
                           penalty (can pass multiple for mixed trends)
            ridge_penalty: ridge penalty
            folded: if ``False``, infer :math:`\eta(t)` using unfolded SFS. If
                    ``True``, can only be used with ``infer_mu=False``, and
                    infer :math:`\eta(t)` using folded SFS.
            pts: number of points for time discretization
            ta: time (in WF generations ago) of oldest change point in time
                discretization. If ``None``, set automatically based on
                10 * E[TMRCA] under MLE constant demography
            log_transform: fit :math:`\log\eta(t)`
            eta: initial demographic history. By default, a constant MLE is
                 computed
            eta_ref: reference demographic history for ridge penalty. If
                     ``None``, the constant MLE is used
            loss: loss function name from :mod:`~mushi.loss_functions` module
            max_iter: maximum number of optimization steps
            tol: relative tolerance in objective function (if ``0``, not used)
            line_search_kwargs: line search keyword arguments,
                                see :py:meth:`mushi.optimization.LineSearcher`
            trend_kwargs: keyword arguments for trend filtering,
                          see :py:meth:`mushi.optimization.TrendFilter.run`
            verbose: print verbose messages if ``True``

        Examples:

            Suppose ``ksfs`` is a ``kSFS`` object. Then the following fits a
            demographic history with 0-th order (piecewise constant) trend
            penalization of strength 100, assuming a mutation rate of 10
            mutations per genome per generation.

            >>> mu0 = 1
            >>> ksfs.infer_eta(mu0, (0, 1e2))

            Alternatively, a mixed trend solution, with constant and cubic
            pieces, is fit with

            >>> ksfs.infer_eta(mu0, (0, 1e2), (3, 1e1))

            The attribute ``ksfs.eta`` is now set and accessable for plotting
            (see :class:`~mushi.eta`).
        """
        if self.X is None:
            raise TypeError("use simulate() to generate data first")

        # total SFS
        if self.X.ndim == 1:
            x = self.X
        else:
            x = self.X.sum(1)
        # fold the spectrum if inference is on folded SFS
        if folded:
            x = utils.fold(x)

        # constant MLE
        # Harmonic number
        H = (1 / np.arange(1, self.n - 1)).sum()
        N_const = self.X.sum() / 2 / H / mu0

        if ta is None:
            tmrca_exp = 4 * N_const * (1 - 1 / self.n)
            ta = 10 * tmrca_exp
        change_points = np.logspace(0, np.log10(ta), pts)

        # ininitialize with MLE constant η
        if eta is not None:
            self.set_eta(eta)
        elif self.η is None:
            y = N_const * np.ones(change_points.size + 1)
            self.η = hst.eta(change_points, y)
        t = self.η.arrays()[0]

        self.mu0 = mu0

        # badness of fit
        loss = getattr(loss_functions, loss)

        # Accelerated proximal gradient method: our objective function
        # decomposes as f = g + h, where g is differentiable and h is not.
        # https://people.eecs.berkeley.edu/~elghaoui/Teaching/EE227A/lecture18.pdf

        # rescale trend penalties to be comparable between orders and time grids
        # filter zeros from trend penalties
        trend_penalty = tuple((k, (self.η.m**k / onp.math.factorial(k)) * λ)
                              for k, λ in trend_penalty if λ > 0)

        # Tikhonov matrix
        if eta_ref is None:
            eta_ref = self.η
            Γ = np.diag(np.ones_like(eta_ref.y))
        else:
            self.η.check_grid(eta_ref)
            # - log(1 - CDF)
            Γ = np.diag(-np.log(utils.tmrca_sf(t, eta_ref.y, self.n))[:-1])
        y_ref = np.log(eta_ref.y) if log_transform else eta_ref.y

        # In the following, the parameter vector params will contain the
        # misid rate in params[0], and y in params[1:]
        @jit
        def g(params):
            """differentiable piece of objective in η problem."""
            y = params[1:]
            if log_transform:
                M = utils.M(self.n, t, np.exp(y))
            else:
                M = utils.M(self.n, t, y)
            L = self.C @ M
            ξ = self.mu0 * L.sum(1)
            if folded:
                ξ = utils.fold(ξ)
            else:
                r = expit(params[0])
                ξ = (1 - r) * ξ + r * self.AM_freq @ ξ
            loss_term = loss(np.squeeze(ξ), x)
            y_delta = y - y_ref
            ridge_term = (ridge_penalty / 2) * (y_delta.T @ Γ @ y_delta)
            return loss_term + ridge_term

        @jit
        def h(params):
            """nondifferentiable piece of objective in η problem."""
            return sum(λ * np.linalg.norm(np.diff(params[1:], k + 1), 1)
                       for k, λ in trend_penalty)

        def prox(params, s):
            """trend filtering prox operator (no jit due to ptv module)"""
            if trend_penalty:
                k, sλ = zip(*((k, s * λ) for k, λ in trend_penalty))
                trend_filterer = opt.TrendFilter(k, sλ)
                params = params.at[1:].set(
                    trend_filterer.run(params[1:], **trend_kwargs))
            if log_transform:
                return params
            # else:
            # clip to minimum population size of 1
            return np.clip(params, 1)

        # optimizer
        optimizer = opt.AccProxGrad(g,
                                    jit(grad(g)),
                                    h,
                                    prox,
                                    verbose=verbose,
                                    **line_search_kwargs)
        # initial point
        params = np.concatenate(
            (np.array([logit(1e-2)]),
             np.log(self.η.y) if log_transform else self.η.y))
        # run optimization
        params = optimizer.run(params, tol=tol, max_iter=max_iter)

        if not folded:
            self.r = expit(params[0])
        y = np.exp(params[1:]) if log_transform else params[1:]
        self.η = hst.eta(self.η.change_points, y)
        self.M = utils.M(self.n, t, y)
        self.L = self.C @ self.M
Example #9
0
    def simulate(
        self,
        eta: hst.eta,
        mu: Union[hst.mu, np.float64],
        r: np.float64 = 0,
        seed: int = None,
    ) -> None:
        r"""Simulate a :math:`k`-SFS under the Poisson random field model
        (no linkage), assign to ``X`` attribute

        Args:
            eta: demographic history
            mu: mutation spectrum history (or constant mutation rate)
            r: ancestral state misidentification rate (default 0)
            seed: random seed

        Examples:

           Define sample size:

           >>> ksfs = mushi.kSFS(n=10)

           Define demographic history and MuSH:

           >>> eta = mushi.eta(np.array([1, 100, 10000]), np.array([1e4, 1e4, 1e2, 1e4]))
           >>> mush = mushi.mu(eta.change_points, np.ones((4, 4)),
           ...                 ['AAA>ACA', 'ACA>AAA', 'TCC>TTC', 'GAA>GGA'])

           Define ancestral misidentification rate:

           >>> r = 0.03

           Set random seed:

           >>> seed = 0

           Run simulation and print simulated :math:`k`-SFS

           >>> ksfs.simulate(eta, mush, r, seed)

           >>> ksfs.as_df() # doctest: +NORMALIZE_WHITESPACE
           mutation type     AAA>ACA  ACA>AAA  TCC>TTC  GAA>GGA
           sample frequency
           1                    1118     1123     1106     1108
           2                     147      128      120       98
           3                      65       55       66       60
           4                      49       52       64       46
           5                      44       43       34       36
           6                      35       28       36       33
           7                      23       32       24       35
           8                      34       32       24       24
           9                      52       41       57       56
        """
        onp.random.seed(seed)
        t, y = eta.arrays()
        M = utils.M(self.n, t, y)
        L = self.C @ M
        if type(mu) == hst.mu:
            eta.check_grid(mu)
            Ξ = L @ mu.Z
            self.mutation_types = mu.mutation_types
            self.AM_mut = utils.mutype_misid(self.mutation_types)
            self.X = np.array(
                poisson.rvs((1 - r) * Ξ + r * self.AM_freq @ Ξ @ self.AM_mut))
        else:
            ξ = mu * L.sum(1)
            self.X = np.array(poisson.rvs((1 - r) * ξ + r * self.AM_freq @ ξ))