def test_collapsed_vs_full(self): tree_uncollapsed = read_tree(TREE_NWK) acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT) def get_state(node): return ', '.join(sorted(getattr(node, feature))) df_full = pd.DataFrame.from_dict( { node.name: get_state(node) for node in tree_uncollapsed.traverse() }, orient='index', columns=['full']) df_collapsed = pd.DataFrame.from_dict( {node.name: get_state(node) for node in tree.traverse()}, orient='index', columns=['collapsed']) df_joint = df_collapsed.join(df_full, how='left') self.assertTrue( np.all((df_joint['collapsed'] == df_joint['full'])), msg= 'All the node states of the collapsed tree should be the same as of the full one.' )
def reroot_tree_randomly(): rerooted_tree = read_tree(TREE_NWK) new_root = np.random.choice([_ for _ in rerooted_tree.traverse() if not _.is_root() and not _.up.is_root() and _.dist]) old_root_child = rerooted_tree.children[0] old_root_child_dist = old_root_child.dist other_children = list(rerooted_tree.children[1:]) old_root_child.up = None for child in other_children: child.up = None old_root_child.add_child(child, dist=old_root_child_dist + child.dist) old_root_child.set_outgroup(new_root) new_root = new_root.up for _ in new_root.traverse(): if not _.name: _.name = 'unknown' return new_root
def clean(self): cleaned_data = super().clean() sep = cleaned_data.get('data_sep', '\t') if not sep or sep == '<tab>': self.cleaned_data['data_sep'] = '\t' sep = '\t' f = cleaned_data.get('tree', None) if f: f = f.file # the file in Memory try: nwks = f.read().decode().replace('\n', '').split(';') if not nwks: self.add_error('tree', u'Could not find any trees (in Newick format) in the file.') else: n_trees = len([read_tree(nwk + ';') for nwk in nwks[:-1]]) if n_trees > MAX_N_TREES: self.add_error('tree', u'The file contains too many trees ({}), the limit is {}.' .format(n_trees, MAX_N_TREES)) except: self.add_error('tree', u'Your tree does not seem to be in Newick format.') f = cleaned_data.get('data', None) if f: f = f.file try: df = pd.read_table(f, sep=sep, header=0) if len(df.columns) < 2: get_col_name = lambda _: '"{}..."'.format(_[:10]) if len(_) > 10 else '"{}"'.format(_) self.add_error('data', u'Your annotation table contains {} column: {}, ' u'while it must contain at least 2: tip ids and their states. ' u'Please check if the separator ("{}") is correct.' .format(len(df.columns), ', '.join(get_col_name(_) for _ in df.columns), '<tab>' if sep == '\t' else sep)) except: self.add_error('data', u'We could not parse your annotation table, ' u'please check if the separator ("{}") is correct.' .format('<tab>' if sep == '\t' else sep)) return cleaned_data
import argparse parser = argparse.ArgumentParser() parser.add_argument('--input_tree', required=True, type=str) parser.add_argument('--output_tree', required=True, type=str) parser.add_argument('--threshold', required=True, type=str) parser.add_argument('--feature', required=True, type=str) parser.add_argument('--strict', action='store_true', default=False) params = parser.parse_args() logging.basicConfig(level=logging.INFO, format='%(asctime)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S") tr = read_tree(params.input_tree) try: threshold = float(params.threshold) except: # may be it's a string threshold then threshold = params.threshold num_collapsed, num_set_zero_tip, num_set_zero_root = 0, 0, 0 for n in list(tr.traverse('postorder')): children = list(n.children) for child in children: if getattr(child, params.feature) < threshold \ or not params.strict and getattr(child, params.feature) == threshold: if child.is_leaf(): child.dist = 0
from collections import Counter import pandas as pd import numpy as np from pastml.tree import read_tree, collapse_zero_branches from pastml.acr import acr from pastml.ml import MPPA, EFT DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') TREE_NWK = os.path.join(DATA_DIR, 'Albanian.tree.152tax.tre') STATES_INPUT = os.path.join(DATA_DIR, 'data.txt') feature = 'Country' df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[feature]] tree = read_tree(TREE_NWK) acr(tree, df, prediction_method=MPPA, model=EFT) class ACRStateMPPAEFTTest(unittest.TestCase): def test_collapsed_vs_full(self): tree_uncollapsed = read_tree(TREE_NWK) acr(tree_uncollapsed, df, prediction_method=MPPA, model=EFT) def get_state(node): return ', '.join(sorted(getattr(node, feature))) df_full = pd.DataFrame.from_dict( { node.name: get_state(node) for node in tree_uncollapsed.traverse()
parser.add_argument('--tree_tt_acr', required=True, type=str) parser.add_argument('--tree_lsd2_nex', required=True, type=str) parser.add_argument('--tree_lsd2_acr', required=True, type=str) parser.add_argument('--tab', required=True, type=str) params = parser.parse_args() logging.basicConfig(level=logging.INFO, format='%(asctime)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S") timetree_lsd2 = parse_nexus(params.tree_lsd2_nex)[0] annotate_dates([timetree_lsd2]) timetree_tt = parse_nexus(params.tree_tt_nex)[0] annotate_dates([timetree_tt]) acrtree_lsd2 = read_tree(params.tree_lsd2_acr, columns=['country']) acrtree_tt = read_tree(params.tree_tt_acr, columns=['country']) date_df = pd.read_csv( params.dates_tt, sep='\t', index_col=0, skiprows=[0])[['numeric date', 'lower bound', 'upper bound']] get_ci_lsd2 = lambda _: getattr(_, DATE_CI) get_ci_tt = lambda _: date_df.loc[_.name, ['lower bound', 'upper bound']] def get_dates(acrtree, timetree, get_ci): c2dates = {} c2size = defaultdict(lambda: 0) for n in acrtree.traverse('postorder'): if not n.is_root() and getattr(n, 'country') != getattr( n.up, 'country') and not n.is_leaf():
from pastml import get_personalized_feature_name, STATES from pastml.acr import acr from pastml.models.hky import KAPPA, HKY from pastml.ml import LH, LH_SF, MPPA, LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR, \ CHANGES_PER_AVG_BRANCH, SCALING_FACTOR, MARGINAL_PROBABILITIES from pastml.models.f81_like import JC from pastml.tree import read_tree DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') TREE_NWK = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.nwk') STATES_INPUT = os.path.join( DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.pastml.tab') feature = 'ACR' acr_result_f81 = acr(read_tree(TREE_NWK), pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]], prediction_method=MPPA, model=JC)[0] df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]] tree = read_tree(TREE_NWK) acr_result_hky = acr(tree, df, prediction_method=MPPA, model=HKY, column2parameters={ feature: { KAPPA: 1, 'A': .25,
def _validate_input(columns, data, data_sep, date_column, html, html_compressed, id_index, name_column, tree_nwk, copy_only): logger = logging.getLogger('pastml') logger.debug('\n=============INPUT DATA VALIDATION=============') root = read_tree(tree_nwk) num_neg = 0 for _ in root.traverse(): if _.dist < 0: num_neg += 1 _.dist = 0 if num_neg: logger.warning('Input tree contained {} negative branches: we put them to zero.'.format(num_neg)) logger.debug('Read the tree {}.'.format(tree_nwk)) df = pd.read_csv(data, sep=data_sep, index_col=id_index, header=0, dtype=str) df.index = df.index.map(str) logger.debug('Read the annotation file {}.'.format(data)) # As the date column is only used for visualisation if there is no visualisation we are not gonna validate it years, tip2date = [], {} if html_compressed or html: if date_column: if date_column not in df.columns: raise ValueError('The date column "{}" not found among the annotation columns: {}.' .format(date_column, _quote(df.columns))) try: df[date_column] = pd.to_datetime(df[date_column], infer_datetime_format=True) except ValueError: try: df[date_column] = pd.to_datetime(df[date_column], format='%Y.0') except ValueError: raise ValueError('Could not infer the date format for column "{}", please check it.' .format(date_column)) tip2date = df.loc[[_.name for _ in root], date_column].apply(date2years).to_dict() if not tip2date: raise ValueError('Could not find any dates for the tree tips in column {}, please check it.' .format(date_column)) annotate_depth(root) if not date_column: tip2date = {tip.name: round(getattr(tip, DEPTH), 6) for tip in root} else: tip2date = {t: round(d, 6) if d is not None else None for (t, d) in tip2date.items()} dates = [_ for _ in tip2date.values() if _ is not None] if not dates: tip2date = {tip.name: round(getattr(tip, DEPTH), 6) for tip in root} dates = [_ for _ in tip2date.values() if _ is not None] date_column = None logger.warning('The date column does not contains dates for any of the tree tips, ' 'therefore we will ignore it') min_date = min(dates) max_date = max(dates) dates = sorted(dates) years = sorted({dates[0], dates[len(dates) // 2], dates[1 * len(dates) // 4], dates[3 * len(dates) // 4], dates[-1]}) logger.debug("Extracted tip {}: they vary between {} and {}." .format('dates' if date_column else 'distances', min_date, max_date)) if columns: if isinstance(columns, str): columns = [columns] unknown_columns = set(columns) - set(df.columns) if unknown_columns: raise ValueError('{} of the specified columns ({}) {} not found among the annotation columns: {}.' .format('One' if len(unknown_columns) == 1 else 'Some', _quote(unknown_columns), 'is' if len(unknown_columns) == 1 else 'are', _quote(df.columns))) df = df[columns] df.columns = [col_name2cat(column) for column in df.columns] node_names = {n.name for n in root.traverse() if n.name} df_index_names = set(df.index) filtered_df = df.loc[node_names & df_index_names, :] if not filtered_df.shape[0]: tip_name_representatives = [] for _ in root.iter_leaves(): if len(tip_name_representatives) < 3: tip_name_representatives.append(_.name) else: break raise ValueError('Your tree tip names (e.g. {}) do not correspond to annotation id column values (e.g. {}). ' 'Check your annotation file.' .format(', '.join(tip_name_representatives), ', '.join(list(df_index_names)[: min(len(df_index_names), 3)]))) logger.debug('Checked that tip names correspond to annotation file index.') if html_compressed and name_column: name_column = col_name2cat(name_column) if name_column not in df.columns: raise ValueError('The name column ("{}") should be one of those specified as columns ({}).' .format(name_column, _quote(df.columns))) elif len(df.columns) == 1: name_column = df.columns[0] percentage_unknown = filtered_df.isnull().sum(axis=0) / filtered_df.shape[0] max_unknown_percentage = percentage_unknown.max() if max_unknown_percentage >= (.9 if not copy_only else 1): raise ValueError('{:.1f}% of tip annotations for column "{}" are unknown, ' 'not enough data to infer ancestral states. ' 'Check your annotation file and if its id column corresponds to the tree tip names.' .format(max_unknown_percentage * 100, percentage_unknown.idxmax())) percentage_unique = filtered_df.nunique() / filtered_df.count() max_unique_percentage = percentage_unique.max() if filtered_df.count()[0] > 100 and max_unique_percentage > .5: raise ValueError('The column "{}" seem to contain non-categorical data: {:.1f}% of values are unique. ' 'PASTML cannot infer ancestral states for a tree with too many tip states.' .format(percentage_unique.idxmax(), 100 * max_unique_percentage)) logger.debug('Finished input validation.') return root, df, years, tip2date, name_column
from pastml.acr import acr from pastml.models.hky import KAPPA, HKY from pastml.ml import LH, LH_SF, MPPA, LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR, \ CHANGES_PER_AVG_BRANCH, SCALING_FACTOR, FREQUENCIES, MARGINAL_PROBABILITIES from pastml.models.f81_like import F81 from pastml.tree import read_tree DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') TREE_NWK = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.6.C_0.15.G_0.2.T_0.05.nwk') STATES_INPUT = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.6.C_0.15.G_0.2.T_0.05.pastml.tab') TREE_NWK_JC = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.nwk') STATES_INPUT_JC = os.path.join(DATA_DIR, 'tree.152taxa.sf_0.5.A_0.25.C_0.25.G_0.25.T_0.25.pastml.tab') feature = 'ACR' df = pd.read_csv(STATES_INPUT, index_col=0, header=0, sep='\t')[[feature]] acr_result_f81 = acr(read_tree(TREE_NWK), df, prediction_method=MPPA, model=F81)[0] tree = read_tree(TREE_NWK) acr_result_hky = acr(tree, df, prediction_method=MPPA, model=HKY, column2parameters={feature: {KAPPA: 1}})[0] acr_result_hky_free = acr(read_tree(TREE_NWK), df, prediction_method=MPPA, model=HKY)[0] print("Log lh for HKY-kappa-fixed {}, HKY {}" .format(acr_result_hky[LOG_LIKELIHOOD], acr_result_hky_free[LOG_LIKELIHOOD])) class HKYF81Test(unittest.TestCase): def test_params(self): for param in (LOG_LIKELIHOOD, RESTRICTED_LOG_LIKELIHOOD_FORMAT_STR.format(MPPA), CHANGES_PER_AVG_BRANCH, SCALING_FACTOR): self.assertAlmostEqual(acr_result_hky[param], acr_result_f81[param], places=3,
type=str) # metadata which includes IDs and locations parser.add_argument( '--input_counts', default= 'C:/Users/Andrew/PycharmProjects/masters_project/25April2020_perstate.csv', type=str) # state case counts since April 25, 2020 parser.add_argument( '--output_stats', required=False, type=str, default='C:/Users/Andrew/Dropbox/masters_project/mutations.table') params = parser.parse_args() # phylogenetic tree read from the newick file tree_mutations = read_tree(params.input_tree_mutations) def tip_mutations(tip, length): mutations = tip.dist * length parent = tip.up while parent.up: #if there is no parent, none will be treated as false, any non value is treated as true mutations += parent.dist * length parent = parent.up return mutations def num_mutations(tree, length): mutations = {} for tip in tree: mutations[tip.name] = tip_mutations(tip, length) return mutations
import pandas as pd import numpy as np from pastml.tree import read_tree, remove_certain_leaves if '__main__' == __name__: import argparse parser = argparse.ArgumentParser() parser.add_argument('--in_tree', required=True, type=str) parser.add_argument('--metadata', required=True, type=str) parser.add_argument('--out_tree_pattern', required=True, type=str) parser.add_argument('--drm', required=True, type=str) parser.add_argument('--date_column', required=True, type=str) params = parser.parse_args() df = pd.read_csv(params.metadata, header=0, index_col=0, sep='\t') df.index = df.index.map(str) tree = read_tree(params.in_tree) df = df[np.in1d(df.index, [n.name for n in tree.iter_leaves()])] res_df = df[df[params.drm] == 'resistant'] first_year = min([_ for _ in res_df[params.date_column].unique() if not pd.isnull(_)]) last_year = max([_ for _ in df[params.date_column].unique() if not pd.isnull(_)]) for year, label in zip((last_year - 10, first_year), ('mid', 'first')): tree = remove_certain_leaves(tree, to_remove=lambda node: pd.isnull(df.loc[node.name, params.date_column]) or df.loc[node.name, params.date_column] > year) tree.write(outfile=params.out_tree_pattern.format(label), format=3)
def generate_map(data, country, location, html, tree=None, data_sep='\t', id_index=0, colours=None): df = pd.read_csv(data, sep=data_sep, header=0, index_col=id_index) if country not in df.columns: raise ValueError( 'The country column {} not found among the annotation columns: {}.' .format(country, df.columns)) if location not in df.columns: raise ValueError( 'The location column {} not found among the annotation columns: {}.' .format(location, df.columns)) df.sort_values(by=[location], inplace=True, na_position='last') ddf = df.drop_duplicates(subset=[country], inplace=False, keep='first') country2location = { c: l for c, l in zip(ddf[country], ddf[location]) if not pd.isnull(c) and not pd.isnull(l) } if tree: df = df[np.in1d(df.index.astype(np.str), [_.name for _ in read_tree(tree)])] unique_countries = {_ for _ in df[country].unique() if not pd.isnull(_)} if ISO_EXISTS: country2iso = { _: Country.get_iso2_from_iso3(iso) for (_, iso) in ((_, Country.get_iso3_country_code_fuzzy(_)[0]) for _ in country2location.keys()) if iso and _ in unique_countries } else: country2iso = { _: escape(_) for _ in country2location.keys() if _ in unique_countries } iso2num = { iso: len(df[df[country] == c]) for c, iso in country2iso.items() } iso2loc = {iso: country2location[c] for c, iso in country2iso.items()} iso2loc_num = { iso: len(df[df[location] == loc]) for iso, loc in iso2loc.items() } iso2tooltip = { iso: escape('{}: {} samples (out of {} in {})'.format( c, iso2num[iso], iso2loc_num[iso], iso2loc[iso])) for (c, iso) in country2iso.items() } locations = sorted([_ for _ in df[location].unique() if not pd.isnull(_)]) num_unique_values = len(locations) if colours: colours = parse_colours(colours, locations) else: colours = get_enough_colours(num_unique_values) iso2colour = { iso: colours[locations.index(loc)] for iso, loc in iso2loc.items() } env = Environment(loader=PackageLoader('pastml')) template = env.get_template('geo_map.html') page = template.render(iso2colour=iso2colour, colours=colours, iso2tooltip=iso2tooltip) os.makedirs(os.path.abspath(os.path.dirname(html)), exist_ok=True) with open(html, 'w+') as fp: fp.write(page)
import argparse parser = argparse.ArgumentParser() parser.add_argument('--tree', required=True, type=str) parser.add_argument('--tree_acr', required=True, type=str) parser.add_argument('--tab', required=True, type=str) parser.add_argument('--ntab', required=True, type=str) params = parser.parse_args() logging.basicConfig(level=logging.INFO, format='%(asctime)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S") bs_tree = Tree(params.tree, format=2) acr_tree = read_tree(params.tree_acr, columns=['country']) c2bs = {} c2size = defaultdict(lambda: 0) for n in acr_tree.traverse('postorder'): if not n.is_root() and getattr(n, 'country') != getattr( n.up, 'country') and not n.is_leaf(): c = getattr(n, 'country') if len(c) == 1: c = c.pop() if c in COUNTRIES and c2size[c] < len(n): bs_n = bs_tree.get_common_ancestor(*(_.name for _ in n)) c2size[c] = len(n) c2bs[c] = bs_n.support with open(params.tab, 'w+') as f:
if '__main__' == __name__: import argparse parser = argparse.ArgumentParser() parser.add_argument('--tree', required=True, type=str) parser.add_argument('--acr', required=True, type=str) parser.add_argument('--columns', default=None, type=str, nargs='*') parser.add_argument('--out_log', required=True, type=str) params = parser.parse_args() df = pd.read_csv(params.acr, header=0, index_col=0, sep='\t') df.index = df.index.map(str) tree = read_tree(params.tree) preannotate_forest(df=df, forest=[tree]) if params.columns is None or len(params.columns) == 0: params.columns = df.columns for n in tree.traverse(): n.add_feature('state', set(state_combinations(n, params.columns))) from_to_count = Counter() for n in tree.traverse('preorder'): if n.is_root(): continue states = getattr(n, 'state') up_states = getattr(n.up, 'state') for s in states:
def test_counts(self): tree = read_tree(TREE_NWK) character = 'Country' df = pd.read_csv(STATES_INPUT, index_col=0, header=0)[[character]] preannotate_forest([tree], df=df) states = np.array( [_ for _ in df[character].unique() if not pd.isna(_) and '' != _]) avg_len, num_nodes, num_tips, tree_len = get_forest_stats([tree]) freqs, sf, kappa, _ = _parse_pastml_parameters(PARAMS_INPUT, states, num_tips=num_tips, reoptimise=False) tau = 0 model = JC n_repetitions = 25_000 counts = marginal_counts([tree], character, model, states, num_nodes, tree_len, freqs, sf, kappa, tau, n_repetitions=n_repetitions) sim_character = character + '.simulated' n_sim_repetitions = n_repetitions * 300 simulate_states(tree, model, freqs, kappa, tau, sf, sim_character, n_repetitions=n_sim_repetitions) good_indices = np.ones(n_sim_repetitions, dtype=int) n_states = len(states) state2id = dict(zip(states, range(n_states))) for tip in tree: state_id = state2id[next(iter(getattr(tip, character)))] good_indices *= (getattr(tip, sim_character) == state_id).astype(int) num_good_simulations = np.count_nonzero(good_indices) print('Simulated {} good configurations'.format(num_good_simulations)) sim_counts = np.zeros((n_states, n_states), dtype=float) for n in tree.traverse('levelorder'): from_states = getattr(n, sim_character)[good_indices > 0] state_nums = Counter(from_states) children_transition_counts = Counter() for c in n.children: transition_counts = Counter( zip(from_states, getattr(c, sim_character)[good_indices > 0])) for (i, j), num in transition_counts.items(): sim_counts[i, j] += num children_transition_counts.update(transition_counts) for i, num in state_nums.items(): sim_counts[i, i] -= min(num, children_transition_counts[(i, i)]) sim_counts /= num_good_simulations print(np.round(counts, 2)) print(np.round(sim_counts, 2)) for i in range(n_states): for j in range(n_states): self.assertAlmostEqual( counts[i, j], sim_counts[i, j], 2, 'Counts are different for {}->{}: {} (calculated) vs {} (simulated).' .format(states[i], states[j], counts[i, j], sim_counts[i, j]))