Ejemplo n.º 1
0
    def test_collapsed_vs_full(self):
        tree_uncollapsed = read_tree(TREE_NWK)
        acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT)

        def get_state(node):
            return ', '.join(sorted(getattr(node, feature)))

        df_full = pd.DataFrame.from_dict(
            {
                node.name: get_state(node)
                for node in tree_uncollapsed.traverse()
            },
            orient='index',
            columns=['full'])
        df_collapsed = pd.DataFrame.from_dict(
            {node.name: get_state(node)
             for node in tree.traverse()},
            orient='index',
            columns=['collapsed'])
        df_joint = df_collapsed.join(df_full, how='left')
        self.assertTrue(
            np.all((df_joint['collapsed'] == df_joint['full'])),
            msg=
            'All the node states of the collapsed tree should be the same as of the full one.'
        )
def reroot_tree_randomly():
    rerooted_tree = read_tree(TREE_NWK)
    new_root = np.random.choice([_ for _ in rerooted_tree.traverse()
                                 if not _.is_root() and not _.up.is_root() and _.dist])
    old_root_child = rerooted_tree.children[0]
    old_root_child_dist = old_root_child.dist
    other_children = list(rerooted_tree.children[1:])
    old_root_child.up = None
    for child in other_children:
        child.up = None
        old_root_child.add_child(child, dist=old_root_child_dist + child.dist)
    old_root_child.set_outgroup(new_root)
    new_root = new_root.up
    for _ in new_root.traverse():
        if not _.name:
            _.name = 'unknown'
    return new_root
Ejemplo n.º 3
0
    def clean(self):
        cleaned_data = super().clean()

        sep = cleaned_data.get('data_sep', '\t')
        if not sep or sep == '<tab>':
            self.cleaned_data['data_sep'] = '\t'
            sep = '\t'

        f = cleaned_data.get('tree', None)
        if f:
            f = f.file  # the file in Memory
            try:
                nwks = f.read().decode().replace('\n', '').split(';')
                if not nwks:
                    self.add_error('tree',
                                   u'Could not find any trees (in Newick format) in the file.')
                else:
                    n_trees = len([read_tree(nwk + ';') for nwk in nwks[:-1]])
                    if n_trees > MAX_N_TREES:
                        self.add_error('tree',
                                       u'The file contains too many trees ({}), the limit is {}.'
                                       .format(n_trees, MAX_N_TREES))
            except:
                self.add_error('tree',
                               u'Your tree does not seem to be in Newick format.')

        f = cleaned_data.get('data', None)
        if f:
            f = f.file
            try:
                df = pd.read_table(f, sep=sep, header=0)
                if len(df.columns) < 2:
                    get_col_name = lambda _: '"{}..."'.format(_[:10]) if len(_) > 10 else '"{}"'.format(_)
                    self.add_error('data', u'Your annotation table contains {} column: {}, '
                                           u'while it must contain at least 2: tip ids and their states. '
                                           u'Please check if the separator ("{}") is correct.'
                                   .format(len(df.columns), ', '.join(get_col_name(_) for _ in df.columns),
                                           '<tab>' if sep == '\t' else sep))
            except:
                self.add_error('data', u'We could not parse your annotation table, '
                                       u'please check if the separator ("{}") is correct.'
                               .format('<tab>' if sep == '\t' else sep))
        return cleaned_data
Ejemplo n.º 4
0
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--input_tree', required=True, type=str)
    parser.add_argument('--output_tree', required=True, type=str)
    parser.add_argument('--threshold', required=True, type=str)
    parser.add_argument('--feature', required=True, type=str)
    parser.add_argument('--strict', action='store_true', default=False)
    params = parser.parse_args()

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s: %(message)s',
                        datefmt="%Y-%m-%d %H:%M:%S")

    tr = read_tree(params.input_tree)

    try:
        threshold = float(params.threshold)
    except:
        # may be it's a string threshold then
        threshold = params.threshold

    num_collapsed, num_set_zero_tip, num_set_zero_root = 0, 0, 0
    for n in list(tr.traverse('postorder')):
        children = list(n.children)
        for child in children:
            if getattr(child, params.feature) < threshold \
                    or not params.strict and getattr(child, params.feature) == threshold:
                if child.is_leaf():
                    child.dist = 0
