Beispiel #1
0
    def test_collapsed_vs_full(self):
        tree_uncollapsed = read_tree(TREE_NWK)
        acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT)

        def get_state(node):
            return ', '.join(sorted(getattr(node, feature)))

        df_full = pd.DataFrame.from_dict(
            {
                node.name: get_state(node)
                for node in tree_uncollapsed.traverse()
            },
            orient='index',
            columns=['full'])
        df_collapsed = pd.DataFrame.from_dict(
            {node.name: get_state(node)
             for node in tree.traverse()},
            orient='index',
            columns=['collapsed'])
        df_joint = df_collapsed.join(df_full, how='left')
        self.assertTrue(
            np.all((df_joint['collapsed'] == df_joint['full'])),
            msg=
            'All the node states of the collapsed tree should be the same as of the full one.'
        )
 def test_rerooted_values_are_the_same(self):
     for _ in range(5):
         rerooted_tree = reroot_tree_randomly()
         rerooted_acr_result = acr(rerooted_tree, df, prediction_method=MPPA, model=F81)[0]
         for (state, freq, refreq) in zip(acr_result[STATES], acr_result[FREQUENCIES],
                                          rerooted_acr_result[FREQUENCIES]):
             self.assertAlmostEqual(freq, refreq, places=2,
                                    msg='Frequency of {} for the original tree and rerooted tree '
                                        'were supposed to be the same, '
                                        'got {:.3f} vs {:3f}'
                                    .format(state, freq, refreq))
         for label in (LOG_LIKELIHOOD, CHANGES_PER_AVG_BRANCH, SCALING_FACTOR):
             value = acr_result[label]
             rerooted_value = rerooted_acr_result[label]
             self.assertAlmostEqual(value, rerooted_value, places=2,
                                    msg='{} for the original tree and rerooted tree were supposed to be the same, '
                                        'got {:.3f} vs {:3f}'
                                    .format(label, value, rerooted_value))
         mps = acr_result[MARGINAL_PROBABILITIES]
         remps = rerooted_acr_result[MARGINAL_PROBABILITIES]
         for node_name in ('node_4', '02ALAY1660'):
             for loc in acr_result[STATES]:
                 value = mps.loc[node_name, loc]
                 revalue = remps.loc[node_name, loc]
                 self.assertAlmostEqual(value, revalue, places=2,
                                        msg='{}: Marginal probability of {} for the original tree and rerooted tree '
                                            'were supposed to be the same, got {:.3f} vs {:3f}'
                                        .format(node_name, loc, value, revalue))
Beispiel #3
0
    def test_marginal_probs_internal_nodes(self):
        _set_up_pastml_logger(True)
        tree = Tree(TREE_NWK, format=3)
        simulate_states(tree,
                        JTT,
                        JTT_FREQUENCIES,
                        kappa=0,
                        tau=0,
                        sf=1,
                        character='jtt',
                        rate_matrix=None,
                        n_repetitions=1)
        for tip in tree:
            tip.add_feature('state1', {JTT_STATES[getattr(tip, 'jtt')][0]})
            tip.add_feature('state2', {JTT_STATES[getattr(tip, 'jtt')][0]})

        acr_result_jtt = \
        acr(tree, columns=['state1'], column2states={'state1': JTT_STATES}, prediction_method=MPPA, model=JTT)[0]
        os.makedirs(WD, exist_ok=True)
        _serialize_acr((acr_result_jtt, WD))
        params = os.path.join(WD,
                              get_pastml_parameter_file(MPPA, JTT, 'state1'))

        save_custom_rates(JTT_STATES, JTT_RATE_MATRIX, RM)
        acr_result_cr = acr(tree,
                            columns=['state2'],
                            prediction_method=MPPA,
                            model=CUSTOM_RATES,
                            column2parameters={'state2': params},
                            column2rates={'state2': RM},
                            column2states={'state2': JTT_STATES})[0]
        self.assertEqual(
            acr_result_jtt[LOG_LIKELIHOOD],
            acr_result_cr[LOG_LIKELIHOOD],
            msg='Likelihood should be the same for JTT and CR with JTT matrix')

        mps_jtt = acr_result_jtt[MARGINAL_PROBABILITIES]
        mps_cr = acr_result_cr[MARGINAL_PROBABILITIES]
        self.assertTrue(
            np.all(mps_jtt == mps_cr),
            msg=
            'Marginal probabilities be the same for JTT and CR with JTT matrix'
        )

        shutil.rmtree(WD)
        os.remove(RM)
