Example #1
0
def make_forward_tree_defn(
    subst_model, tree, bin_names, with_indel_params=True, kn=True
):
    """Pairwise Fwd"""
    indel = make_indel_model_defn(with_indel_params, kn)
    subst = subst_model.make_fundamental_param_controller_defns(bin_names)
    leaf = NonParamDefn("leaf", dimensions=("edge",))

    if len(bin_names) > 1:
        switch = ProbabilityParamDefn("bin_switch", dimensions=["locus"])
        bprobs = PartitionDefn(
            [1.0 / len(bin_names) for bin in bin_names],
            name="bprobs",
            dimensions=["locus"],
            dimension=("bin", bin_names),
        )
        edge_args = [switch, bprobs]
        edge_defn_constructor = EdgeSumAndAlignDefnWithBins
    else:
        edge_args = []
        edge_defn_constructor = EdgeSumAndAlignDefn

    mprobs = subst["word_probs"]
    bin_data = CalcDefn(BinData)(mprobs, indel, subst["Qd"])
    bin_data = bin_data.across_dimension("bin", bin_names)
    edge_args.extend(bin_data)

    (top, scores) = _recursive_defns(
        tree, subst, leaf, edge_defn_constructor, edge_args
    )
    defn = FwdDefn(top)
    # defn = SumDefn(*scores)
    return AnnotateFloatDefn(defn, top)
Example #2
0
def make_total_loglikelihood_defn(tree, leaves, psubs, mprobs, bprobs,
                                  bin_names, locus_names, sites_independent):

    fixed_motifs = NonParamDefn("fixed_motif", ["edge"])

    lht = LikelihoodTreeDefn(leaves, tree=tree)
    plh = make_partial_likelihood_defns(tree, lht, psubs, fixed_motifs)

    # After the root partial likelihoods have been calculated it remains to
    # sum over the motifs, local sites, other sites (ie: cpus), bins and loci.
    # The motifs are always done first, but after that it gets complicated.
    # If a bin HMM is being used then the sites from the different CPUs must
    # be interleaved first, otherwise summing over the CPUs is done last to
    # minimise inter-CPU communicaton.

    root_mprobs = mprobs.select_from_dimension("edge", "root")
    lh = CalcDefn(numpy.inner, name="lh")(plh, root_mprobs)
    if len(bin_names) > 1:
        if sites_independent:
            site_pattern = CalcDefn(BinnedSiteDistribution,
                                    name="bdist")(bprobs)
        else:
            switch = ProbabilityParamDefn("bin_switch", dimensions=["locus"])
            site_pattern = CalcDefn(PatchSiteDistribution,
                                    name="bdist")(switch, bprobs)
        blh = CallDefn(site_pattern, lht, name="bindex")
        tll = CallDefn(blh, *lh.across_dimension("bin", bin_names),
                       **dict(name="tll"))
    else:
        lh = lh.select_from_dimension("bin", bin_names[0])
        tll = CalcDefn(log_sum_across_sites, name="logsum")(lht, lh)

    if len(locus_names) > 1:
        # currently has no .make_likelihood_function() method.
        tll = SumDefn(*tll.across_dimension("locus", locus_names))
    else:
        tll = tll.select_from_dimension("locus", locus_names[0])

    return tll