Пример #1
0
def test_ancestral():
    import os
    from Bio import AlignIO
    import numpy as np
    from treetime import TreeAnc, GTR
    root_dir = os.path.dirname(os.path.realpath(__file__))
    fasta = str(os.path.join(root_dir, '../data/H3N2_NA_allyears_NA.20.fasta'))
    nwk = str(os.path.join(root_dir, '../data/H3N2_NA_allyears_NA.20.nwk'))

    for marginal in [True, False]:
        print('loading flu example')
        t = TreeAnc(gtr='Jukes-Cantor', tree=nwk, aln=fasta)
        print('ancestral reconstruction' + ("marginal" if marginal else "joint"))
        t.reconstruct_anc(method='ml', marginal=marginal)
        assert "".join(t.tree.root.sequence) == 'ATGAATCCAAATCAAAAGATAATAACGATTGGCTCTGTTTCTCTCACCATTTCCACAATATGCTTCTTCATGCAAATTGCCATCTTGATAACTACTGTAACATTGCATTTCAAGCAATATGAATTCAACTCCCCCCCAAACAACCAAGTGATGCTGTGTGAACCAACAATAATAGAAAGAAACATAACAGAGATAGTGTATCTGACCAACACCACCATAGAGAAGGAAATATGCCCCAAACCAGCAGAATACAGAAATTGGTCAAAACCGCAATGTGGCATTACAGGATTTGCACCTTTCTCTAAGGACAATTCGATTAGGCTTTCCGCTGGTGGGGACATCTGGGTGACAAGAGAACCTTATGTGTCATGCGATCCTGACAAGTGTTATCAATTTGCCCTTGGACAGGGAACAACACTAAACAACGTGCATTCAAATAACACAGTACGTGATAGGACCCCTTATCGGACTCTATTGATGAATGAGTTGGGTGTTCCTTTTCATCTGGGGACCAAGCAAGTGTGCATAGCATGGTCCAGCTCAAGTTGTCACGATGGAAAAGCATGGCTGCATGTTTGTATAACGGGGGATGATAAAAATGCAACTGCTAGCTTCATTTACAATGGGAGGCTTGTAGATAGTGTTGTTTCATGGTCCAAAGAAATTCTCAGGACCCAGGAGTCAGAATGCGTTTGTATCAATGGAACTTGTACAGTAGTAATGACTGATGGAAGTGCTTCAGGAAAAGCTGATACTAAAATACTATTCATTGAGGAGGGGAAAATCGTTCATACTAGCACATTGTCAGGAAGTGCTCAGCATGTCGAAGAGTGCTCTTGCTATCCTCGATATCCTGGTGTCAGATGTGTCTGCAGAGACAACTGGAAAGGCTCCAATCGGCCCATCGTAGATATAAACATAAAGGATCATAGCATTGTTTCCAGTTATGTGTGTTCAGGACTTGTTGGAGACACACCCAGAAAAAACGACAGCTCCAGCAGTAGCCATTGTTTGGATCCTAACAATGAAGAAGGTGGTCATGGAGTGAAAGGCTGGGCCTTTGATGATGGAAATGACGTGTGGATGGGAAGAACAATCAACGAGACGTCACGCTTAGGGTATGAAACCTTCAAAGTCATTGAAGGCTGGTCCAACCCTAAGTCCAAATTGCAGATAAATAGGCAAGTCATAGTTGACAGAGGTGATAGGTCCGGTTATTCTGGTATTTTCTCTGTTGAAGGCAAAAGCTGCATCAATCGGTGCTTTTATGTGGAGTTGATTAGGGGAAGAAAAGAGGAAACTGAAGTCTTGTGGACCTCAAACAGTATTGTTGTGTTTTGTGGCACCTCAGGTACATATGGAACAGGCTCATGGCCTGATGGGGCGGACCTCAATCTCATGCCTATA'

    print('testing LH normalization')
    from StringIO import StringIO
    from Bio import Phylo,AlignIO
    tiny_tree = Phylo.read(StringIO("((A:0.60100000009,B:0.3010000009):0.1,C:0.2):0.001;"), 'newick')
    tiny_aln = AlignIO.read(StringIO(">A\nAAAAAAAAAAAAAAAACCCCCCCCCCCCCCCCGGGGGGGGGGGGGGGGTTTTTTTTTTTTTTTT\n"
                                     ">B\nAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTT\n"
                                     ">C\nACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT\n"), 'fasta')

    mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']), pi = np.array([0.9, 0.06, 0.02, 0.02]), W=np.ones((4,4)))
    t = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln)
    t.reconstruct_anc('ml', marginal=True, debug=True)
    lhsum =  (t.tree.root.marginal_profile.sum(axis=1) * np.exp(t.tree.root.marginal_subtree_LH_prefactor)).sum()
    print (lhsum)
    assert(np.abs(lhsum-1.0)<1e-6)

    t.optimize_branch_len()
def infer_gene_gain_loss(path, rates=[1.0, 1.0]):
    # initialize GTR model with default parameters
    mu = np.sum(rates)
    gene_pi = np.array(rates) / mu
    gain_loss_model = GTR.custom(pi=gene_pi,
                                 mu=mu,
                                 W=np.ones((2, 2)),
                                 alphabet=np.array(['0', '1']))
    # add "unknown" state to profile
    gain_loss_model.profile_map['-'] = np.ones(2)
    root_dir = os.path.dirname(os.path.realpath(__file__))

    # define file names for pseudo alignment of presence/absence patterns as in 001001010110
    sep = '/'
    fasta = sep.join([path.rstrip(sep), 'geneCluster', 'genePresence.aln'])
    # strain tree based on core gene SNPs
    nwk = sep.join([path.rstrip(sep), 'geneCluster', 'strain_tree.nwk'])

    # instantiate treetime with custom GTR
    t = TreeAnc(nwk, gtr=gain_loss_model, verbose=2)
    # fix leaves names since Bio.Phylo interprets numeric leaf names as confidence
    for leaf in t.tree.get_terminals():
        if leaf.name is None:
            leaf.name = str(leaf.confidence)
    t.aln = fasta
    t.tree.root.branch_length = 0.0001
    t.reconstruct_anc(method='ml')

    for n in t.tree.find_clades():
        n.genepresence = n.sequence

    return t
Пример #3
0
def test_ancestral():
    import os
    from Bio import AlignIO
    import numpy as np
    from treetime import TreeAnc, GTR
    root_dir = os.path.dirname(os.path.realpath(__file__))
    fasta = str(os.path.join(root_dir, 'treetime_examples/data/h3n2_na/h3n2_na_20.fasta'))
    nwk = str(os.path.join(root_dir, 'treetime_examples/data/h3n2_na/h3n2_na_20.nwk'))

    for marginal in [True, False]:
        print('loading flu example')
        t = TreeAnc(gtr='Jukes-Cantor', tree=nwk, aln=fasta)
        print('ancestral reconstruction' + ("marginal" if marginal else "joint"))
        t.reconstruct_anc(method='ml', marginal=marginal)
        assert "".join(t.tree.root.sequence) == 'ATGAATCCAAATCAAAAGATAATAACGATTGGCTCTGTTTCTCTCACCATTTCCACAATATGCTTCTTCATGCAAATTGCCATCTTGATAACTACTGTAACATTGCATTTCAAGCAATATGAATTCAACTCCCCCCCAAACAACCAAGTGATGCTGTGTGAACCAACAATAATAGAAAGAAACATAACAGAGATAGTGTATCTGACCAACACCACCATAGAGAAGGAAATATGCCCCAAACCAGCAGAATACAGAAATTGGTCAAAACCGCAATGTGGCATTACAGGATTTGCACCTTTCTCTAAGGACAATTCGATTAGGCTTTCCGCTGGTGGGGACATCTGGGTGACAAGAGAACCTTATGTGTCATGCGATCCTGACAAGTGTTATCAATTTGCCCTTGGACAGGGAACAACACTAAACAACGTGCATTCAAATAACACAGTACGTGATAGGACCCCTTATCGGACTCTATTGATGAATGAGTTGGGTGTTCCTTTTCATCTGGGGACCAAGCAAGTGTGCATAGCATGGTCCAGCTCAAGTTGTCACGATGGAAAAGCATGGCTGCATGTTTGTATAACGGGGGATGATAAAAATGCAACTGCTAGCTTCATTTACAATGGGAGGCTTGTAGATAGTGTTGTTTCATGGTCCAAAGAAATTCTCAGGACCCAGGAGTCAGAATGCGTTTGTATCAATGGAACTTGTACAGTAGTAATGACTGATGGAAGTGCTTCAGGAAAAGCTGATACTAAAATACTATTCATTGAGGAGGGGAAAATCGTTCATACTAGCACATTGTCAGGAAGTGCTCAGCATGTCGAAGAGTGCTCTTGCTATCCTCGATATCCTGGTGTCAGATGTGTCTGCAGAGACAACTGGAAAGGCTCCAATCGGCCCATCGTAGATATAAACATAAAGGATCATAGCATTGTTTCCAGTTATGTGTGTTCAGGACTTGTTGGAGACACACCCAGAAAAAACGACAGCTCCAGCAGTAGCCATTGTTTGGATCCTAACAATGAAGAAGGTGGTCATGGAGTGAAAGGCTGGGCCTTTGATGATGGAAATGACGTGTGGATGGGAAGAACAATCAACGAGACGTCACGCTTAGGGTATGAAACCTTCAAAGTCATTGAAGGCTGGTCCAACCCTAAGTCCAAATTGCAGATAAATAGGCAAGTCATAGTTGACAGAGGTGATAGGTCCGGTTATTCTGGTATTTTCTCTGTTGAAGGCAAAAGCTGCATCAATCGGTGCTTTTATGTGGAGTTGATTAGGGGAAGAAAAGAGGAAACTGAAGTCTTGTGGACCTCAAACAGTATTGTTGTGTTTTGTGGCACCTCAGGTACATATGGAACAGGCTCATGGCCTGATGGGGCGGACCTCAATCTCATGCCTATA'

    print('testing LH normalization')
    from Bio import Phylo,AlignIO
    tiny_tree = Phylo.read(StringIO("((A:0.60100000009,B:0.3010000009):0.1,C:0.2):0.001;"), 'newick')
    tiny_aln = AlignIO.read(StringIO(">A\nAAAAAAAAAAAAAAAACCCCCCCCCCCCCCCCGGGGGGGGGGGGGGGGTTTTTTTTTTTTTTTT\n"
                                     ">B\nAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTTAAAACCCCGGGGTTTT\n"
                                     ">C\nACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGTACGT\n"), 'fasta')

    mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']), pi = np.array([0.9, 0.06, 0.02, 0.02]), W=np.ones((4,4)))
    t = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln)
    t.reconstruct_anc('ml', marginal=True, debug=True)
    lhsum =  np.exp(t.sequence_LH(pos=np.arange(4**3))).sum()
    print (lhsum)
    assert(np.abs(lhsum-1.0)<1e-6)

    t.optimize_branch_len()