Ejemplo n.º 5
0
from collections import Counter

import pandas as pd
import numpy as np

from pastml.tree import read_tree, collapse_zero_branches
from pastml.acr import acr
from pastml.ml import MPPA, EFT

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre')
STATES_INPUT = os.path.join(DATA_DIR, 'data.txt')

feature = 'Country'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]]
tree = read_tree(TREE_NWK)
acr(tree, df, prediction_method=MPPA, model=EFT)


class ACRStateMPPAEFTTest(unittest.TestCase):
    def test_collapsed_vs_full(self):
        tree_uncollapsed = read_tree(TREE_NWK)
        acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT)

        def get_state(node):
            return ', '.join(sorted(getattr(node, feature)))

        df_full = pd.DataFrame.from_dict(
            {
                node.name: get_state(node)
                for node in tree_uncollapsed.traverse()
Ejemplo n.º 6
0
    parser.add_argument('--tree_tt_acr', required=True, type=str)
    parser.add_argument('--tree_lsd2_nex', required=True, type=str)
    parser.add_argument('--tree_lsd2_acr', required=True, type=str)
    parser.add_argument('--tab', required=True, type=str)
    params = parser.parse_args()

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s: %(message)s',
                        datefmt="%Y-%m-%d %H:%M:%S")

    timetree_lsd2 = parse_nexus(params.tree_lsd2_nex)[0]
    annotate_dates([timetree_lsd2])
    timetree_tt = parse_nexus(params.tree_tt_nex)[0]
    annotate_dates([timetree_tt])

    acrtree_lsd2 = read_tree(params.tree_lsd2_acr, columns=['country'])
    acrtree_tt = read_tree(params.tree_tt_acr, columns=['country'])

    date_df = pd.read_csv(
        params.dates_tt, sep='\t', index_col=0,
        skiprows=[0])[['numeric date', 'lower bound', 'upper bound']]

    get_ci_lsd2 = lambda _: getattr(_, DATE_CI)
    get_ci_tt = lambda _: date_df.loc[_.name, ['lower bound', 'upper bound']]

    def get_dates(acrtree, timetree, get_ci):
        c2dates = {}
        c2size = defaultdict(lambda: 0)
        for n in acrtree.traverse('postorder'):
            if not n.is_root() and getattr(n, 'country') != getattr(
                    n.up, 'country') and not n.is_leaf():
Ejemplo n.º 7
0
from pastml import get_personalized_feature_name, STATES
from pastml.acr import acr
from pastml.models.hky import KAPPA, HKY
from pastml.ml import LH, LH_SF, MPPA, LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR, \
    CHANGES_PER_AVG_BRANCH, SCALING_FACTOR, MARGINAL_PROBABILITIES
from pastml.models.f81_like import JC
from pastml.tree import read_tree

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR,
                        'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.nwk')
