def _getCladeEstimates(tree, txp, hs, nhsd) : po = getPostOrder(tree) for n in po: data = n.data if not n.succ : data.clade = ([txp[data.taxon]], True) data.pheight = data.branchlength + (nhsd[data.taxon] if nhsd else 0) else : sc = [tree.node(s).data.clade for s in n.succ] svalid = all([s[1] for s in sc]) sc = [s[0] for s in sc] nh = max([tree.node(s).data.pheight for s in n.succ]) mn,mx = min([min(x) for x in sc]),max([max(x) for x in sc]) if mx - mn + 1 == sum([len(x) for x in sc]) : valid = True else : valid = False if valid and svalid: m1 = max(sc[0]) if not (m1 < mx) : m1 = max(sc[1]) # [max(x) for x in sc[1]] assert (0 <= m1 < mx) hs[m1].append(nh) if n.id != tree.root : data.pheight = nh + data.branchlength data.clade = (sc[0] + sc[1], valid)
def parsimonyScore(tree, taxVals) : """ Parsimony score on rooted binary tree. taxVals is a function returning the (text) value for a tip ('?' for missing).""" for itx in tree.get_terminals() : n = tree.node(itx) v = taxVals(n.data) n.data.vals = set([v]) if v != '?' else set() pscore = 0 pord = getPostOrder(tree) for n in pord : if n.succ : ch = [tree.node(x) for x in n.succ] v = [c.data.vals for c in ch] a = reduce(set.intersection, v) if a : vals = a else : vals = reduce(set.union, v) tot = sum([len(x) > 0 for x in v]) if tot > 1 : assert len(v) == 2 pscore += 1 n.data.vals = vals for i in tree.all_ids() : n = tree.node(i) del n.data.vals return pscore
def nonInformative(tree, vals) : ps = parsimonyScoreWithPolytomies(tree, lambda x : vals[x.taxon], True) c = Counter([x for x in vals.itervalues() if x != '?']) if len(c) <= 1 or (len(c) == 2 and '?' in c) : # one known value or none return True mc = c.most_common() mcv = mc[0][0] for n in getPostOrder(tree) : if not n.succ : n.data.nmcv = int(vals[n.data.taxon] not in (mcv,'?')) else : nvals = n.data.vals pv = min(nvals.values()) if len(nvals) > 0 else 0 n.data.nmcv = sum( [tree.node(c).data.nmcv for c in n.succ] ) # mcv gives parsimony minimum if n.id != tree.root: if not (nvals[mcv] if mcv in nvals else nvals['?']) == n.data.nmcv: return False else : if not (pv == 0 or (nvals.get(mcv) == pv)) : return False if not ( pv == n.data.nmcv or (mcv not in nvals and n.data.nmcv == 1) ) : return False return True
def setSTspacing(tree, epsi, po) : """ Set species tree taxa positions (node.data.x) for plotting. The horizontal distance between extrama points (of branches) is at least epsi. """ rs = _detemineSpacing(tree, tree.root, epsi, po) xs = cumsum((0,)+rs.spacing) for n,x in zip(filter(lambda n : not n.succ, getPostOrder(tree)), xs) : n.data.x = x
def clusterFromTree(tr, th, caHelper = None) : """ th - half distance (i.e. height) """ if len(tr.get_terminals()) == 1 : return [tr.node(1)] if not (caHelper and getattr(caHelper,"tree",None) == tr) : caHelper = CAhelper(tr) topNodes = [] po = getPostOrder(tr) for nd in po: if nd.data.rh <= th : if nd.id == tr.root or tr.node(nd.prev).data.rh > th : topNodes.append(nd) return topNodes
def _collectCladeTaxaNonATZ(tree, taxa, partitions) : nhs = nodeHeights(tree, allTipsZero = False) for n in getPostOrder(tree): data = n.data if not n.succ: data.clade = [taxa.index(data.taxon),] else : p = [] for s in n.succ : d = tree.node(s).data p.extend(d.clade) del d.clade data.clade = p partitions[frozenset(data.clade)] = (n.id, nhs[n.id]) return nhs[tree.root]
def kendallVectors(tree, tax = None) : __setAuxHeights(tree, tree.root, 0.0, 0) terms = tree.get_terminals() if tax is None : term2i = dict(zip(terms, range(len(terms)))) for i,k in zip(terms, range(len(terms))) : tree.node(i).data.idx = k else : tax2i = dict(zip(tax, range(len(tax)))) term2i = [(i,tax2i[tree.node(i).data.taxon]) for i in terms] for i,k in term2i: tree.node(i).data.idx = k terms = [j[0] for j in sorted(term2i, key = lambda x : x[1])] term2i = dict(term2i) nt = len(terms) np = (nt*(nt-1))//2 hv = [None]*(nt + np) iv = [None]*len(hv) for n in getPostOrder(tree): data = n.data if not n.succ: data.clade = [data.idx,] # term2i[n.id],] else : p = [] for s in n.succ : d = tree.node(s).data p.append(d.clade) del d.clade data.clade = list(itertools.chain(*p)) for cc in itertools.combinations(p, 2) : for i,j in itertools.product(*cc) : pos = __ij2pos(i,j,nt) assert hv[pos] is None assert iv[pos] is None hv[pos] = data.info[0] iv[pos] = data.info[1] for i in terms: data = tree.node(i).data k = data.idx # term2i[i] hv[np + k] = data.branchlength iv[np+k] = 1 return hv,iv,[tree.node(x).data.taxon for x in terms]
def parsimonyScoreWithPolytomies(tree, taxVals, keepInternalInfo = False, parsimonyStats = False) : """ Parsimony score for non-binary trees. Missing data is marked with a '?'. taxVals is a function which returns the tip value based on node data. """ for itx in tree.get_terminals() : n = tree.node(itx) v = taxVals(n.data) n.data.vals = {v : 0} n.data.character = v if v != '?' : n.data.vals['?'] = 1 pord = getPostOrder(tree) for n in pord : if n.succ : ch = [tree.node(x) for x in n.succ] vch = [c.data.vals for c in ch] allv = reduce(set.union,[set(x.keys()) for x in vch]) v1 = dict([(x,sum([cost(v, x) for v in vch])) for x in allv - set('?')]) v1['?'] = sum([cost(v,'?') for v in vch]) n.data.vals = v1 vals = tree.node(tree.root).data.vals pval = min([vals[x] for x in vals if x != '?']) if len(vals) > 1 else 0 if parsimonyStats : # CI #minChanges / #actual if pval > 0 : c = Counter([tree.node(itx).data.character for itx in tree.get_terminals()]) if '?' in c: c.pop('?') minChanges = len(c) - 1; assert x >= 0 maxChanges = sum(c.itervalues()) - c.most_common(1)[0][1] pval = (pval, (minChanges, maxChanges)) else : pval = (pval, None) if not keepInternalInfo : for i in tree.all_ids() : n = tree.node(i) del n.data.vals return pval
def _setTreeHeights(tree, opts,fctr) : order = getPostOrder(tree) nhs = nodeHeights(tree, allTipsZero = False) for node in order : if not node.succ: node.data.height = nhs[node.id]*fctr else : hs = [tree.node(x).data.height for x in node.succ] mn = sum([h + opts[x] for x,h in zip(node.succ,hs)])/len(hs) node.data.height = max(*(hs + [mn])) for n in tree.all_ids() : node = tree.node(n) if node.prev is not None: p = tree.node(node.prev) node.data.branchlength = p.data.height - node.data.height assert node.data.branchlength >= 0 return tree
def _setTreeHeightsForTargets(tree, ftargets, fctr) : for i,h in ftargets() : tree.node(i).data.height = h nhs = nodeHeights(tree, allTipsZero = False) order = getPostOrder(tree) for node in order : if not node.succ: node.data.height = nhs[node.id]*fctr else : node.data.height = max([node.data.height]+ [tree.node(x).data.height for x in node.succ]) for n in tree.all_ids() : node = tree.node(n) if node.prev is not None: p = tree.node(node.prev) node.data.branchlength = p.data.height - node.data.height assert node.data.branchlength >= 0 for i in tree.all_ids() : del tree.node(i).data.height
def summaryTreeUsingMedianHeights(tree, xtrees) : tree = copy.deepcopy(tree) func = lambda t,(n,h) : h posteriorParts,rhs = allPartitions(tree, xtrees, func = func, withHeights = True, withRoot = True) treeParts = allPartitions(tree, [tree]) for k in treeParts : # Node id nn = treeParts[k][0][1] if k in posteriorParts : tree.node(nn).data.height = median(posteriorParts[k]) else : raise RuntimeError("tree incompatible with trees") # Assume all trees share same tip heights (not checked) nh = nodeHeights(xtrees[0]) tree.node(tree.root).data.height = median(rhs) for n in getPostOrder(tree): if not len(n.succ) : n.data.height = nh[xtrees[0].search_taxon(n.data.taxon)] else : # Make sure node is heigher than descendants n.data.height = max([n.data.height] + [tree.node(x).data.height for x in n.succ]) for n in tree.all_ids() : node = tree.node(n) if node.prev is not None: p = tree.node(node.prev) node.data.branchlength = p.data.height - node.data.height assert node.data.branchlength >= 0 return tree
def assembleTree(trees, thFrom, thTo, getSeqForTaxon, nMaxReps = 20, maxPerCons = 100, lowDiversity = 0.02, refineFactor = 1.1, refineUpperLimit = .15, verbose = None) : cahelpers = dict() cahelper = lambda t : cahelpers.get(t.name) or \ (cahelpers.update([(t.name,CAhelper(t))]) or cahelpers.get(t.name)) if verbose: print >> verbose, "cutting",len(trees),"trees at %g" % thFrom # cut trees at thFrom pseudoTaxa = cutForestAt(trees, thFrom, cahelper) nReps = len(pseudoTaxa) reps = [None]*nReps def getReps(k) : if not reps[k] : t,n = pseudoTaxa[k] nc = len(n.data.terms) if nc > 2: nc = min(max(int(math.log(nc,3)), 2), nMaxReps) r = random.sample(n.data.terms, nc) else : r = n.data.terms reps[k] = [getSeqForTaxon(x.data.taxon) for x in r] return reps[k] cons = [None]*nReps def getCons(k) : if not cons[k] : t,n = pseudoTaxa[k] nc = len(n.data.terms) if nc > maxPerCons : i = random.sample(n.data.terms, maxPerCons) else : i = n.data.terms sq = [getSeqForTaxon(x.data.taxon) for x in i] # s, r = align.mpc(sq, nRefines=0) # del r s = doTheCons(sq, n.data.rh) #al = align.seqMultiAlign(sorted(sqs, reverse=1)) #s = align.stripseq(align.cons(calign.createProfile(al))) cons[k] = s return cons[k] mhs = [] for t,n in pseudoTaxa: cahelper(t) # populate rh mhs.append(n.data.rh) # if both low diversity - use consensus. If not valid or close to cluster height, do the # means thing. If not low diversity, use log representatives # low less then 4%?? ## lowDiversity = 0.02 ## refineFactor = 1.1 ## refineUpperLimit = .15 # counts how many alignments done (for display) global acnt acnt = 0 def getDist(i,j) : mi,mj = mhs[i],mhs[j] anyCons = False if mi < lowDiversity : ri = [getCons(i)] anyCons = True else : ri = getReps(i) if mj < lowDiversity : rj = [getCons(j)] anyCons = True else : rj = getReps(j) nhs = len(ri)*len(rj) if nhs == 1 : h = calign.globalAlign(ri[0], rj[0], scores = defaultMatchScores, report = calign.JCcorrection) else : ap = calign.allpairs(ri, rj, align=True, scores = defaultMatchScores, report = calign.JCcorrection) h = sum([sum(x) for x in ap])/nhs global acnt acnt += nhs lowLim = 2*max(mi,mj) if anyCons and (h < lowLim or (h < refineUpperLimit and h < lowLim*refineFactor)) : xri = getReps(i) if len(ri) == 1 else ri xrj = getReps(j) if len(rj) == 1 else rj if ri != xri or rj != xrj : ap1 = calign.allpairs(xri, xrj, align=True, scores = defaultMatchScores, report = calign.JCcorrection) h1 = sum([sum(x) for x in ap1]) xnhs = (len(xri)*len(xrj)) acnt += xnhs h = (h * nhs + h1)/(nhs + xnhs) return max(h, lowLim) if verbose : print >> verbose, "assembling",nReps,"sub-trees into one tree",time.strftime("%T") print "n-sub-tree #pair-only-alignments #alignments time" verbose.flush() tnow = time.clock() # Use array. those can get big ds = array.array('f',repeat(0.0,nPairs(nReps))) pos = 0 for i in range(nReps-1) : for j in range(i+1, nReps) : ds[pos] = getDist(i,j) pos += 1 if verbose : dn = sum(range(nReps-1, nReps-i-2,-1)) print >> verbose, i, dn, "%4.3g%%" % ((100.*dn)/len(ds)), acnt, time.strftime("%T") if verbose : print >> verbose, tohms(time.clock() - tnow), time.strftime("%T") # Using correct weights can throw off the height guarantee, or not? wt = [len(n.data.terms) for t,n in pseudoTaxa] tnew = treeFromDists(ds, tax = [str(x) for x in range(nReps)], weights = wt) del ds for n in getPostOrder(tnew) : if not n.succ : t,nd = pseudoTaxa[int(n.data.taxon)] if len(nd.data.terms) == 1 : n.data.taxon = nd.data.taxon n.data.rtree = "%s:%f" % (n.data.taxon, n.data.branchlength) else : # Insure heights are there cahelper(t) s = t.toNewick(nd.id) d = n.data.branchlength - nd.data.rh if (d < -1e-10) : print "***** ERROR", d n.data.rtree = "%s:%f" % (s, max(d,0.0)) else : ch = [tnew.node(x).data.rtree for x in n.succ] n.data.rtree = "(%s,%s)" % (ch[0],ch[1]) if n.id != tnew.root : n.data.rtree = n.data.rtree + (":%f" % n.data.branchlength) trec = tnew.node(tnew.root).data.rtree trec = parseNewick(trec) return trec
def _getGTorderFixedStree(gtree, stree, gtax, gtx, tryPairs) : allsn = getPostOrder(stree, stree.root) # Species taxa in layout order, that is the order they are plotted stax = filter(lambda n : not n.succ, allsn) # For each gene tree node set # data.grp: the leftmost/rightmost species of taxa in clade (as indices of # plot positions) # data.grpnds: gene taxa nodes of in the leftmost/rightmost species (per # above). Those are the only ones that can make a diffrence for the node # score. The rest are always in. # # for terminal nodes set # data.o: Ordinal number in sequence # all terminals contain data.snode no = 0 for i,n in enumerate(stax): for gn in gtx[n.id] : gn.data.grp = [i,i] gn.data.grpnds = [[gn],[gn]] gn.data.allo = [[] if k != n else [gn] for k in stax] gn.data.o = no no += 1 gn.data.sz = 1 gtpost = _getRandPostOrder(gtree, gtree.root) swaps = [] for n in gtpost : if n.succ : sns = [gtree.node(x) for x in n.succ] l,r = zip(*[x.data.grp for x in sns]) n.data.grp = min(l),max(r) lg,rg = zip(*[x.data.grpnds for x in sns]) n.data.grpnds = reduce(_plus, [y for x,y in zip(l,lg) if n.data.grp[0] == x]),\ reduce(_plus, [y for x,y in zip(r,rg) if n.data.grp[1] == x]) n.data.sz = sum([x.data.sz for x in sns]) aa = zip(*[x.data.allo for x in sns]) n.data.allo = [l+r for l,r in aa] sw = [(l,r) for l,r in aa if len(r) and len(l)] if len(sw) : s = sum([1 for l,r in aa if len(r) or len(l)]) if s > 1: swaps.append((n, sw, s)) ## def ook(swaps) : ## return [all([all([max([x.data.o for x in u])-min([x.data.o for x in u])+1 == len(u) ## , max([x.data.o for x in v])-min([x.data.o for x in v])+1 == len(v) , ## max([x.data.o for x in u])+1 == min([x.data.o for x in v]) or ## max([x.data.o for x in v])+1 == min([x.data.o for x in u])]) for ## u,v in sw[1]]) ## for sw in swaps] nint = filter(lambda n : len(n.succ), gtpost) def score(nint) : htot, tot = 0.0, 0 for x in nint : l,r = [[z.data.o for z in u] for u in x.data.grpnds] dd = ((max(r) - min(l) + 1) - x.data.sz) if dd > 0 : tot += dd htot += dd * x.data.ht return tot, -htot ms = score(nint) mp = [x.data.o for x in gtax] msLast = (sys.maxint, 0) while ms < msLast: msLast = ms random.shuffle(swaps) for swap in swaps : sv = [] for l,r in swap[1] : lo,ro = [x.data.o for x in l], [x.data.o for x in r] mlo, mro = min(lo),min(ro) ll, lr = len(l),len(r) # assert 1 + max(lo+ro) - min(lo+ro) == ll+lr and max(lo)-mlo+1 == ll and max(ro)-mro+1 == lr #if swap[2] > 1 and verbose: # print [x.id for x in l], [x.id for x in r], mlo,mlo+ll,mro,mro+lr, sv.extend([(n.data.o,n) for n in l+r]) if mlo < mro : ls = mlo + lr rs = mlo else : ls = mro rs = mro + ll for k,n in zip([x - mlo for x in lo], l): n.data.o = ls + k for k,n in zip([x - mro for x in ro], r): n.data.o = rs + k # assert all(ook(swaps)) s = score(nint) if s < ms : #if not swap[2] > 1: # pdb.set_trace() #assert swap[2] > 1 #if verbose: print "*",ms, "to" ,s ms = s mp = [x.data.o for x in gtax] else : for k,n in sv : n.data.o = k #assert score(nint) == ms if tryPairs: msLast = (sys.maxint, 0) while ms < msLast: msLast = ms for kk in gtx: gtxkk = gtx[kk] a = [n.data.o for n in gtxkk] for i0,i1 in allPairs(range(len(gtxkk))) : sw = [gtxkk[x].data.o for x in (i0,i1)] gtxkk[i1].data.o, gtxkk[i0].data.o = sw s = score(nint) if s < ms : ms = s mp = [x.data.o for x in gtax] else : gtxkk[i0].data.o, gtxkk[i1].data.o = sw return ms,mp