Пример #4
0
def test_seq_joint_reconstruction_correct():
    """
    evolve the random sequence, get the alignment at the leaf nodes.
    Reconstruct the sequences of the internal nodes (joint)
    and prove the reconstruction is correct.
    In addition, compute the likelihood of the particular realization of the
    sequences on the tree and prove that this likelihood is exactly the same
    as calculated in the joint reconstruction
    """

    from treetime import TreeAnc, GTR
    from treetime import seq_utils
    from Bio import Phylo, AlignIO
    from StringIO import StringIO
    import numpy as np
    try:
        from itertools import izip
    except ImportError:  #python3.x
        izip = zip
    from collections import defaultdict

    def exclusion(a, b):
        """
        Intersection of two lists
        """
        return list(set(a) - set(b))

    tiny_tree = Phylo.read(
        StringIO("((A:.060,B:.01200)C:.020,D:.0050)E:.004;"), 'newick')
    mygtr = GTR.custom(alphabet=np.array(['A', 'C', 'G', 'T']),
                       pi=np.array([0.15, 0.95, 0.05, 0.3]),
                       W=np.ones((4, 4)))
    seq = np.random.choice(mygtr.alphabet, p=mygtr.Pi, size=400)

    myTree = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=None, verbose=4)

    # simulate evolution, set resulting sequence as ref_seq
    tree = myTree.tree
    seq_len = 400
    tree.root.ref_seq = np.random.choice(mygtr.alphabet,
                                         p=mygtr.Pi,
                                         size=seq_len)
    print("Root sequence: " + ''.join(tree.root.ref_seq))
    mutation_list = defaultdict(list)
    for node in tree.find_clades():
        for c in node.clades:
            c.up = node
        if hasattr(node, 'ref_seq'):
            continue
        t = node.branch_length
        p = mygtr.propagate_profile(
            seq_utils.seq2prof(node.up.ref_seq, mygtr.profile_map), t)
        # normalie profile
        p = (p.T / p.sum(axis=1)).T
        # sample mutations randomly
        ref_seq_idxs = np.array([
            int(np.random.choice(np.arange(p.shape[1]), p=p[k]))
            for k in np.arange(p.shape[0])
        ])

        node.ref_seq = np.array([mygtr.alphabet[k] for k in ref_seq_idxs])

        node.ref_mutations = [
            (anc, pos, der)
            for pos, (anc,
                      der) in enumerate(izip(node.up.ref_seq, node.ref_seq))
            if anc != der
        ]
        for anc, pos, der in node.ref_mutations:
            print(pos)
            mutation_list[pos].append((node.name, anc, der))
        print(node.name, len(node.ref_mutations), node.ref_mutations)

    # set as the starting sequences to the terminal nodes:
    alnstr = ""
    i = 1
    for leaf in tree.get_terminals():
        alnstr += ">" + leaf.name + "\n" + ''.join(leaf.ref_seq) + '\n'
        i += 1
    print(alnstr)
    myTree.aln = AlignIO.read(StringIO(alnstr), 'fasta')
    myTree._attach_sequences_to_nodes()
    # reconstruct ancestral sequences:
    myTree._ml_anc_joint(debug=True)

    diff_count = 0
    mut_count = 0
    for node in myTree.tree.find_clades():
        if node.up is not None:
            mut_count += len(node.ref_mutations)
            diff_count += np.sum(node.sequence != node.ref_seq) == 0
            if np.sum(node.sequence != node.ref_seq):
                print(
                    "%s: True sequence does not equal inferred sequence. parent %s"
                    % (node.name, node.up.name))
            else:
                print("%s: True sequence equals inferred sequence. parent %s" %
                      (node.name, node.up.name))
        print(node.name, np.sum(node.sequence != node.ref_seq),
              np.where(node.sequence != node.ref_seq), len(node.mutations),
              node.mutations)

    # the assignment of mutations to the root node is probabilistic. Hence some differences are expected
    assert diff_count / seq_len < 2 * (1.0 * mut_count / seq_len)**2

    # prove the likelihood value calculation is correct
    LH = myTree.ancestral_likelihood()
    LH_p = (myTree.tree.sequence_LH)

    print("Difference between reference and inferred LH:", (LH - LH_p).sum())
    assert ((LH - LH_p).sum()) < 1e-9

    return myTree
Пример #5
0
def test_seq_joint_lh_is_max():
    """
    For a single-char sequence, perform joint ancestral sequence reconstruction
    and prove that this reconstruction is the most likely one by comparing to all
    possible reconstruction variants (brute-force).
    """

    from treetime import TreeAnc, GTR
    from treetime import seq_utils
    from Bio import Phylo, AlignIO
    from StringIO import StringIO
    import numpy as np

    mygtr = GTR.custom(alphabet=np.array(['A', 'C', 'G', 'T']),
                       pi=np.array([0.91, 0.05, 0.02, 0.02]),
                       W=np.ones((4, 4)))
    tiny_tree = Phylo.read(StringIO("((A:.0060,B:.30)C:.030,D:.020)E:.004;"),
                           'newick')

    #terminal node sequences (single nuc)
    A_char = 'A'
    B_char = 'C'
    D_char = 'G'

    # for brute-force, expand them to the strings
    A_seq = ''.join(np.repeat(A_char, 16))
    B_seq = ''.join(np.repeat(B_char, 16))
    D_seq = ''.join(np.repeat(D_char, 16))

    #
    def ref_lh():
        """
        reference likelihood - LH values for all possible variants
        of the internal node sequences
        """

        tiny_aln = AlignIO.read(
            StringIO(">A\n" + A_seq + "\n"
                     ">B\n" + B_seq + "\n"
                     ">D\n" + D_seq + "\n"
                     ">C\nAAAACCCCGGGGTTTT\n"
                     ">E\nACGTACGTACGTACGT\n"), 'fasta')

        myTree = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=tiny_aln, verbose=4)

        logLH_ref = myTree.ancestral_likelihood()

        return logLH_ref

    #
    def real_lh():
        """
        Likelihood of the sequences calculated by the joint ancestral
        sequence reconstruction
        """
        tiny_aln_1 = AlignIO.read(
            StringIO(">A\n" + A_char + "\n"
                     ">B\n" + B_char + "\n"
                     ">D\n" + D_char + "\n"), 'fasta')

        myTree_1 = TreeAnc(gtr=mygtr,
                           tree=tiny_tree,
                           aln=tiny_aln_1,
                           verbose=4)

        myTree_1.reconstruct_anc(method='ml', marginal=False, debug=True)
        logLH = myTree_1.tree.sequence_LH
        return logLH

    ref = ref_lh()
    real = real_lh()

    print(abs(ref.max() - real))
    # joint chooses the most likely realization of the tree
    assert (abs(ref.max() - real) < 1e-10)
    return ref, real