Beispiel #4
0
import pandas as pd
import numpy as np

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.ml import MPPA, EFT

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr(tree, df, prediction_method=MPPA, model=EFT)


class ACRStateMPPAEFTTest(unittest.TestCase):
    def test_collapsed_vs_full(self):
        tree_uncollapsed = read_tree(TREE_NWK)
        acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT)

        def get_state(node):
            return ', '.join(sorted(getattr(node, feature)))

        df_full = pd.DataFrame.from_dict(
            {
                node.name: get_state(node)
                for node in tree_uncollapsed.traverse()
            },
def get_binary_trait_subtrees(tre,
                              csv,
                              tiplabel_in_csv=None,
                              elements=1,
                              trait_column="adm2",
                              trait_value="NORFOLK",
                              n_threads=4,
                              method="DOWNPASS"):
    '''
    Instead of reconstructing all states, we ask if ancestral state is value or not. Still we allow for `elements` > 1
    (2 in practice), so that we accept "yes and no" ancestral nodes. 
    You can group trait values into a list 
    '''
    if elements < 1:
        elements = 1  ## inferred state cardinality (1 makes much more sense, but you can set "2" as well)
    if method not in [
            "MPPA", "MAP", "JOINT", "ACCTRAN", "DELTRAN", "DOWNPASS"
    ]:
        method = "DOWNPASS"
    if tiplabel_in_csv:  # o.w. we assume input csv already has it
        csv_column = csv[[tiplabel_in_csv, trait_column
                          ]]  # two columns now, tiplabel is removed below
        csv_column.set_index(
            "sequence_name", drop=True,
            inplace=True)  # acr needs index mapping ete3 leaves
    else:
        csv_column = csv[[
            trait_column
        ]]  ## nodes will have e.g. n.adm2 (access through getattr(n.adm2)
    if isinstance(trait_value, str):  # assume it's a list, below
        trait_value = [trait_value]

    # transform variable into binary
    new_trait = str(trait_column) + "_is_" + "_or_".join(trait_value)
    csv_column[new_trait] = csv_column[trait_column].map(
        lambda a: "yes" if a in trait_value else "no")

    csv_column.drop(labels=[trait_column], axis=1, inplace=True)
    ## Ancestral state reconstruction of given trait
    result = acr(tre,
                 csv_column,
                 prediction_method=method,
                 force_joint=False,
                 threads=n_threads
                 )  ## annotates tree nodes with states (e.g. tre2.adm2)
    ## Find all internal nodes where trait_value is possible state (b/c is seen at tips below)
    matches = filter(
        lambda n: not n.is_leaf() and "yes" in getattr(n, new_trait)
        and  # traits are sets (not lists)
        len(getattr(n, new_trait)) <= elements,
        tre.traverse("preorder"))

    stored_leaves = set(
    )  # set of leaf names (created with get_cached_content)
    subtrees = []  # list of non-overlapping nodes
    node2leaves = tre.get_cached_content(
        store_attr="name"
    )  # set() of leaves below every node; store leaf name only
    for xnode in matches:
        if not bool(stored_leaves & node2leaves[xnode]
                    ):  # both are sets; bool is just to be verbose
            stored_leaves.update(
                node2leaves[xnode])  # update() is append() for sets ;)
            subtrees.append(xnode)
    mono = tre.get_monophyletic(values="yes",
                                target_attr=new_trait)  # from ete3
    return subtrees, mono, result, new_trait
def get_ancestral_trait_subtrees(tre,
                                 csv,
                                 tiplabel_in_csv=None,
                                 elements=1,
                                 trait_column="adm2",
                                 trait_value="NORFOLK",
                                 n_threads=4,
                                 method="DOWNPASS"):
    '''
    Returns ancestral nodes predicted to have a given trait_value (e.g. "NORFOLK") for a given trait column (e.g. "adm2").
    Also returns nodes scrictly monophyletic regarding value.
    If using parsimony then it's better to store only nodes with a single state (or create a binary state?) otherwise
    we may end up chosing too close to root node...    MAP works well to find almost monophyletic nodes, but is quite slow
    max likelihood methods: pastml.ml.MPPA, pastml.ml.MAP, pastml.ml.JOINT,
    max parsimony methods: pastml.parsimony.ACCTRAN, pastml.parsimony.DELTRAN, pastml.parsimony.DOWNPASS
    '''
    if elements < 1:
        elements = 1  ## inferred state cardinality (how many values are allowed on matching internal node?)
    if method not in [
            "MPPA", "MAP", "JOINT", "ACCTRAN", "DELTRAN", "DOWNPASS"
    ]:
        method = "DOWNPASS"
    if tiplabel_in_csv:  # o.w. we assume input csv already has it
        csv_column = csv[[tiplabel_in_csv, trait_column
                          ]]  # two columns now, tiplabel is removed below
        csv_column.set_index(
            "sequence_name", drop=True,
            inplace=True)  # acr needs index mapping ete3 leaves
    else:
        csv_column = csv[[
            trait_column
        ]]  ## nodes will have e.g. n.adm2 (access through getattr(n.adm2)
    if isinstance(trait_value, list):
        trait_value = trait_value[
            0]  ## This function can handle only one value; for grouping try the get_binary version
    ## Ancestral state reconstruction of given trait
    result = acr(tre,
                 csv_column,
                 prediction_method=method,
                 force_joint=False,
                 threads=n_threads
                 )  ## annotates tree nodes with states (e.g. tre2.adm2)

    ## Find all internal nodes where trait_value is possible state (b/c is seen at tips below)
    matches = filter(
        lambda n: not n.is_leaf() and trait_value in getattr(
            n, trait_column) and len(getattr(n, trait_column)) <= elements,
        tre.traverse("preorder"))
    # print ([x.__dict__ for x in matches]) dictionary of attributes; ete3 also has n.features[] with dict keys

    stored_leaves = set(
    )  # set of leaf names (created with get_cached_content)
    subtrees = []  # list of non-overlapping nodes
    node2leaves = tre.get_cached_content(
        store_attr="name"
    )  # set() of leaves below every node; store leaf name only
    for xnode in matches:
        if not bool(stored_leaves & node2leaves[xnode]
                    ):  # both are sets; bool is just to be verbose
            stored_leaves.update(
                node2leaves[xnode])  # update() is append() for sets ;)
            subtrees.append(xnode)
    mono = tre.get_monophyletic(values=trait_value,
                                target_attr=trait_column)  # from ete3
    return subtrees, mono, result
Beispiel #7
0
import pandas as pd
import numpy as np

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.ml import JOINT, F81

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr(tree, df, prediction_method=JOINT, model=F81)


class ACRStateJointF81Test(unittest.TestCase):
    def test_collapsed_vs_full(self):
        tree_uncollapsed = read_tree(TREE_NWK)
        acr(tree_uncollapsed, df, prediction_method=JOINT, model=F81)

        def get_state(node):
            state = getattr(node, feature)
            return state if not isinstance(state, list) else ', '.join(
                sorted(state))

        df_full = pd.DataFrame.from_dict(
            {
                node.name: get_state(node)
Beispiel #8
0
import pandas as pd

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.parsimony import ACCTRAN, STEPS

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
collapse_zero_branches([tree])
acr_result = acr(tree, df, prediction_method=ACCTRAN)[0]


class ACRStateAcctranTest(unittest.TestCase):
    def test_num_steps(self):
        self.assertEqual(
            32,
            acr_result[STEPS],
            msg='Was supposed to have {} parsimonious steps, got {}.'.format(
                32, acr_result[STEPS]))

    def test_num_nodes(self):
        state2num = Counter()
        for node in tree.traverse():
            state = getattr(node, feature)
            if len(state) > 1:
Beispiel #9
0
from pastml.models.hky import KAPPA, HKY
from pastml.ml import LH, LH_SF, MPPA, LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR, \
    CHANGES_PER_AVG_BRANCH, SCALING_FACTOR, MARGINAL_PROBABILITIES
from pastml.models.f81_like import JC
from pastml.tree import read_tree

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR,
                        'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.nwk')
STATES_INPUT = os.path.join(
    DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.pastml.tab')

feature = 'ACR'
acr_result_f81 = acr(read_tree(TREE_NWK),
                     pd.read_csv(STATES_INPUT, index_col=0, header=0,
                                 sep='\t')[[feature]],
                     prediction_method=MPPA,
                     model=JC)[0]

df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]]
tree = read_tree(TREE_NWK)
acr_result_hky = acr(tree,
                     df,
                     prediction_method=MPPA,
                     model=HKY,
                     column2parameters={
                         feature: {
                             KAPPA: 1,
                             'A': .25,
                             'T': .25,
                             'C': .25,
Beispiel #10
0
    old_root_child.up = None
    for child in other_children:
        child.up = None
        old_root_child.add_child(child, dist=old_root_child_dist + child.dist)
    old_root_child.set_outgroup(new_root)
    new_root = new_root.up
    for _ in new_root.traverse():
        if not _.name:
            _.name = 'unknown'
    return new_root


feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr_result = acr(tree, df, prediction_method=MPPA, model=EFT)[0]


class ACRParameterOptimisationMPPAEFTTest(unittest.TestCase):
    def test_rerooted_values_are_the_same(self):
        for _ in range(5):
            rerooted_tree = reroot_tree_randomly()
            rerooted_acr_result = acr(rerooted_tree,
                                      df,
                                      prediction_method=MPPA,
                                      model=EFT)[0]
            for (state, freq, refreq) in zip(acr_result[STATES],
                                             acr_result[FREQUENCIES],
                                             rerooted_acr_result[FREQUENCIES]):
                self.assertAlmostEqual(
                    freq,
Beispiel #11
0
from collections import Counter

import pandas as pd

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr, COPY

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'copy_states.tab')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]]
tree = read_tree(TREE_NWK)
collapse_zero_branches([tree])
acr_result = acr(tree, df, prediction_method=COPY)[0]


class ACRStateDownpassTest(unittest.TestCase):
    def test_num_nodes(self):
        state2num = Counter()
        for node in tree.traverse():
            state = getattr(node, feature)
            if len(state) > 1:
                state2num['unresolved'] += 1
            else:
                state2num[next(iter(state))] += 1
        expected_state2num = {
            'unresolved': 5,
            'Africa': 114,
            'Albania': 50,
Beispiel #12
0
import pandas as pd

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.parsimony import DOWNPASS, STEPS

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
collapse_zero_branches([tree])
acr_result = acr(tree, df, prediction_method=DOWNPASS)[0]


class ACRStateDownpassTest(unittest.TestCase):
    def test_num_steps(self):
        self.assertEqual(
            32,
            acr_result[STEPS],
            msg='Was supposed to have {} parsimonious steps, got {}.'.format(
                32, acr_result[STEPS]))

    def test_num_nodes(self):
        state2num = Counter()
        for node in tree.traverse():
            state = getattr(node, feature)
            if len(state) > 1:
Beispiel #13
0
from pastml.acr import acr
from pastml.models.hky import KAPPA, HKY
from pastml.ml import LH, LH_SF, MPPA, LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR, \
    CHANGES_PER_AVG_BRANCH, SCALING_FACTOR, FREQUENCIES, MARGINAL_PROBABILITIES
from pastml.models.f81_like import F81
from pastml.tree import read_tree

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.6.C_0.15.G_0.2.T_0.05.nwk')
STATES_INPUT = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.6.C_0.15.G_0.2.T_0.05.pastml.tab')
TREE_NWK_JC = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.nwk')
STATES_INPUT_JC = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.pastml.tab')

feature = 'ACR'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]]
acr_result_f81 = acr(read_tree(TREE_NWK), df, prediction_method=MPPA, model=F81)[0]

tree = read_tree(TREE_NWK)
acr_result_hky = acr(tree, df, prediction_method=MPPA, model=HKY, column2parameters={feature: {KAPPA: 1}})[0]
acr_result_hky_free = acr(read_tree(TREE_NWK), df, prediction_method=MPPA, model=HKY)[0]

print("Log lh for HKY-kappa-fixed {}, HKY {}"
      .format(acr_result_hky[LOG_LIKELIHOOD], acr_result_hky_free[LOG_LIKELIHOOD]))


class HKYF81Test(unittest.TestCase):

    def test_params(self):
        for param in (LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR.format(MPPA), CHANGES_PER_AVG_BRANCH,
                      SCALING_FACTOR):
            self.assertAlmostEqual(acr_result_hky[param], acr_result_f81[param], places=3,
Beispiel #14
0
import pandas as pd
import numpy as np

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.ml import MAP
from pastml.models.f81_like import JC

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr(tree, df, prediction_method=MAP, model=JC)


class ACRStateMAPJCTest(unittest.TestCase):
    def test_collapsed_vs_full(self):
        tree_uncollapsed = read_tree(TREE_NWK)
        acr(tree_uncollapsed, df, prediction_method=MAP, model=JC)

        def get_state(node):
            state = getattr(node, feature)
            return state if not isinstance(state, list) else ', '.join(
                sorted(state))

        df_full = pd.DataFrame.from_dict(
            {
                node.name: get_state(node)
def get_binary_trait_subtrees(tre,
                              csv,
                              tiplabel_in_csv=None,
                              elements=1,
                              trait_column="adm2",
                              trait_value="NORFOLK",
                              n_threads=4,
                              method="DOWNPASS",
                              extended_mode=0):
    '''
    Instead of reconstructing all states, we ask if ancestral state is value or not. Still we allow for `elements` > 1
    (2 in practice), so that we accept "yes and no" ancestral nodes. 
    You can group trait values into a list 
    '''
    if elements < 1:
        elements = 1  ## inferred state cardinality (1 makes much more sense, but you can set "2" as well)
    if method not in [
            "MPPA", "MAP", "JOINT", "ACCTRAN", "DELTRAN", "DOWNPASS"
    ]:
        method = "DOWNPASS"
    if tiplabel_in_csv:  # o.w. we assume input csv already has it
        csv_column = csv[[tiplabel_in_csv, trait_column
                          ]]  # two columns now, tiplabel is removed below
        csv_column.set_index(
            "sequence_name", drop=True,
            inplace=True)  # acr needs index mapping ete3 leaves
    else:
        csv_column = csv[[
            trait_column
        ]]  ## nodes will have e.g. n.adm2 (access through getattr(n.adm2)
    if isinstance(trait_value, str):  # assume it's a list, below
        trait_value = [trait_value]

    # transform variable into binary
    new_trait = str(trait_column) + "_is_" + "_or_".join(trait_value)
    csv_column[new_trait] = csv_column[trait_column].map(
        lambda a: "yes" if a in trait_value else "no")
    tree_leaf_nodes = {leaf: leaf.name
                       for leaf in tre.iter_leaves()
                       }  # node:leafname, not the other way around as usual...

    csv_column.drop(labels=[trait_column], axis=1, inplace=True)
    ## Ancestral state reconstruction of given trait
    result = acr(tre,
                 csv_column,
                 prediction_method=method,
                 force_joint=False,
                 threads=n_threads
                 )  ## annotates tree nodes with states (e.g. tre2.adm2)

    for leafnode, leafname in tree_leaf_nodes.items(
    ):  # our tree may have dups, and pastml (correctly) replaces them by a placeholder "t123"
        leafnode.name = leafname  # here we revert back to duplicate names (dictionary keys are nodes, and values are original names

    logger.debug("Finished lowlevel binary_trait_subtrees()")

    stored_leaves = set(
    )  # set of leaf names (created with get_cached_content)
    subtrees = []  # list of non-overlapping nodes
    node2leaves = tre.get_cached_content(
        store_attr="name"
    )  # set() of leaves below every node; store leaf name only

    if extended_mode == 0:
        ## Find all internal nodes where trait_value is possible state (b/c is seen at tips below)
        matches = filter(
            lambda n: not n.is_leaf() and "yes" in getattr(n, new_trait)
            and  # traits are sets (not lists)
            len(getattr(n, new_trait)) <= elements,
            tre.traverse("preorder"))
        for xnode in matches:
            if not bool(stored_leaves & node2leaves[xnode]
                        ):  # both are sets; bool is just to be verbose
                stored_leaves.update(
                    node2leaves[xnode])  # update() is append() for sets ;)
                subtrees.append(xnode)
    else:
        matches = filter(
            lambda n: "yes" in getattr(n, new_trait) and len(
                getattr(n, new_trait)) <= (elements + 1),
            tre.traverse("preorder"))
        for xnode in matches:
            if not bool(stored_leaves & node2leaves[xnode]
                        ):  # both are sets; bool is just to be verbose
                stored_leaves.update(
                    node2leaves[xnode])  # update() is append() for sets ;)
                if extended_mode == 2 and xnode.up.up is not None:
                    subtrees.append(xnode.up.up)
                elif xnode.up is not None:  # extended_mode 1 or 2
                    subtrees.append(xnode.up)
                else:
                    subtrees.append(xnode)

    mono = tre.get_monophyletic(values="yes",
                                target_attr=new_trait)  # from ete3
    subtrees = list(set(subtrees))  ## for cases where node->up are the same
    return subtrees, mono, result, new_trait
def ASR_subtrees(metadata0, tree, extended_mode=0, reroot=True, method=None):
    """ ancestral state reconstruction assuming binary or not. Clusters are based on "locality": patient is from Norfolk
    *or* was sequenced here in NORW. One column at a time, so the n_threads doesn't help.
    """
    if method is None: method = ["DOWNPASS", "ACCTRAN"]
    if isinstance(method, str): method = [method, method]

    metadata = metadata0.copy()  # work with copies
    csv_cols = [x for x in common.asr_cols if x in metadata.columns]
    if reroot:  ## this should be A,B
        R = tree.get_midpoint_outgroup()
        tree.set_outgroup(R)

    md_description = """
Sequence clusters are based on **locality**: patients from Norfolk (field `adm2`) *or* patients that were sequenced here
(submission org = `NORW`).
This definition is not equivalent to the `UK lineages`, which relies on sequence properties and is estimated at the
national level. 
Our method does not take the `UK lineages` explicitly into account and as result we may find clusters spanning several
lineages, and we may also see samples from the same lineage scattered across clusters (reflecting their locality).
The **locality** allows us to focus on the local scale, by "zooming in" into geographycally connected lineages.
<br>
    """
    if extended_mode == 0:
        yesno = ((metadata["adm2"].str.contains(
            "Norfolk", case=False, na=False))
                 | (metadata["submission_org_code"].str.contains(
                     "NORW", case=False, na=False)))
    else:
        md_description = "Extended mode of **peroba**, for non-COGUK analyses<br><br>"
        yesno = (metadata["submission_org_code"].str.contains("NORW",
                                                              case=False,
                                                              na=False))

    logger.info(
        "Start estimating ancestral states by %s for locality and %s for others",
        method[0], method[1])
    df = pd.DataFrame(list(yesno.astype(str)),
                      columns=["local"],
                      index=metadata.index.copy())
    x = get_binary_trait_subtrees(tree,
                                  df,
                                  trait_column="local",
                                  trait_value="True",
                                  elements=1,
                                  method=method[0],
                                  extended_mode=extended_mode)
    subtree, mono, result, trait_name = x  # most important is subtree
    logger.info(
        "Finished estimating ancestral state for 'locality', which defines clusters"
    )

    ## decorate the tree with ancestral states (csv informs which columns should be on tree as node attribute)
    csv, csv_cols = prepare_csv_columns_for_asr(
        metadata,
        csv_cols)  # csv is used only in pastml (imputation go to tree)
    if (csv_cols):
        logger.info("Will now estimate ancestral states for %s",
                    " ".join(csv_cols))
        tree_leaf_nodes = {leaf: leaf.name
                           for leaf in tree.iter_leaves()
                           }  # in case we have duplicated names
        result = acr(tree, csv, prediction_method=method[1], force_joint=False
                     )  ## annotates tree nodes with states (e.g. tre2.adm2)
        for leafnode, leafname in tree_leaf_nodes.items(
        ):  # pastml (correctly) replaces duplicated names by a placeholder like "t123"
            leafnode.name = leafname  # reverts back to duplicate names
    # adds new peroba_ columns with imputed and original values:
    metadata = save_metadata_inferred(metadata, tree, csv_cols)

    node2leaves = tree.get_cached_content(
        store_attr="name")  # dict node:[leaves]
    submetadata = [
        metadata.loc[metadata.index.isin(node2leaves[x])] for x in subtree
    ]
    # not currently used
    supermetadata = [
        metadata.loc[metadata.index.isin(node2leaves[x.up])] for x in subtree
        if x.up
    ]

    return submetadata, subtree, md_description, metadata  ## metadata now has tip reconstruction