예제 #1
    def get_claic_component(self, x0, all_boots, grid_sizes, eps):
        if dadi_available:
            from dadi import Godambe
        elif moments_available:
            from moments import Godambe
            ImportError("For CLAIC evaluation either dadi or moments is"
                        " required.")
        # Cache evaluations of the frequency spectrum inside our hessian/J
        # evaluation function
        var2val = self.model.var2value(x0)
        is_not_discrete = np.array(
            [not isinstance(var, DiscreteVariable) for var in var2val])
        if len(x0) > 0 and len(var2val) > 0:
            x0 = np.array(list(var2val.values()), dtype=object)
            p0 = x0[is_not_discrete].astype(float)
            p0 = x0.astype(float)

        def simul_func(x):
            p = np.array(x0)
            if len(p) > 0 and len(var2val) > 0:
                p[is_not_discrete] = x
                p = x
            return self.simulate(p, self.data.sample_sizes, None, None,

        cached_simul = cache_func(simul_func)

        def func(x, data):
            model = cached_simul(x)
            return self.base_module.Inference.ll_multinom(model, self.data)

        H = -Godambe.get_hess(func, p0, eps, args=[self.data])
        H_inv = np.linalg.inv(H)

        J = np.zeros((len(p0), len(p0)))
        for ii, boot in enumerate(all_boots):
            boot = self.base_module.Spectrum(boot)
            grad_temp = Godambe.get_grad(func, p0, eps, args=[boot])
            J_temp = np.outer(grad_temp, grad_temp)
            J += J_temp

        J = J / len(all_boots)

        # G = J*H^-1
        G = np.dot(J, H_inv)

        return np.trace(G)
예제 #2
def get_claic_component(func_ex, all_boot, p0, data, pts=None, eps=1e-2):
    if pts is None, then moments is used.
    Some help:
    moments.Godambe.get_hess(func, p0, eps, args=())
    moments.Godambe.get_grad(func, p0, eps, args=())
    from dadi import Godambe
    if pts is None:
        import moments
        func_ex_ = func_ex
        import dadi as sim_lib
        import moments

        def func_ex_(p, ns):
            return func_ex(p, ns, pts)

    ns = data.sample_sizes

    # Cache evaluations of the frequency spectrum inside our hessian/J
    # evaluation function
    cache = {}

    def func(params, data):
        key = (tuple(params), tuple(ns))
        if key not in cache:
            cache[key] = func_ex_(params, ns)
        fs = cache[key]
        return moments.Inference.ll(fs, data)

    H = -Godambe.get_hess(func, p0, eps, args=[data])
    H_inv = np.linalg.inv(H)

    J = np.zeros((len(p0), len(p0)))
    for ii, boot in enumerate(all_boot):
        boot = moments.Spectrum(boot)
        grad_temp = Godambe.get_grad(func, p0, eps, args=[boot])

        J_temp = np.outer(grad_temp, grad_temp)
        J += J_temp

    J = J / len(all_boot)

    # G = J*H^-1
    G = np.dot(J, H_inv)

    return np.trace(G)
예제 #3
def perform_analysis(mutdf, prefix, args):
    non_sfs, syn_sfs = compute_sfs(mutdf, args)

    if args.verbose:
        print("Fitting demographic parameters...")
    fit_demography(syn_sfs, args)
    like, theta, demog_params = get_best_demog('.'.join(
        [str(args.samples), args.model_name, 'optimized', 'txt']))
    if args.verbose:
        print("Bootstrapping for Godambe uncertainty")
    bootstraps = make_bootstrap_sfs_binom(
        mutdf, 'syn', samples=args.samples
    )  #for Godambe uncertainty of demographic parameters.
    uncert = Godambe.GIM_uncert(
        [args.samples + 10, args.samples + 20, args.samples + 30], bootstraps,
        demog_params, syn_sfs)

    if args.verbose:
        print("Calculating spectra")
    spectra, theta_ns = create_spectra(demog_params, theta, args)
    if args.verbose:
        print("Fitting DFE models")
    simple_popt = fit_simple_dfe(spectra, theta_ns, non_sfs)
    complex_popt = fit_neugamma_dfe(spectra, theta_ns, non_sfs)
    #not currently using Godambe uncertainty for dfe parameters, though hypothetically possible, haven't worked out issues with implementation
    if args.verbose:
        print("Collecting final results")
    #plot and save final results.
    save_results(non_sfs, syn_sfs, spectra, simple_popt, complex_popt, like,
                 theta, demog_params, theta_ns, uncert, prefix, args)