Пример #6
0
def mugration_inference(tree=None, seq_meta=None, field='country', confidence=True,
                        infer_gtr=True, root_state=None, missing='?'):
    from treetime import GTR
    from Bio.Align import MultipleSeqAlignment
    from Bio.SeqRecord import SeqRecord
    from Bio.Seq import Seq
    from Bio import Phylo

    T = Phylo.read(tree, 'newick')
    nodes = {n.name:n for n in T.get_terminals()}

    # Determine alphabet only counting tips in the tree
    places = set()
    for name, meta in seq_meta.items():
        if field in meta and name in nodes:
            places.add(meta[field])
    if root_state is not None:
        places.add(root_state)

    # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
    places = sorted(places)
    nc = len(places)
    if nc>180:
        print("ERROR: geo_inference: can't have more than 180 places!", file=sys.stderr)
        return None,None
    elif nc==1:
        print("WARNING: geo_inference: only one place found -- set every internal node to %s!"%places[0], file=sys.stderr)
        return None,None
    elif nc==0:
        print("ERROR: geo_inference: list of places is empty!", file=sys.stderr)
        return None,None
    else:
        # set up model
        alphabet = {chr(65+i):place for i,place in enumerate(places)}
        model = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                          alphabet = np.array(sorted(alphabet.keys())))

        missing_char = chr(65+nc)
        alphabet[missing_char]=missing
        model.profile_map[missing_char] = np.ones(nc)
        model.ambiguous = missing_char
        alphabet_rev = {v:k for k,v in alphabet.items()}

        # construct pseudo alignment
        pseudo_seqs = []
        for name, meta in seq_meta.items():
            if name in nodes:
                s=alphabet_rev[meta[field]] if field in meta else missing_char
                pseudo_seqs.append(SeqRecord(Seq(s), name=name, id=name))
        aln = MultipleSeqAlignment(pseudo_seqs)

        # set up treetime and infer
        from treetime import TreeAnc
        tt = TreeAnc(tree=tree, aln=aln, gtr=model, convert_upper=False, verbose=0)
        tt.use_mutation_length=False
        tt.infer_ancestral_sequences(infer_gtr=infer_gtr, store_compressed=False, pc=5.0,
                                     marginal=True, normalized_rate=False)

        # attach inferred states as e.g. node.region = 'africa'
        for node in tt.tree.find_clades():
            node.__setattr__(field, alphabet[node.sequence[0]])

        # if desired, attach entropy and confidence as e.g. node.region_entropy = 0.03
        if confidence:
            for node in tt.tree.find_clades():
                pdis = node.marginal_profile[0]
                S = -np.sum(pdis*np.log(pdis+TINY))

                marginal = [(alphabet[tt.gtr.alphabet[i]], pdis[i]) for i in range(len(tt.gtr.alphabet))]
                marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
                marginal = [(a, b) for a, b in marginal if b > 0.001][:4] #only take stuff over .1% and the top 4 elements
                conf = {a:b for a,b in marginal}
                node.__setattr__(field + "_entropy", S)
                node.__setattr__(field + "_confidence", conf)

        return tt, alphabet
Пример #7
0
            sep='\t' if params.states[-3:] == 'tsv' else ',',
            skipinitialspace=True)
        weights = {row[0]: row[1] for ri, row in tmp_weights.iterrows()}
        mean_weight = np.mean(weights.values())
        weights = np.array([
            weights[c] if c in weights else mean_weight for c in unique_states
        ],
                           dtype=float)
        weights /= weights.sum()
    else:
        weights = np.ones(nc, dtype=float) / nc

    # set up dummy matrix
    W = np.ones((nc, nc), dtype=float)

    mugration_GTR = GTR.custom(pi=weights, W=W, alphabet=np.array(alphabet))
    mugration_GTR.profile_map[missing_char] = np.ones(nc)
    mugration_GTR.ambiguous = missing_char

    ###########################################################################
    ### set up treeanc
    ###########################################################################
    treeanc = TreeAnc(params.tree, gtr=mugration_GTR, verbose=params.verbose)
    pseudo_seqs = [
        SeqRecord(id=n.name,
                  name=n.name,
                  seq=Seq(reverse_alphabet[leaf_to_attr[n.name]] if n.name in
                          leaf_to_attr else missing))
        for n in treeanc.tree.get_terminals()
    ]
    treeanc.aln = MultipleSeqAlignment(pseudo_seqs)
Пример #8
0
def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling_bias_correction=None,
                                weights=None, verbose=0, iterations=5):
    """take a set of discrete states associated with tips of a tree
    and reconstruct their ancestral states along with a GTR model that
    approximately maximizes the likelihood of the states on the tree.

    Parameters
    ----------
    tree : str, Bio.Phylo.Tree
        name of tree file or Biopython tree object
    traits : dict
        dictionary linking tips to straits
    missing_data : str, optional
        string indicating missing data
    pc : float, optional
        number of pseudo-counts to be used during GTR inference, default 1.0
    sampling_bias_correction : float, optional
        factor to inflate overall switching rate by to counteract sampling bias
    weights : str, optional
        name of file with equilibirum frequencies
    verbose : int, optional
        level of verbosity in output
    iterations : int, optional
        number of times non-linear optimization of overall rate and
        transmission estimation are iterated

    Returns
    -------
    tuple
        tuple of treeanc object, forward and reverse alphabets

    Raises
    ------
    TreeTimeError
        raise error if ancestral reconstruction errors out
    """
    ###########################################################################
    ### make a single character alphabet that maps to discrete states
    ###########################################################################

    unique_states = set(traits.values())
    n_observed_states = len(unique_states)

    # load weights from file and convert to dict if supplied as string
    if type(weights)==str:
        try:
            tmp_weights = pd.read_csv(weights, sep='\t' if weights[-3:]=='tsv' else ',',
                                 skipinitialspace=True)
            weight_dict = {row[0]:row[1] for ri,row in tmp_weights.iterrows() if not np.isnan(row[1])}
        except:
            raise ValueError("Loading of weights file '%s' failed!"%weights)
    elif type(weights)==dict:
        weight_dict = weights
    else:
        weight_dict = None

    # add weights to unique states for alphabet construction
    if weight_dict is not None:
        unique_states.update(weight_dict.keys())
        missing_weights = [c for c in unique_states if c not in weight_dict and c is not missing_data]
        if len(missing_weights):
            print("Missing weights for values: " + ", ".join(missing_weights))

        if len(missing_weights)>0.5*n_observed_states:
            print("More than half of discrete states missing from the weights file")
            print("Weights read from file are:", weights)
            raise TreeTimeError("More than half of discrete states missing from the weights file")

    unique_states=sorted(unique_states)
    # make a map from states (excluding missing data) to characters in the alphabet
    # note that gap character '-' is chr(45) and will never be included here
    reverse_alphabet = {state:chr(65+i) for i,state in enumerate(unique_states) if state!=missing_data}
    alphabet = list(reverse_alphabet.values())
    # construct a look up from alphabet character to states
    letter_to_state = {v:k for k,v in reverse_alphabet.items()}

    # construct the vector with weights to be used as equilibrium frequency
    if weight_dict is not None:
        mean_weight = np.mean(list(weight_dict.values()))
        weights = np.array([weight_dict[letter_to_state[c]] if letter_to_state[c] in weight_dict else mean_weight
                            for c in alphabet], dtype=float)
        weights/=weights.sum()

    # consistency checks
    if len(alphabet)<2:
        print("mugration: only one or zero states found -- this doesn't make any sense", file=sys.stderr)
        return None, None, None

    n_states = len(alphabet)
    missing_char = chr(65+n_states)
    reverse_alphabet[missing_data]=missing_char
    letter_to_state[missing_char]=missing_data

    ###########################################################################
    ### construct gtr model
    ###########################################################################

    # set up dummy matrix
    W = np.ones((n_states,n_states), dtype=float)

    mugration_GTR = GTR.custom(pi = weights, W=W, alphabet = np.array(alphabet))
    mugration_GTR.profile_map[missing_char] = np.ones(n_states)
    mugration_GTR.ambiguous=missing_char


    ###########################################################################
    ### set up treeanc
    ###########################################################################
    treeanc = TreeAnc(tree, gtr=mugration_GTR, verbose=verbose,
                      convert_upper=False, one_mutation=0.001)
    treeanc.use_mutation_length = False
    pseudo_seqs = [SeqRecord(id=n.name,name=n.name,
                   seq=Seq(reverse_alphabet[traits[n.name]]
                           if n.name in traits else missing_char))
                   for n in treeanc.tree.get_terminals()]
    valid_seq = np.array([str(s.seq)!=missing_char for s in pseudo_seqs])
    print("Assigned discrete traits to %d out of %d taxa.\n"%(np.sum(valid_seq),len(valid_seq)))
    treeanc.aln = MultipleSeqAlignment(pseudo_seqs)

    try:
        ndiff = treeanc.infer_ancestral_sequences(method='ml', infer_gtr=True,
            store_compressed=False, pc=pc, marginal=True, normalized_rate=False,
            fixed_pi=weights, reconstruct_tip_states=True)
        treeanc.optimize_gtr_rate()
    except TreeTimeError as e:
        print("\nAncestral reconstruction failed, please see above for error messages and/or rerun with --verbose 4\n")
        raise e

    for i in range(iterations):
        treeanc.infer_gtr(marginal=True, normalized_rate=False, pc=pc, fixed_pi=weights)
        treeanc.optimize_gtr_rate()

    if sampling_bias_correction:
        treeanc.gtr.mu *= sampling_bias_correction

    treeanc.infer_ancestral_sequences(infer_gtr=False, store_compressed=False,
                                 marginal=True, normalized_rate=False,
                                 reconstruct_tip_states=True)

    print(fill("NOTE: previous versions (<0.7.0) of this command made a 'short-branch length assumption. "
          "TreeTime now optimizes the overall rate numerically and thus allows for long branches "
          "along which multiple changes accumulated. This is expected to affect estimates of the "
          "overall rate while leaving the relative rates mostly unchanged."))

    return treeanc, letter_to_state, reverse_alphabet
