def make_forward_tree_defn( subst_model, tree, bin_names, with_indel_params=True, kn=True ): """Pairwise Fwd""" indel = make_indel_model_defn(with_indel_params, kn) subst = subst_model.make_fundamental_param_controller_defns(bin_names) leaf = NonParamDefn("leaf", dimensions=("edge",)) if len(bin_names) > 1: switch = ProbabilityParamDefn("bin_switch", dimensions=["locus"]) bprobs = PartitionDefn( [1.0 / len(bin_names) for bin in bin_names], name="bprobs", dimensions=["locus"], dimension=("bin", bin_names), ) edge_args = [switch, bprobs] edge_defn_constructor = EdgeSumAndAlignDefnWithBins else: edge_args = [] edge_defn_constructor = EdgeSumAndAlignDefn mprobs = subst["word_probs"] bin_data = CalcDefn(BinData)(mprobs, indel, subst["Qd"]) bin_data = bin_data.across_dimension("bin", bin_names) edge_args.extend(bin_data) (top, scores) = _recursive_defns( tree, subst, leaf, edge_defn_constructor, edge_args ) defn = FwdDefn(top) # defn = SumDefn(*scores) return AnnotateFloatDefn(defn, top)
def make_total_loglikelihood_defn(tree, leaves, psubs, mprobs, bprobs, bin_names, locus_names, sites_independent): fixed_motifs = NonParamDefn("fixed_motif", ["edge"]) lht = LikelihoodTreeDefn(leaves, tree=tree) plh = make_partial_likelihood_defns(tree, lht, psubs, fixed_motifs) # After the root partial likelihoods have been calculated it remains to # sum over the motifs, local sites, other sites (ie: cpus), bins and loci. # The motifs are always done first, but after that it gets complicated. # If a bin HMM is being used then the sites from the different CPUs must # be interleaved first, otherwise summing over the CPUs is done last to # minimise inter-CPU communicaton. root_mprobs = mprobs.select_from_dimension("edge", "root") lh = CalcDefn(numpy.inner, name="lh")(plh, root_mprobs) if len(bin_names) > 1: if sites_independent: site_pattern = CalcDefn(BinnedSiteDistribution, name="bdist")(bprobs) else: switch = ProbabilityParamDefn("bin_switch", dimensions=["locus"]) site_pattern = CalcDefn(PatchSiteDistribution, name="bdist")(switch, bprobs) blh = CallDefn(site_pattern, lht, name="bindex") tll = CallDefn(blh, *lh.across_dimension("bin", bin_names), **dict(name="tll")) else: lh = lh.select_from_dimension("bin", bin_names[0]) tll = CalcDefn(log_sum_across_sites, name="logsum")(lht, lh) if len(locus_names) > 1: # currently has no .make_likelihood_function() method. tll = SumDefn(*tll.across_dimension("locus", locus_names)) else: tll = tll.select_from_dimension("locus", locus_names[0]) return tll