STATES_INPUT = os.path.join(
    DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.pastml.tab')

feature = 'ACR'
acr_result_f81 = acr(read_tree(TREE_NWK),
                     pd.read_csv(STATES_INPUT, index_col=0, header=0,
                                 sep='\t')[[feature]],
                     prediction_method=MPPA,
                     model=JC)[0]

df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]]
tree = read_tree(TREE_NWK)
acr_result_hky = acr(tree,
                     df,
                     prediction_method=MPPA,
                     model=HKY,
                     column2parameters={
                         feature: {
                             KAPPA: 1,
                             'A': .25,
Ejemplo n.º 8
0
def _validate_input(columns, data, data_sep, date_column, html, html_compressed, id_index, name_column, tree_nwk,
                    copy_only):
    logger = logging.getLogger('pastml')
    logger.debug('\n=============INPUT DATA VALIDATION=============')
    root = read_tree(tree_nwk)
    num_neg = 0
    for _ in root.traverse():
        if _.dist < 0:
            num_neg += 1
            _.dist = 0
    if num_neg:
        logger.warning('Input tree contained {} negative branches: we put them to zero.'.format(num_neg))
    logger.debug('Read the tree {}.'.format(tree_nwk))

    df = pd.read_csv(data, sep=data_sep, index_col=id_index, header=0, dtype=str)
    df.index = df.index.map(str)
    logger.debug('Read the annotation file {}.'.format(data))

    # As the date column is only used for visualisation if there is no visualisation we are not gonna validate it
    years, tip2date = [], {}
    if html_compressed or html:
        if date_column:
            if date_column not in df.columns:
                raise ValueError('The date column "{}" not found among the annotation columns: {}.'
                                 .format(date_column, _quote(df.columns)))
            try:
                df[date_column] = pd.to_datetime(df[date_column], infer_datetime_format=True)
            except ValueError:
                try:
                    df[date_column] = pd.to_datetime(df[date_column], format='%Y.0')
                except ValueError:
                    raise ValueError('Could not infer the date format for column "{}", please check it.'
                                     .format(date_column))

            tip2date = df.loc[[_.name for _ in root], date_column].apply(date2years).to_dict()
            if not tip2date:
                raise ValueError('Could not find any dates for the tree tips in column {}, please check it.'
                                 .format(date_column))
        annotate_depth(root)
        if not date_column:
            tip2date = {tip.name: round(getattr(tip, DEPTH), 6) for tip in root}
        else:
            tip2date = {t: round(d, 6) if d is not None else None for (t, d) in tip2date.items()}

        dates = [_ for _ in tip2date.values() if _ is not None]
        if not dates:
            tip2date = {tip.name: round(getattr(tip, DEPTH), 6) for tip in root}
            dates = [_ for _ in tip2date.values() if _ is not None]
            date_column = None
            logger.warning('The date column does not contains dates for any of the tree tips, '
                           'therefore we will ignore it')

        min_date = min(dates)
        max_date = max(dates)
        dates = sorted(dates)
        years = sorted({dates[0], dates[len(dates) // 2],
                        dates[1 * len(dates) // 4], dates[3 * len(dates) // 4], dates[-1]})
        logger.debug("Extracted tip {}: they vary between {} and {}."
                     .format('dates' if date_column else 'distances', min_date, max_date))

    if columns:
        if isinstance(columns, str):
            columns = [columns]
        unknown_columns = set(columns) - set(df.columns)
        if unknown_columns:
            raise ValueError('{} of the specified columns ({}) {} not found among the annotation columns: {}.'
                             .format('One' if len(unknown_columns) == 1 else 'Some',
                                     _quote(unknown_columns),
                                     'is' if len(unknown_columns) == 1 else 'are',
                                     _quote(df.columns)))
        df = df[columns]

    df.columns = [col_name2cat(column) for column in df.columns]

    node_names = {n.name for n in root.traverse() if n.name}
    df_index_names = set(df.index)
    filtered_df = df.loc[node_names & df_index_names, :]
    if not filtered_df.shape[0]:
        tip_name_representatives = []
        for _ in root.iter_leaves():
            if len(tip_name_representatives) < 3:
                tip_name_representatives.append(_.name)
            else:
                break
        raise ValueError('Your tree tip names (e.g. {}) do not correspond to annotation id column values (e.g. {}). '
                         'Check your annotation file.'
                         .format(', '.join(tip_name_representatives),
                                 ', '.join(list(df_index_names)[: min(len(df_index_names), 3)])))
    logger.debug('Checked that tip names correspond to annotation file index.')

    if html_compressed and name_column:
        name_column = col_name2cat(name_column)
        if name_column not in df.columns:
            raise ValueError('The name column ("{}") should be one of those specified as columns ({}).'
                             .format(name_column, _quote(df.columns)))
    elif len(df.columns) == 1:
        name_column = df.columns[0]

    percentage_unknown = filtered_df.isnull().sum(axis=0) / filtered_df.shape[0]
    max_unknown_percentage = percentage_unknown.max()
    if max_unknown_percentage >= (.9 if not copy_only else 1):
        raise ValueError('{:.1f}% of tip annotations for column "{}" are unknown, '
                         'not enough data to infer ancestral states. '
                         'Check your annotation file and if its id column corresponds to the tree tip names.'
                         .format(max_unknown_percentage * 100, percentage_unknown.idxmax()))
    percentage_unique = filtered_df.nunique() / filtered_df.count()
    max_unique_percentage = percentage_unique.max()
    if filtered_df.count()[0] > 100 and max_unique_percentage > .5:
        raise ValueError('The column "{}" seem to contain non-categorical data: {:.1f}% of values are unique. '
                         'PASTML cannot infer ancestral states for a tree with too many tip states.'
                         .format(percentage_unique.idxmax(), 100 * max_unique_percentage))
    logger.debug('Finished input validation.')
    return root, df, years, tip2date, name_column
Ejemplo n.º 9
0
from pastml.acr import acr
from pastml.models.hky import KAPPA, HKY
from pastml.ml import LH, LH_SF, MPPA, LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR, \
    CHANGES_PER_AVG_BRANCH, SCALING_FACTOR, FREQUENCIES, MARGINAL_PROBABILITIES
from pastml.models.f81_like import F81
from pastml.tree import read_tree

DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data')
TREE_NWK = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.6.C_0.15.G_0.2.T_0.05.nwk')
STATES_INPUT = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.6.C_0.15.G_0.2.T_0.05.pastml.tab')
TREE_NWK_JC = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.nwk')
STATES_INPUT_JC = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.pastml.tab')

feature = 'ACR'
df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]]
acr_result_f81 = acr(read_tree(TREE_NWK), df, prediction_method=MPPA, model=F81)[0]

tree = read_tree(TREE_NWK)
acr_result_hky = acr(tree, df, prediction_method=MPPA, model=HKY, column2parameters={feature: {KAPPA: 1}})[0]
acr_result_hky_free = acr(read_tree(TREE_NWK), df, prediction_method=MPPA, model=HKY)[0]

print("Log lh for HKY-kappa-fixed {}, HKY {}"
      .format(acr_result_hky[LOG_LIKELIHOOD], acr_result_hky_free[LOG_LIKELIHOOD]))


class HKYF81Test(unittest.TestCase):

    def test_params(self):
        for param in (LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR.format(MPPA), CHANGES_PER_AVG_BRANCH,
                      SCALING_FACTOR):
            self.assertAlmostEqual(acr_result_hky[param], acr_result_f81[param], places=3,
Ejemplo n.º 10
0
        type=str)  # metadata which includes IDs and locations
    parser.add_argument(
        '--input_counts',
        default=
        'C:/Users/Andrew/PycharmProjects/masters_project/25April2020_perstate.csv',
        type=str)  # state case counts since April 25, 2020
    parser.add_argument(
        '--output_stats',
        required=False,
        type=str,
        default='C:/Users/Andrew/Dropbox/masters_project/mutations.table')

    params = parser.parse_args()

    # phylogenetic tree read from the newick file
    tree_mutations = read_tree(params.input_tree_mutations)

    def tip_mutations(tip, length):
        mutations = tip.dist * length
        parent = tip.up
        while parent.up:  #if there is no parent, none will be treated as false, any non value is treated as true
            mutations += parent.dist * length
            parent = parent.up
        return mutations

    def num_mutations(tree, length):
        mutations = {}
        for tip in tree:
            mutations[tip.name] = tip_mutations(tip, length)
        return mutations
Ejemplo n.º 11
0
import pandas as pd
import numpy as np

from pastml.tree import read_tree, remove_certain_leaves

if '__main__' == __name__:
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--in_tree', required=True, type=str)
    parser.add_argument('--metadata', required=True, type=str)
    parser.add_argument('--out_tree_pattern', required=True, type=str)
    parser.add_argument('--drm', required=True, type=str)
    parser.add_argument('--date_column', required=True, type=str)
    params = parser.parse_args()

    df = pd.read_csv(params.metadata, header=0, index_col=0, sep='\t')
    df.index = df.index.map(str)
    tree = read_tree(params.in_tree)
    df = df[np.in1d(df.index, [n.name for n in tree.iter_leaves()])]
    res_df = df[df[params.drm] == 'resistant']
    first_year = min([_ for _ in res_df[params.date_column].unique() if not pd.isnull(_)])
    last_year = max([_ for _ in df[params.date_column].unique() if not pd.isnull(_)])

    for year, label in zip((last_year - 10, first_year), ('mid', 'first')):
        tree = remove_certain_leaves(tree, to_remove=lambda node: pd.isnull(df.loc[node.name, params.date_column])
                                                                  or df.loc[node.name, params.date_column] > year)
        tree.write(outfile=params.out_tree_pattern.format(label), format=3)
Ejemplo n.º 12
0
def generate_map(data,
                 country,
                 location,
                 html,
                 tree=None,
                 data_sep='\t',
                 id_index=0,
                 colours=None):
    df = pd.read_csv(data, sep=data_sep, header=0, index_col=id_index)
    if country not in df.columns:
        raise ValueError(
            'The country column {} not found among the annotation columns: {}.'
            .format(country, df.columns))
    if location not in df.columns:
        raise ValueError(
            'The location column {} not found among the annotation columns: {}.'
            .format(location, df.columns))
    df.sort_values(by=[location], inplace=True, na_position='last')
    ddf = df.drop_duplicates(subset=[country], inplace=False, keep='first')
    country2location = {
        c: l
        for c, l in zip(ddf[country], ddf[location])
        if not pd.isnull(c) and not pd.isnull(l)
    }
    if tree:
        df = df[np.in1d(df.index.astype(np.str),
                        [_.name for _ in read_tree(tree)])]
    unique_countries = {_ for _ in df[country].unique() if not pd.isnull(_)}
    if ISO_EXISTS:
        country2iso = {
            _: Country.get_iso2_from_iso3(iso)
            for (_, iso) in ((_, Country.get_iso3_country_code_fuzzy(_)[0])
                             for _ in country2location.keys())
            if iso and _ in unique_countries
        }
    else:
        country2iso = {
            _: escape(_)
            for _ in country2location.keys() if _ in unique_countries
        }
    iso2num = {
        iso: len(df[df[country] == c])
        for c, iso in country2iso.items()
    }
    iso2loc = {iso: country2location[c] for c, iso in country2iso.items()}
    iso2loc_num = {
        iso: len(df[df[location] == loc])
        for iso, loc in iso2loc.items()
    }
    iso2tooltip = {
        iso: escape('{}: {} samples (out of {} in {})'.format(
            c, iso2num[iso], iso2loc_num[iso], iso2loc[iso]))
        for (c, iso) in country2iso.items()
    }
    locations = sorted([_ for _ in df[location].unique() if not pd.isnull(_)])
    num_unique_values = len(locations)
    if colours:
        colours = parse_colours(colours, locations)
    else:
        colours = get_enough_colours(num_unique_values)
    iso2colour = {
        iso: colours[locations.index(loc)]
        for iso, loc in iso2loc.items()
    }

    env = Environment(loader=PackageLoader('pastml'))
    template = env.get_template('geo_map.html')
    page = template.render(iso2colour=iso2colour,
                           colours=colours,
                           iso2tooltip=iso2tooltip)
    os.makedirs(os.path.abspath(os.path.dirname(html)), exist_ok=True)
    with open(html, 'w+') as fp:
        fp.write(page)
Ejemplo n.º 13
0
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--tree', required=True, type=str)
    parser.add_argument('--tree_acr', required=True, type=str)
    parser.add_argument('--tab', required=True, type=str)
    parser.add_argument('--ntab', required=True, type=str)
    params = parser.parse_args()

    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s: %(message)s',
                        datefmt="%Y-%m-%d %H:%M:%S")

    bs_tree = Tree(params.tree, format=2)
    acr_tree = read_tree(params.tree_acr, columns=['country'])

    c2bs = {}
    c2size = defaultdict(lambda: 0)
    for n in acr_tree.traverse('postorder'):
        if not n.is_root() and getattr(n, 'country') != getattr(
                n.up, 'country') and not n.is_leaf():
            c = getattr(n, 'country')
            if len(c) == 1:
                c = c.pop()
                if c in COUNTRIES and c2size[c] < len(n):
                    bs_n = bs_tree.get_common_ancestor(*(_.name for _ in n))
                    c2size[c] = len(n)
                    c2bs[c] = bs_n.support

    with open(params.tab, 'w+') as f:
Ejemplo n.º 14
0

if '__main__' == __name__:
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--tree', required=True, type=str)
    parser.add_argument('--acr', required=True, type=str)
    parser.add_argument('--columns', default=None, type=str, nargs='*')
    parser.add_argument('--out_log', required=True, type=str)
    params = parser.parse_args()

    df = pd.read_csv(params.acr, header=0, index_col=0, sep='\t')
    df.index = df.index.map(str)
    tree = read_tree(params.tree)
    preannotate_forest(df=df, forest=[tree])

    if params.columns is None or len(params.columns) == 0:
        params.columns = df.columns

    for n in tree.traverse():
        n.add_feature('state', set(state_combinations(n, params.columns)))
    
    from_to_count = Counter()
    for n in tree.traverse('preorder'):
        if n.is_root():
            continue
        states = getattr(n, 'state')
        up_states = getattr(n.up, 'state')
        for s in states:
Ejemplo n.º 15
0
    def test_counts(self):

        tree = read_tree(TREE_NWK)
        character = 'Country'
        df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[character]]
        preannotate_forest([tree], df=df)
        states = np.array(
            [_ for _ in df[character].unique() if not pd.isna(_) and '' != _])
        avg_len, num_nodes, num_tips, tree_len = get_forest_stats([tree])

        freqs, sf, kappa, _ = _parse_pastml_parameters(PARAMS_INPUT,
                                                       states,
                                                       num_tips=num_tips,
                                                       reoptimise=False)
        tau = 0

        model = JC
        n_repetitions = 25_000
        counts = marginal_counts([tree],
                                 character,
                                 model,
                                 states,
                                 num_nodes,
                                 tree_len,
                                 freqs,
                                 sf,
                                 kappa,
                                 tau,
                                 n_repetitions=n_repetitions)

        sim_character = character + '.simulated'
        n_sim_repetitions = n_repetitions * 300
        simulate_states(tree,
                        model,
                        freqs,
                        kappa,
                        tau,
                        sf,
                        sim_character,
                        n_repetitions=n_sim_repetitions)
        good_indices = np.ones(n_sim_repetitions, dtype=int)
        n_states = len(states)
        state2id = dict(zip(states, range(n_states)))
        for tip in tree:
            state_id = state2id[next(iter(getattr(tip, character)))]
            good_indices *= (getattr(tip,
                                     sim_character) == state_id).astype(int)
        num_good_simulations = np.count_nonzero(good_indices)
        print('Simulated {} good configurations'.format(num_good_simulations))

        sim_counts = np.zeros((n_states, n_states), dtype=float)
        for n in tree.traverse('levelorder'):
            from_states = getattr(n, sim_character)[good_indices > 0]
            state_nums = Counter(from_states)
            children_transition_counts = Counter()
            for c in n.children:
                transition_counts = Counter(
                    zip(from_states,
                        getattr(c, sim_character)[good_indices > 0]))
                for (i, j), num in transition_counts.items():
                    sim_counts[i, j] += num
                children_transition_counts.update(transition_counts)
            for i, num in state_nums.items():
                sim_counts[i, i] -= min(num,
                                        children_transition_counts[(i, i)])
        sim_counts /= num_good_simulations
        print(np.round(counts, 2))
        print(np.round(sim_counts, 2))

        for i in range(n_states):
            for j in range(n_states):
                self.assertAlmostEqual(
                    counts[i, j], sim_counts[i, j], 2,
                    'Counts are different for {}->{}: {} (calculated) vs {} (simulated).'
                    .format(states[i], states[j], counts[i, j], sim_counts[i,
                                                                           j]))