Пример #9
0
def mugration_inference(tree=None,
                        seq_meta=None,
                        field='country',
                        confidence=True,
                        infer_gtr=True,
                        root_state=None,
                        missing='?',
                        sampling_bias_correction=None):
    """
    Infer likely ancestral states of a discrete character assuming a time reversible model.

    Parameters
    ----------
    tree : str
        name of tree file
    seq_meta : dict
        meta data associated with sequences
    field : str, optional
        meta data field to use
    confidence : bool, optional
        calculate confidence values for inferences
    infer_gtr : bool, optional
        infer a GTR model for trait transitions (otherwises uses a flat model with rate 1)
    root_state : None, optional
        force the state of the root node (currently not implemented)
    missing : str, optional
        character that is to be interpreted as missing data, default='?'

    Returns
    -------
    T : Phylo.Tree
        Biophyton tree
    gtr : treetime.GTR
        GTR model
    alphabet : dict
        mapping of character states to
    """
    from treetime import GTR
    from Bio.Align import MultipleSeqAlignment
    from Bio.SeqRecord import SeqRecord
    from Bio.Seq import Seq
    from Bio import Phylo

    T = Phylo.read(tree, 'newick')
    nodes = {n.name: n for n in T.get_terminals()}

    # Determine alphabet only counting tips in the tree
    places = set()
    for name, meta in seq_meta.items():
        if field in meta and name in nodes:
            places.add(meta[field])
    if root_state is not None:
        places.add(root_state)

    # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
    places = sorted(places)
    nc = len(places)
    if nc > 180:
        print("ERROR: geo_inference: can't have more than 180 places!",
              file=sys.stderr)
        return None, None, None
    elif nc == 0:
        print("ERROR: geo_inference: list of places is empty!",
              file=sys.stderr)
        return None, None, None
    elif nc == 1:
        print(
            "WARNING: geo_inference: only one place found -- set every internal node to %s!"
            % places[0],
            file=sys.stderr)
        alphabet = {'A': places[0]}
        alphabet_values = ['A']
        gtr = None
        for node in T.find_clades():
            node.sequence = ['A']
            node.marginal_profile = np.array([[1.0]])
    else:
        # set up model
        alphabet = {chr(65 + i): place for i, place in enumerate(places)}
        model = GTR.custom(pi=np.ones(nc, dtype=float) / nc,
                           W=np.ones((nc, nc)),
                           alphabet=np.array(sorted(alphabet.keys())))

        missing_char = chr(65 + nc)
        alphabet[missing_char] = missing
        model.profile_map[missing_char] = np.ones(nc)
        model.ambiguous = missing_char
        alphabet_rev = {v: k for k, v in alphabet.items()}

        # construct pseudo alignment
        pseudo_seqs = []
        for name, meta in seq_meta.items():
            if name in nodes:
                s = alphabet_rev[
                    meta[field]] if field in meta else missing_char
                pseudo_seqs.append(SeqRecord(Seq(s), name=name, id=name))
        aln = MultipleSeqAlignment(pseudo_seqs)

        # set up treetime and infer
        from treetime import TreeAnc
        tt = TreeAnc(tree=tree,
                     aln=aln,
                     gtr=model,
                     convert_upper=False,
                     verbose=0)
        tt.use_mutation_length = False
        tt.infer_ancestral_sequences(infer_gtr=infer_gtr,
                                     store_compressed=False,
                                     pc=1.0,
                                     marginal=True,
                                     normalized_rate=False)

        if sampling_bias_correction:
            tt.gtr.mu *= sampling_bias_correction
            tt.infer_ancestral_sequences(infer_gtr=False,
                                         store_compressed=False,
                                         marginal=True,
                                         normalized_rate=False)

        T = tt.tree
        gtr = tt.gtr
        alphabet_values = tt.gtr.alphabet

    # attach inferred states as e.g. node.region = 'africa'
    for node in T.find_clades():
        node.__setattr__(field, alphabet[node.sequence[0]])

    # if desired, attach entropy and confidence as e.g. node.region_entropy = 0.03
    if confidence:
        for node in T.find_clades():
            pdis = node.marginal_profile[0]
            S = -np.sum(pdis * np.log(pdis + TINY))

            marginal = [(alphabet[alphabet_values[i]], pdis[i])
                        for i in range(len(alphabet_values))]
            marginal.sort(key=lambda x: x[1],
                          reverse=True)  # sort on likelihoods
            marginal = [(a, b) for a, b in marginal if b > 0.001
                        ][:4]  #only take stuff over .1% and the top 4 elements
            conf = {a: b for a, b in marginal}
            node.__setattr__(field + "_entropy", S)
            node.__setattr__(field + "_confidence", conf)

    return T, gtr, alphabet
Пример #10
0
def test_seq_joint_reconstruction_correct():
    """
    evolve the random sequence, get the alignment at the leaf nodes.
    Reconstruct the sequences of the internal nodes (joint)
    and prove the reconstruction is correct.
    In addition, compute the likelihood of the particular realization of the
    sequences on the tree and prove that this likelihood is exactly the same
    as calculated in the joint reconstruction
    """

    from treetime import TreeAnc, GTR
    from treetime import seq_utils
    from Bio import Phylo, AlignIO
    import numpy as np
    try:
        from itertools import izip
    except ImportError:  #python3.x
        izip = zip
    from collections import defaultdict
    def exclusion(a, b):
        """
        Intersection of two lists
        """
        return list(set(a) - set(b))

    tiny_tree = Phylo.read(StringIO("((A:.060,B:.01200)C:.020,D:.0050)E:.004;"), 'newick')
    mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']),
                       pi = np.array([0.15, 0.95, 0.05, 0.3]), W=np.ones((4,4)))
    seq = np.random.choice(mygtr.alphabet, p=mygtr.Pi, size=400)


    myTree = TreeAnc(gtr=mygtr, tree=tiny_tree, aln=None, verbose=4)

    # simulate evolution, set resulting sequence as ref_seq
    tree = myTree.tree
    seq_len = 400
    tree.root.ref_seq = np.random.choice(mygtr.alphabet, p=mygtr.Pi, size=seq_len)
    print ("Root sequence: " + ''.join(tree.root.ref_seq))
    mutation_list = defaultdict(list)
    for node in tree.find_clades():
        for c in node.clades:
            c.up = node
        if hasattr(node, 'ref_seq'):
            continue
        t = node.branch_length
        p = mygtr.propagate_profile( seq_utils.seq2prof(node.up.ref_seq, mygtr.profile_map), t)
        # normalie profile
        p=(p.T/p.sum(axis=1)).T
        # sample mutations randomly
        ref_seq_idxs = np.array([int(np.random.choice(np.arange(p.shape[1]), p=p[k])) for k in np.arange(p.shape[0])])

        node.ref_seq = np.array([mygtr.alphabet[k] for k in ref_seq_idxs])

        node.ref_mutations = [(anc, pos, der) for pos, (anc, der) in
                            enumerate(izip(node.up.ref_seq, node.ref_seq)) if anc!=der]
        for anc, pos, der in node.ref_mutations:
            print(pos)
            mutation_list[pos].append((node.name, anc, der))
        print (node.name, len(node.ref_mutations), node.ref_mutations)

    # set as the starting sequences to the terminal nodes:
    alnstr = ""
    i = 1
    for leaf in tree.get_terminals():
        alnstr += ">" + leaf.name + "\n" + ''.join(leaf.ref_seq) + '\n'
        i += 1
    print (alnstr)
    myTree.aln = AlignIO.read(StringIO(alnstr), 'fasta')
    myTree._attach_sequences_to_nodes()
    # reconstruct ancestral sequences:
    myTree._ml_anc_joint(debug=True)

    diff_count = 0
    mut_count = 0
    for node in myTree.tree.find_clades():
        if node.up is not None:
            mut_count += len(node.ref_mutations)
            diff_count += np.sum(node.sequence != node.ref_seq)==0
            if np.sum(node.sequence != node.ref_seq):
                print("%s: True sequence does not equal inferred sequence. parent %s"%(node.name, node.up.name))
            else:
                print("%s: True sequence equals inferred sequence. parent %s"%(node.name, node.up.name))
        print (node.name, np.sum(node.sequence != node.ref_seq), np.where(node.sequence != node.ref_seq), len(node.mutations), node.mutations)

    # the assignment of mutations to the root node is probabilistic. Hence some differences are expected
    assert diff_count/seq_len<2*(1.0*mut_count/seq_len)**2

    # prove the likelihood value calculation is correct
    LH = myTree.ancestral_likelihood()
    LH_p = (myTree.tree.sequence_LH)

    print ("Difference between reference and inferred LH:", (LH - LH_p).sum())
    assert ((LH - LH_p).sum())<1e-9

    return myTree
Пример #11
0
def mugration(params):
    """
    implementing treetime mugration
    """

    ###########################################################################
    ### Parse states
    ###########################################################################
    if os.path.isfile(params.states):
        states = pd.read_csv(params.states, sep='\t' if params.states[-3:]=='tsv' else ',',
                             skipinitialspace=True)
    else:
        print("file with states does not exist")
        return 1

    outdir = get_outdir(params, '_mugration')

    taxon_name = 'name' if 'name' in states.columns else states.columns[0]
    if params.attribute:
        if params.attribute in states.columns:
            attr = params.attribute
        else:
            print("The specified attribute was not found in the metadata file "+params.states, file=sys.stderr)
            print("Available columns are: "+", ".join(states.columns), file=sys.stderr)
            return 1
    else:
        attr = states.columns[1]
        print("Attribute for mugration inference was not specified. Using "+attr, file=sys.stderr)

    leaf_to_attr = {x[taxon_name]:x[attr] for xi, x in states.iterrows()
                    if x[attr]!=params.missing_data}
    unique_states = sorted(set(leaf_to_attr.values()))
    nc = len(unique_states)
    if nc>180:
        print("mugration: can't have more than 180 states!", file=sys.stderr)
        return 1
    elif nc<2:
        print("mugration: only one or zero states found -- this doesn't make any sense", file=sys.stderr)
        return 1

    ###########################################################################
    ### make a single character alphabet that maps to discrete states
    ###########################################################################
    alphabet = [chr(65+i) for i,state in enumerate(unique_states)]
    missing_char = chr(65+nc)
    letter_to_state = {a:unique_states[i] for i,a in enumerate(alphabet)}
    letter_to_state[missing_char]=params.missing_data
    reverse_alphabet = {v:k for k,v in letter_to_state.items()}

    ###########################################################################
    ### construct gtr model
    ###########################################################################
    if params.weights:
        params.infer_gtr = True
        tmp_weights = pd.read_csv(params.weights, sep='\t' if params.states[-3:]=='tsv' else ',',
                             skipinitialspace=True)
        weights = {row[0]:row[1] for ri,row in tmp_weights.iterrows()}
        mean_weight = np.mean(list(weights.values()))
        weights = np.array([weights[c] if c in weights else mean_weight for c in unique_states], dtype=float)
        weights/=weights.sum()
    else:
        weights = np.ones(nc, dtype=float)/nc

    # set up dummy matrix
    W = np.ones((nc,nc), dtype=float)

    mugration_GTR = GTR.custom(pi = weights, W=W, alphabet = np.array(alphabet))
    mugration_GTR.profile_map[missing_char] = np.ones(nc)
    mugration_GTR.ambiguous=missing_char

    ###########################################################################
    ### set up treeanc
    ###########################################################################
    treeanc = TreeAnc(params.tree, gtr=mugration_GTR, verbose=params.verbose,
                      convert_upper=False, one_mutation=0.001)
    pseudo_seqs = [SeqRecord(id=n.name,name=n.name,
                   seq=Seq(reverse_alphabet[leaf_to_attr[n.name]]
                           if n.name in leaf_to_attr else missing_char))
                   for n in treeanc.tree.get_terminals()]
    treeanc.aln = MultipleSeqAlignment(pseudo_seqs)

    ndiff = treeanc.infer_ancestral_sequences(method='ml', infer_gtr=True,
            store_compressed=False, pc=params.pc, marginal=True, normalized_rate=False,
            fixed_pi=weights if params.weights else None)
    if ndiff==ttconf.ERROR: # if reconstruction failed, exit
        return 1


    ###########################################################################
    ### output
    ###########################################################################
    print("\nCompleted mugration model inference of attribute '%s' for"%attr,params.tree)

    basename = get_basename(params, outdir)
    gtr_name = basename + 'GTR.txt'
    with open(gtr_name, 'w') as ofile:
        ofile.write('Character to attribute mapping:\n')
        for state in unique_states:
            ofile.write('  %s: %s\n'%(reverse_alphabet[state], state))
        ofile.write('\n\n'+str(treeanc.gtr)+'\n')
        print("\nSaved inferred mugration model as:", gtr_name)

    terminal_count = 0
    for n in treeanc.tree.find_clades():
        if n.up is None:
            continue
        n.confidence=None
        # due to a bug in older versions of biopython that truncated filenames in nexus export
        # we truncate them by hand and make them unique.
        if n.is_terminal() and len(n.name)>40 and bioversion<"1.69":
            n.name = n.name[:35]+'_%03d'%terminal_count
            terminal_count+=1
        n.comment= '&%s="'%attr + letter_to_state[n.sequence[0]] +'"'

    if params.confidence:
        conf_name = basename+'confidence.csv'
        with open(conf_name, 'w') as ofile:
            ofile.write('#name, '+', '.join(unique_states)+'\n')
            for n in treeanc.tree.find_clades():
                ofile.write(n.name + ', '+', '.join([str(x) for x in n.marginal_profile[0]])+'\n')
        print("Saved table with ancestral state confidences as:", conf_name)

    # write tree to file
    outtree_name = basename+'annotated_tree.nexus'
    Phylo.write(treeanc.tree, outtree_name, 'nexus')
    print("Saved annotated tree as:",outtree_name)

    return 0
Пример #12
0
    def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False):
        '''
        infer a "mugration" model by pretending each region corresponds to a sequence
        state and repurposing the GTR inference and ancestral reconstruction
        '''
        from treetime import GTR
        # Determine alphabet
        places = set()
        for node in self.tree.find_clades():
            if hasattr(node, 'attr'):
                if attr in node.attr and attr!=missing:
                    places.add(node.attr[attr])
        if root_state is not None:
            places.add(root_state)

        # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
        places = sorted(places)
        nc = len(places)
        if nc>180:
            self.logger("geo_inference: can't have more than 180 places!",1)
            return
        elif nc==1:
            self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1)
            for node in self.tree.find_clades():
                node.attr[attr] = places[0]
                node.__setattr__(attr+'_transitions',[])
            return
        elif nc==0:
            self.logger("geo_inference: list of places is empty!",1)
            return

        # store previously reconstructed sequences
        nuc_seqs = {}
        nuc_muts = {}
        nuc_seq_LH = None
        if hasattr(self.tt.tree,'sequence_LH'):
            nuc_seq_LH = self.tt.tree.sequence_LH
        for node in self.tree.find_clades():
            if hasattr(node, 'sequence'):
                nuc_seqs[node] = node.sequence
            if hasattr(node, 'mutations'):
                nuc_muts[node] = node.mutations
                node.__delattr__('mutations')


        alphabet = {chr(65+i):place for i,place in enumerate(places)}
        sequence_gtr = self.tt.gtr
        myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                              alphabet = np.array(sorted(alphabet.keys())))
        missing_char = chr(65+nc)
        alphabet[missing_char]=missing
        myGeoGTR.profile_map[missing_char] = np.ones(nc)
        alphabet_rev = {v:k for k,v in alphabet.items()}

        # set geo info to nodes as one letter sequence.
        self.tt.seq_len = 1
        for node in self.tree.get_terminals():
            if hasattr(node, 'attr'):
                if attr in node.attr:
                    node.sequence=np.array([alphabet_rev[node.attr[attr]]])
                else:
                    node.sequence=np.array([missing_char])
            else:
                node.sequence=np.array([missing_char])
        for node in self.tree.get_nonterminals():
            node.__delattr__('sequence')
        if root_state is not None:
            self.tree.root.split(n=1, branch_length=0.0)
            extra_clade = self.tree.root.clades[-1]
            extra_clade.name = "dummy_root_node"
            extra_clade.up = self.tree.root
            extra_clade.sequence = np.array([alphabet_rev[root_state]])
        self.tt.make_reduced_alignment()
        # set custom GTR model, run inference
        self.tt._gtr = myGeoGTR
        # import pdb; pdb.set_trace()
        tmp_use_mutation_length = self.tt.use_mutation_length
        self.tt.use_mutation_length=False
        self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False,
            store_compressed=False, pc=5.0, marginal=True, normalized_rate=False)

        if root_state is not None:
            self.tree.prune(extra_clade)
        # restore the nucleotide sequence and mutations to maintain expected behavior
        self.tt.geogtr = self.tt.gtr
        self.tt.geogtr.alphabet_to_location = alphabet
        self.tt._gtr = sequence_gtr
        if hasattr(self.tt.tree,'sequence_LH'):
            self.tt.tree.geo_LH = self.tt.tree.sequence_LH
            self.tt.tree.sequence_LH = nuc_seq_LH
        for node in self.tree.find_clades():
            node.attr[attr] = alphabet[node.sequence[0]]
            if node in nuc_seqs:
                node.sequence = nuc_seqs[node]
            if node.up is not None:
                node.__setattr__(attr+'_transitions', node.mutations)
                if node in nuc_muts:
                    node.mutations = nuc_muts[node]
            # save marginal likelihoods if desired
            if report_confidence:
                node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0]))
                # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length);
                marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))]
                marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
                marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements
                node.attr[attr + "_confidence"] = {a:b for a,b in marginal}
        self.tt.use_mutation_length=tmp_use_mutation_length

        # store saved attrs for save/restore functionality
        if not hasattr(self, "mugration_attrs"):
            self.mugration_attrs = []
        self.mugration_attrs.append(attr)
        if report_confidence:
            self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"])
Пример #13
0
    def geo_inference(self, attr, missing='?', root_state=None, report_confidence=False):
        '''
        infer a "mugration" model by pretending each region corresponds to a sequence
        state and repurposing the GTR inference and ancestral reconstruction
        '''
        from treetime import GTR
        # Determine alphabet
        places = set()
        for node in self.tree.find_clades():
            if hasattr(node, 'attr'):
                if attr in node.attr and attr!=missing:
                    places.add(node.attr[attr])
        if root_state is not None:
            places.add(root_state)

        # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
        places = sorted(places)
        nc = len(places)
        if nc>180:
            self.logger("geo_inference: can't have more than 180 places!",1)
            return
        elif nc==1:
            self.logger("geo_inference: only one place found -- setting every internal node to %s!"%places[0],1)
            for node in self.tree.find_clades():
                node.attr[attr] = places[0]
                node.__setattr__(attr+'_transitions',[])
            return
        elif nc==0:
            self.logger("geo_inference: list of places is empty!",1)
            return

        # store previously reconstructed sequences
        nuc_seqs = {}
        nuc_muts = {}
        nuc_seq_LH = None
        if hasattr(self.tt.tree,'sequence_LH'):
            nuc_seq_LH = self.tt.tree.sequence_LH
        for node in self.tree.find_clades():
            if hasattr(node, 'sequence'):
                nuc_seqs[node] = node.sequence
            if hasattr(node, 'mutations'):
                nuc_muts[node] = node.mutations
                node.__delattr__('mutations')


        alphabet = {chr(65+i):place for i,place in enumerate(places)}
        sequence_gtr = self.tt.gtr
        myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                              alphabet = np.array(sorted(alphabet.keys())))
        missing_char = chr(65+nc)
        alphabet[missing_char]=missing
        myGeoGTR.profile_map[missing_char] = np.ones(nc)
        alphabet_rev = {v:k for k,v in alphabet.iteritems()}

        # set geo info to nodes as one letter sequence.
        self.tt.seq_len = 1
        for node in self.tree.get_terminals():
            if hasattr(node, 'attr'):
                if attr in node.attr:
                    node.sequence=np.array([alphabet_rev[node.attr[attr]]])
                else:
                    node.sequence=np.array([missing_char])
            else:
                node.sequence=np.array([missing_char])
        for node in self.tree.get_nonterminals():
            node.__delattr__('sequence')
        if root_state is not None:
            self.tree.root.split(n=1, branch_length=0.0)
            extra_clade = self.tree.root.clades[-1]
            extra_clade.name = "dummy_root_node"
            extra_clade.up = self.tree.root
            extra_clade.sequence = np.array([alphabet_rev[root_state]])
        self.tt.make_reduced_alignment()
        # set custom GTR model, run inference
        self.tt._gtr = myGeoGTR
        # import pdb; pdb.set_trace()
        tmp_use_mutation_length = self.tt.use_mutation_length
        self.tt.use_mutation_length=False
        self.tt.infer_ancestral_sequences(method='ml', infer_gtr=False,
            store_compressed=False, pc=5.0, marginal=True, normalized_rate=False)

        if root_state is not None:
            self.tree.prune(extra_clade)
        # restore the nucleotide sequence and mutations to maintain expected behavior
        self.tt.geogtr = self.tt.gtr
        self.tt.geogtr.alphabet_to_location = alphabet
        self.tt._gtr = sequence_gtr
        if hasattr(self.tt.tree,'sequence_LH'):
            self.tt.tree.geo_LH = self.tt.tree.sequence_LH
            self.tt.tree.sequence_LH = nuc_seq_LH
        for node in self.tree.find_clades():
            node.attr[attr] = alphabet[node.sequence[0]]
            if node in nuc_seqs:
                node.sequence = nuc_seqs[node]
            if node.up is not None:
                node.__setattr__(attr+'_transitions', node.mutations)
                if node in nuc_muts:
                    node.mutations = nuc_muts[node]
            # save marginal likelihoods if desired
            if report_confidence:
                node.attr[attr + "_entropy"] = sum([v * math.log(v+1E-20) for v in node.marginal_profile[0]]) * -1 / math.log(len(node.marginal_profile[0]))
                # javascript: vals.map((v) => v * Math.log(v + 1E-10)).reduce((a, b) => a + b, 0) * -1 / Math.log(vals.length);
                marginal = [(alphabet[self.tt.geogtr.alphabet[i]], node.marginal_profile[0][i]) for i in range(0, len(self.tt.geogtr.alphabet))]
                marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
                marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements
                node.attr[attr + "_confidence"] = {a:b for a,b in marginal}
        self.tt.use_mutation_length=tmp_use_mutation_length

        # store saved attrs for save/restore functionality
        if not hasattr(self, "mugration_attrs"):
            self.mugration_attrs = []
        self.mugration_attrs.append(attr)
        if report_confidence:
            self.mugration_attrs.extend([attr + "_entropy", attr + "_confidence"])
Пример #14
0
def moogration(params):
    """
    implementing treetime moogration
    """

    ###########################################################################
    ### Parse states
    ###########################################################################
    if os.path.isfile(params.states):
        states = pd.read_csv(params.states,
                             sep='\t' if params.states[-3:] == 'tsv' else ',',
                             skipinitialspace=True)
    else:
        print("file with states does not exist")
        return 1

    outdir = get_outdir(params, '_moogration')

    taxon_name = 'name' if 'name' in states.columns else states.columns[0]
    if params.attribute:
        if params.attribute in states.columns:
            attr = params.attribute
        else:
            print(
                "The specified attribute was not found in the metadata file " +
                params.states,
                file=sys.stderr)
            print("Available columns are: " + ", ".join(states.columns),
                  file=sys.stderr)
            return 1
    else:
        attr = states.columns[1]
        print("Attribute for moogration inference was not specified. Using " +
              attr,
              file=sys.stderr)

    leaf_to_attr = {
        x[taxon_name]: x[attr]
        for xi, x in states.iterrows() if x[attr] != params.missing_data
    }
    unique_states = sorted(set(leaf_to_attr.values()))
    nc = len(unique_states)
    if nc > 180:
        print("moogration: can't have more than 180 states!", file=sys.stderr)
        return 1
    elif nc < 2:
        print(
            "moogration: only one or zero states found -- this doesn't make any sense",
            file=sys.stderr)
        return 1

    ###########################################################################
    ### make a single character alphabet that maps to discrete states
    ###########################################################################
    alphabet = [chr(65 + i) for i, state in enumerate(unique_states)]
    missing_char = chr(65 + nc)
    letter_to_state = {a: unique_states[i] for i, a in enumerate(alphabet)}
    letter_to_state[missing_char] = params.missing_data
    reverse_alphabet = {v: k for k, v in letter_to_state.items()}

    ###########################################################################
    ### construct gtr model
    ###########################################################################
    if params.weights:
        params.infer_gtr = True
        tmp_weights = pd.read_csv(
            params.weights,
            sep='\t' if params.states[-3:] == 'tsv' else ',',
            skipinitialspace=True)
        weights = {row[0]: row[1] for ri, row in tmp_weights.iterrows()}
        mean_weight = np.mean(list(weights.values()))
        weights = np.array([
            weights[c] if c in weights else mean_weight for c in unique_states
        ],
                           dtype=float)
        weights /= weights.sum()
    else:
        weights = np.ones(nc, dtype=float) / nc

    # set up dummy matrix
    W = np.ones((nc, nc), dtype=float)

    moogration_GTR = GTR.custom(pi=weights, W=W, alphabet=np.array(alphabet))
    moogration_GTR.profile_map[missing_char] = np.ones(nc)
    moogration_GTR.ambiguous = missing_char

    ###########################################################################
    ### set up treeanc
    ###########################################################################
    treeanc = TreeAnc(params.tree,
                      gtr=moogration_GTR,
                      verbose=params.verbose,
                      convert_upper=False,
                      one_mutation=0.001)
    pseudo_seqs = [
        SeqRecord(id=n.name,
                  name=n.name,
                  seq=Seq(reverse_alphabet[leaf_to_attr[n.name]] if n.name in
                          leaf_to_attr else missing_char))
        for n in treeanc.tree.get_terminals()
    ]
    treeanc.aln = MultipleSeqAlignment(pseudo_seqs)

    ndiff = treeanc.infer_ancestral_sequences(
        method='ml',
        infer_gtr=True,
        store_compressed=False,
        pc=params.pc,
        marginal=True,
        normalized_rate=False,
        fixed_pi=weights if params.weights else None)
    if ndiff == ttconf.ERROR:  # if reconstruction failed, exit
        return 1

    ###########################################################################
    ### output
    ###########################################################################
    print(
        "\nCompleted moogration model inference of attribute '%s' for" % attr,
        params.tree)

    basename = get_basename(params, outdir)
    gtr_name = basename + 'GTR.txt'
    with open(gtr_name, 'w') as ofile:
        ofile.write('Character to attribute mapping:\n')
        for state in unique_states:
            ofile.write('  %s: %s\n' % (reverse_alphabet[state], state))
        ofile.write('\n\n' + str(treeanc.gtr) + '\n')
        print("\nSaved inferred moogration model as:", gtr_name)

    terminal_count = 0
    for n in treeanc.tree.find_clades():
        if n.up is None:
            continue
        n.confidence = None
        # due to a bug in older versions of biopython that truncated filenames in nexus export
        # we truncate them by hand and make them unique.
        if n.is_terminal() and len(n.name) > 40 and bioversion < "1.69":
            n.name = n.name[:35] + '_%03d' % terminal_count
            terminal_count += 1
        n.comment = '&%s="' % attr + letter_to_state[n.sequence[0]] + '"'

    if params.confidence:
        conf_name = basename + 'confidence.csv'
        with open(conf_name, 'w') as ofile:
            ofile.write('#name, ' + ', '.join(unique_states) + '\n')
            for n in treeanc.tree.find_clades():
                ofile.write(n.name + ', ' +
                            ', '.join([str(x)
                                       for x in n.marginal_profile[0]]) + '\n')
        print("Saved table with ancestral state confidences as:", conf_name)

    # write tree to file
    outtree_name = basename + 'annotated_tree.nexus'
    Phylo.write(treeanc.tree, outtree_name, 'nexus')
    print("Saved annotated tree as:", outtree_name)

    return 0
Пример #15
0
def test_seq_joint_lh_is_max():
    """
    For a single-char sequence, perform joint ancestral sequence reconstruction
    and prove that this reconstruction is the most likely one by comparing to all
    possible reconstruction variants (brute-force).
    """

    from treetime import TreeAnc, GTR
    from treetime import seq_utils
    from Bio import Phylo, AlignIO
    import numpy as np

    mygtr = GTR.custom(alphabet = np.array(['A', 'C', 'G', 'T']), pi = np.array([0.91, 0.05, 0.02, 0.02]), W=np.ones((4,4)))
    tiny_tree = Phylo.read(StringIO("((A:.0060,B:.30)C:.030,D:.020)E:.004;"), 'newick')

    #terminal node sequences (single nuc)
    A_char = 'A'
    B_char = 'C'
    D_char = 'G'

    # for brute-force, expand them to the strings
    A_seq = ''.join(np.repeat(A_char,16))
    B_seq = ''.join(np.repeat(B_char,16))
    D_seq = ''.join(np.repeat(D_char,16))

    #
    def ref_lh():
        """
        reference likelihood - LH values for all possible variants
        of the internal node sequences
        """

        tiny_aln = AlignIO.read(StringIO(">A\n" + A_seq + "\n"
                                         ">B\n" + B_seq + "\n"
                                         ">D\n" + D_seq + "\n"
                                         ">C\nAAAACCCCGGGGTTTT\n"
                                         ">E\nACGTACGTACGTACGT\n"), 'fasta')

        myTree = TreeAnc(gtr=mygtr, tree = tiny_tree,
                         aln =tiny_aln, verbose = 4)

        logLH_ref = myTree.ancestral_likelihood()

        return logLH_ref

    #
    def real_lh():
        """
        Likelihood of the sequences calculated by the joint ancestral
        sequence reconstruction
        """
        tiny_aln_1 = AlignIO.read(StringIO(">A\n"+A_char+"\n"
                                           ">B\n"+B_char+"\n"
                                           ">D\n"+D_char+"\n"), 'fasta')

        myTree_1 = TreeAnc(gtr=mygtr, tree = tiny_tree,
                            aln=tiny_aln_1, verbose = 4)

        myTree_1.reconstruct_anc(method='ml', marginal=False, debug=True)
        logLH = myTree_1.tree.sequence_LH
        return logLH

    ref = ref_lh()
    real  = real_lh()

    print(abs(ref.max() - real) )
    # joint chooses the most likely realization of the tree
    assert(abs(ref.max() - real) < 1e-10)
    return ref, real
Пример #16
0
def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling_bias_correction=None,
                                weights=None, verbose=0, iterations=5):
    """take a set of discrete states associated with tips of a tree
    and reconstruct their ancestral states along with a GTR model that
    approximately maximizes the likelihood of the states on the tree.

    Parameters
    ----------
    tree : str, Bio.Phylo.Tree
        name of tree file or Biopython tree object
    traits : dict
        dictionary linking tips to straits
    missing_data : str, optional
        string indicating missing data
    pc : float, optional
        number of pseudo-counts to be used during GTR inference, default 1.0
    sampling_bias_correction : float, optional
        factor to inflate overall switching rate by to counteract sampling bias
    weights : str, optional
        name of file with equilibirum frequencies
    verbose : int, optional
        level of verbosity in output
    iterations : int, optional
        number of times non-linear optimization of overall rate and
        transmission estimation are iterated

    Returns
    -------
    tuple
        tuple of treeanc object, forward and reverse alphabets

    Raises
    ------
    TreeTimeError
        raise error if ancestral reconstruction errors out
    """
    unique_states = sorted(set(traits.values()))
    nc = len(unique_states)
    if nc>180:
        print("mugration: can't have more than 180 states!", file=sys.stderr)
        return None, None, None
    elif nc<2:
        print("mugration: only one or zero states found -- this doesn't make any sense", file=sys.stderr)
        return None, None, None

    ###########################################################################
    ### make a single character alphabet that maps to discrete states
    ###########################################################################
    alphabet = [chr(65+i) for i,state in enumerate(unique_states)]
    missing_char = chr(65+nc)
    letter_to_state = {a:unique_states[i] for i,a in enumerate(alphabet)}
    letter_to_state[missing_char]=missing_data
    reverse_alphabet = {v:k for k,v in letter_to_state.items()}

    ###########################################################################
    ### construct gtr model
    ###########################################################################
    if type(weights)==str:
        tmp_weights = pd.read_csv(weights, sep='\t' if weights[-3:]=='tsv' else ',',
                             skipinitialspace=True)
        weights = {row[0]:row[1] for ri,row in tmp_weights.iterrows()}
        mean_weight = np.mean(list(weights.values()))
        weights = np.array([weights[c] if c in weights else mean_weight for c in unique_states], dtype=float)
        weights/=weights.sum()
    else:
        weights = None

    # set up dummy matrix
    W = np.ones((nc,nc), dtype=float)

    mugration_GTR = GTR.custom(pi = weights, W=W, alphabet = np.array(alphabet))
    mugration_GTR.profile_map[missing_char] = np.ones(nc)
    mugration_GTR.ambiguous=missing_char

    ###########################################################################
    ### set up treeanc
    ###########################################################################
    treeanc = TreeAnc(tree, gtr=mugration_GTR, verbose=verbose,
                      convert_upper=False, one_mutation=0.001)
    treeanc.use_mutation_length = False
    pseudo_seqs = [SeqRecord(id=n.name,name=n.name,
                   seq=Seq(reverse_alphabet[traits[n.name]]
                           if n.name in traits else missing_char))
                   for n in treeanc.tree.get_terminals()]
    treeanc.aln = MultipleSeqAlignment(pseudo_seqs)

    try:
        ndiff = treeanc.infer_ancestral_sequences(method='ml', infer_gtr=True,
            store_compressed=False, pc=pc, marginal=True, normalized_rate=False,
            fixed_pi=weights, reconstruct_tip_states=True)
        treeanc.optimize_gtr_rate()
    except TreeTimeError as e:
        print("\nAncestral reconstruction failed, please see above for error messages and/or rerun with --verbose 4\n")
        raise e

    for i in range(iterations):
        treeanc.infer_gtr(marginal=True, normalized_rate=False, pc=pc)
        treeanc.optimize_gtr_rate()

    if sampling_bias_correction:
        treeanc.gtr.mu *= sampling_bias_correction

    treeanc.infer_ancestral_sequences(infer_gtr=False, store_compressed=False,
                                 marginal=True, normalized_rate=False, reconstruct_tip_states=True)

    print(fill("NOTE: previous versions (<0.7.0) of this command made a 'short-branch length assumption. "
          "TreeTime now optimizes the overall rate numerically and thus allows for long branches "
          "along which multiple changes accumulated. This is expected to affect estimates of the "
          "overall rate while leaving the relative rates mostly unchanged."))

    return treeanc, letter_to_state, reverse_alphabet
Пример #17
0
def reconstruct_discrete_traits(tree, traits, missing_data='?', pc=1.0, sampling_bias_correction=None,
                                weights=None, verbose=0):
    unique_states = sorted(set(traits.values()))
    nc = len(unique_states)
    if nc>180:
        print("mugration: can't have more than 180 states!", file=sys.stderr)
        return 1
    elif nc<2:
        print("mugration: only one or zero states found -- this doesn't make any sense", file=sys.stderr)
        return 1

    ###########################################################################
    ### make a single character alphabet that maps to discrete states
    ###########################################################################
    alphabet = [chr(65+i) for i,state in enumerate(unique_states)]
    missing_char = chr(65+nc)
    letter_to_state = {a:unique_states[i] for i,a in enumerate(alphabet)}
    letter_to_state[missing_char]=missing_data
    reverse_alphabet = {v:k for k,v in letter_to_state.items()}

    ###########################################################################
    ### construct gtr model
    ###########################################################################
    if type(weights)==str:
        tmp_weights = pd.read_csv(weights, sep='\t' if weights[-3:]=='tsv' else ',',
                             skipinitialspace=True)
        weights = {row[0]:row[1] for ri,row in tmp_weights.iterrows()}
        mean_weight = np.mean(list(weights.values()))
        weights = np.array([weights[c] if c in weights else mean_weight for c in unique_states], dtype=float)
        weights/=weights.sum()
    else:
        weights = None

    # set up dummy matrix
    W = np.ones((nc,nc), dtype=float)

    mugration_GTR = GTR.custom(pi = weights, W=W, alphabet = np.array(alphabet))
    mugration_GTR.profile_map[missing_char] = np.ones(nc)
    mugration_GTR.ambiguous=missing_char

    ###########################################################################
    ### set up treeanc
    ###########################################################################
    treeanc = TreeAnc(tree, gtr=mugration_GTR, verbose=verbose,
                      convert_upper=False, one_mutation=0.001)
    treeanc.use_mutation_length = False
    pseudo_seqs = [SeqRecord(id=n.name,name=n.name,
                   seq=Seq(reverse_alphabet[traits[n.name]]
                           if n.name in traits else missing_char))
                   for n in treeanc.tree.get_terminals()]
    treeanc.aln = MultipleSeqAlignment(pseudo_seqs)

    ndiff = treeanc.infer_ancestral_sequences(method='ml', infer_gtr=True,
            store_compressed=False, pc=pc, marginal=True, normalized_rate=False,
            fixed_pi=weights)

    if ndiff==ttconf.ERROR: # if reconstruction failed, exit
        return 1

    if sampling_bias_correction:
        treeanc.gtr.mu *= sampling_bias_correction
        treeanc.infer_ancestral_sequences(infer_gtr=False, store_compressed=False,
                                     marginal=True, normalized_rate=False)
    return treeanc, letter_to_state, reverse_alphabet
Пример #18
0
def mugration_inference(tree=None, seq_meta=None, field='country', confidence=True,
                        infer_gtr=True, root_state=None, missing='?', sampling_bias_correction=None):
    """
    Infer likely ancestral states of a discrete character assuming a time reversible model.

    Parameters
    ----------
    tree : str
        name of tree file
    seq_meta : dict
        meta data associated with sequences
    field : str, optional
        meta data field to use
    confidence : bool, optional
        calculate confidence values for inferences
    infer_gtr : bool, optional
        infer a GTR model for trait transitions (otherwises uses a flat model with rate 1)
    root_state : None, optional
        force the state of the root node (currently not implemented)
    missing : str, optional
        character that is to be interpreted as missing data, default='?'

    Returns
    -------
    T : Phylo.Tree
        Biophyton tree
    gtr : treetime.GTR
        GTR model
    alphabet : dict
        mapping of character states to
    """
    from treetime import GTR
    from Bio.Align import MultipleSeqAlignment
    from Bio.SeqRecord import SeqRecord
    from Bio.Seq import Seq
    from Bio import Phylo

    T = Phylo.read(tree, 'newick')
    nodes = {n.name:n for n in T.get_terminals()}

    # Determine alphabet only counting tips in the tree
    places = set()
    for name, meta in seq_meta.items():
        if field in meta and name in nodes:
            places.add(meta[field])
    if root_state is not None:
        places.add(root_state)

    # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
    places = sorted(places)
    nc = len(places)
    if nc>180:
        print("ERROR: geo_inference: can't have more than 180 places!", file=sys.stderr)
        return None,None,None
    elif nc==0:
        print("ERROR: geo_inference: list of places is empty!", file=sys.stderr)
        return None,None,None
    elif nc==1:
        print("WARNING: geo_inference: only one place found -- set every internal node to %s!"%places[0], file=sys.stderr)
        alphabet = {'A':places[0]}
        alphabet_values = ['A']
        gtr = None
        for node in T.find_clades():
            node.sequence=['A']
            node.marginal_profile=np.array([[1.0]])
    else:
        # set up model
        alphabet = {chr(65+i):place for i,place in enumerate(places)}
        model = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                          alphabet = np.array(sorted(alphabet.keys())))

        missing_char = chr(65+nc)
        alphabet[missing_char]=missing
        model.profile_map[missing_char] = np.ones(nc)
        model.ambiguous = missing_char
        alphabet_rev = {v:k for k,v in alphabet.items()}

        # construct pseudo alignment
        pseudo_seqs = []
        for name, meta in seq_meta.items():
            if name in nodes:
                s=alphabet_rev[meta[field]] if field in meta else missing_char
                pseudo_seqs.append(SeqRecord(Seq(s), name=name, id=name))
        aln = MultipleSeqAlignment(pseudo_seqs)

        # set up treetime and infer
        from treetime import TreeAnc
        tt = TreeAnc(tree=tree, aln=aln, gtr=model, convert_upper=False, verbose=0)
        tt.use_mutation_length = False
        tt.infer_ancestral_sequences(infer_gtr=infer_gtr, store_compressed=False, pc=1.0,
                                     marginal=True, normalized_rate=False)

        if sampling_bias_correction:
            tt.gtr.mu *= sampling_bias_correction
            tt.infer_ancestral_sequences(infer_gtr=False, store_compressed=False,
                                         marginal=True, normalized_rate=False)

        T = tt.tree
        gtr = tt.gtr
        alphabet_values = tt.gtr.alphabet


    # attach inferred states as e.g. node.region = 'africa'
    for node in T.find_clades():
        node.__setattr__(field, alphabet[node.sequence[0]])

    # if desired, attach entropy and confidence as e.g. node.region_entropy = 0.03
    if confidence:
        for node in T.find_clades():
            pdis = node.marginal_profile[0]
            S = -np.sum(pdis*np.log(pdis+TINY))

            marginal = [(alphabet[alphabet_values[i]], pdis[i]) for i in range(len(alphabet_values))]
            marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
            marginal = [(a, b) for a, b in marginal if b > 0.001][:4] #only take stuff over .1% and the top 4 elements
            conf = {a:b for a,b in marginal}
            node.__setattr__(field + "_entropy", S)
            node.__setattr__(field + "_confidence", conf)

    return T, gtr, alphabet
Пример #19
0
def mugration_inference(tree=None, seq_meta=None, field='country', confidence=True,
                        infer_gtr=True, root_state=None, missing='?'):
        from treetime import GTR
        from Bio.Align import MultipleSeqAlignment
        from Bio.SeqRecord import SeqRecord
        from Bio.Seq import Seq


        # Determine alphabet
        places = set()
        for meta in seq_meta.values():
            if field in meta:
                places.add(meta[field])
        if root_state is not None:
            places.add(root_state)

        # construct GTR (flat for now). The missing DATA symbol is a '-' (ord('-')==45)
        places = sorted(places)
        nc = len(places)
        if nc>180:
            print("geo_inference: can't have more than 180 places!")
            return None
        elif nc==1:
            print("geo_inference: only one place found -- set every internal node to %s!"%places[0])
            return None
        elif nc==0:
            print("geo_inference: list of places is empty!")
            return None
        else:
            alphabet = {chr(65+i):place for i,place in enumerate(places)}
            myGeoGTR = GTR.custom(pi = np.ones(nc, dtype=float)/nc, W=np.ones((nc,nc)),
                              alphabet = np.array(sorted(alphabet.keys())))
            missing_char = chr(65+nc)
            alphabet[missing_char]=missing
            myGeoGTR.profile_map[missing_char] = np.ones(nc)
            alphabet_rev = {v:k for k,v in alphabet.iteritems()}

            pseudo_seqs = []
            for name, meta in seq_meta.items():
                s=alphabet_rev[meta[field]] if field in meta else missing_char
                pseudo_seqs.append(SeqRecord(Seq(s), name=name, id=name))
            aln = MultipleSeqAlignment(pseudo_seqs)

            from treetime import TreeAnc
            tt = TreeAnc(tree=tree, aln=aln, gtr=myGeoGTR, convert_upper=False)
            tt.use_mutation_length=False
            tt.infer_ancestral_sequences(infer_gtr=infer_gtr, store_compressed=False, pc=5.0,
                                         marginal=True, normalized_rate=False)

            for node in tt.tree.find_clades():
                node.__setattr__(field, alphabet[node.sequence[0]])

            if confidence:
                for node in tt.tree.find_clades():
                    pdis = node.marginal_profile[0]
                    S = -np.sum(pdis*np.log(pdis+TINY))

                    marginal = [(alphabet[tt.gtr.alphabet[i]], pdis[i]) for i in range(len(tt.gtr.alphabet))]
                    marginal.sort(key=lambda x: x[1], reverse=True) # sort on likelihoods
                    marginal = [(a, b) for a, b in marginal if b > 0.01][:4] #only take stuff over 1% and the top 4 elements
                    conf = {a:b for a,b in marginal}
                    node.__setattr__(field + "_entropy", S)
                    node.__setattr__(field + "_confidence", conf)

            return tt, alphabet