def test_general_node_height_transform_hetero_all(ratios, root_height, keep, expected_ratios_root): dic = {} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree', '(A:2,(B:1.5,(C:2,D:1):2.5):2.5);', ratios, root_height, dict(zip('ABCD', [5.0, 3.0, 0.0, 1.0])), **{ 'keep_branch_lengths': keep, 'ratios_id': 'ratios', 'root_height_id': 'root_height', }), dic, ) expected = torch.tensor([5.0, 3.0, 0.0, 1.0, 2.0, 4.5, 7.0]) expected_bounds = torch.tensor([5.0, 3.0, 0.0, 1.0, 1.0, 3.0, 5.0]) expected_branch_lengths = torch.tensor([2.0, 1.5, 2.0, 1.0, 2.5, 2.5]) log_det_jacobian = torch.log(expected[5] - expected_bounds[4]) + torch.log( expected[6] - expected_bounds[5]) assert torch.allclose( torch.tensor(expected_ratios_root), torch.cat((dic['ratios'].tensor, dic['root_height'].tensor)), ) assert torch.allclose(expected, tree_model.node_heights) assert torch.allclose(expected_bounds, tree_model.transform._bounds) assert torch.allclose(expected_branch_lengths, tree_model.branch_lengths()) assert torch.allclose(tree_model(), log_det_jacobian)
def tree_model_node_heights_transformed(ratios_list): tree_model = ReparameterizedTimeTreeModel.json_factory( 'tree', '(((A,B),C),D);', ratios_list[:-1], ratios_list[-1:], dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])), ) return tree_model
def test_general_node_height_transform_hetero(ratios, root_height): dic = {} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree', '(((A,B),C),D);', ratios, root_height, dict(zip('ABCD', [0.0, 1.0, 4.0, 5.0])), **{ 'ratios_id': 'ratios', 'root_height_id': 'root_height' }), dic, ) expected = torch.log((tree_model.node_heights[-2] - 1.0) * (tree_model.node_heights[-1] - 4.0)).item() assert tree_model().item() == pytest.approx(expected, 0.0001)
def test_general_node_height_heights_to_ratios(ratios, root_height): tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree', '(((A,B),C),D);', ratios, root_height, dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])), **{ 'ratios_id': 'ratios', 'root_height_id': 'root_height' }), {}, ) ratios_heights = tree_model.transform.inv(tree_model.node_heights[4:]) assert torch.allclose( ratios_heights, torch.tensor(ratios + root_height, dtype=ratios_heights.dtype), )
def test_general_node_height_transform_hetero_7(): taxa = dict(zip('ABCDEFG', [5.0, 3.0, 0.0, 1.0, 0.0, 5.0, 6.0])) dic = {} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree', '(A,(B,(C,(D,(E,(F,G))))));', [0.5] * (len(taxa) - 2), [10.0], taxa, **{ 'ratios_id': 'ratios', 'root_height_id': 'root_height' }), dic, ) log_det_jacobian = torch.tensor([0.0]) for i in range(len(taxa), 2 * len(taxa) - 2): log_det_jacobian += (tree_model.node_heights[i + 1] - tree_model.transform._bounds[i]).log() assert torch.allclose(tree_model(), log_det_jacobian)
def test_keep_branch_lengths_heights(): dic = {} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree', '((((A_0:1.5,B_1:0.5):2.5,C_2:2):2,D_3:3):10,E_12:4);', [0.0] * 3, [0.0], dict( zip(['A_0', 'B_1', 'C_2', 'D_3', 'E_12'], [0.0, 1.0, 2.0, 3.0, 12.0])), **{ 'keep_branch_lengths': True, 'ratios_id': 'ratios', 'root_height_id': 'root_height', }), dic, ) assert torch.allclose( tree_model.branch_lengths(), torch.tensor([1.5, 0.5, 2.0, 3.0, 4.0, 2.5, 2.0, 10.0]), )
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file): taxa_list = [] with open(flu_a_fasta_file) as fp: for line in fp: if line.startswith('>'): taxon = line[1:].strip() date = float(taxon.split('_')[-1]) taxa_list.append(Taxon(taxon, {'date': date})) taxa = Taxa('taxa', taxa_list) site_pattern = { 'id': 'sp', 'type': 'torchtree.evolution.site_pattern.SitePattern', 'alignment': { "id": "alignment", "type": "torchtree.evolution.alignment.Alignment", 'datatype': 'nucleotide', 'file': flu_a_fasta_file, 'taxa': 'taxa', }, } subst_model = JC69('jc') site_model = WeibullSiteModel('site_model', Parameter(None, torch.tensor([[0.1]])), 4) ratios = [0.5] * 67 root_height = [20.0] with open(flu_a_tree_file) as fp: newick = fp.read().strip() dic = {'taxa': taxa} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree_model', newick, ratios, root_height, 'taxa', **{'keep_branch_lengths': True}), dic, ) branch_model = StrictClockModel(None, Parameter(None, torch.tensor([[0.001]])), tree_model) dic['tree_model'] = tree_model dic['site_model'] = site_model dic['subst_model'] = subst_model dic['branch_model'] = branch_model like = likelihood.TreeLikelihoodModel.from_json( { 'id': 'like', 'type': 'torchtree.tree_likelihood.TreeLikelihoodModel', 'tree_model': 'tree_model', 'site_model': 'site_model', 'site_pattern': site_pattern, 'substitution_model': 'subst_model', 'branch_model': 'branch_model', }, dic, ) assert torch.allclose(torch.tensor([-4618.2062529058]), like()) branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1) site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1) tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat( 3, 68) assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
def create_tree_model(id_: str, taxa: dict, arg): tree_format = 'newick' with open(arg.tree) as fp: if next(fp).upper().startswith('#NEXUS'): tree_format = 'nexus' if tree_format == 'nexus': tree = Tree.get( path=arg.tree, schema=tree_format, tree_offset=0, preserve_underscores=True, ) newick = str(tree) + ';' else: with open(arg.tree) as fp: newick = fp.read() newick = newick.strip() kwargs = {} if arg.keep: kwargs['keep_branch_lengths'] = True if arg.clock is not None: dates = [taxon['attributes']['date'] for taxon in taxa['taxa']] offset = max(dates) - min(dates) if arg.heights == 'ratio': ratios = Parameter.json_factory( f'{id_}.ratios', **{'tensor': 0.1, 'full': [len(dates) - 2]} ) ratios['lower'] = 0.0 ratios['upper'] = 1.0 root_height = Parameter.json_factory( f'{id_}.root_height', **{'tensor': [offset + 1.0]} ) if 'root_height_init' in arg: root_height['tensor'] = [max(dates) - arg.root_height_init] elif arg.coalescent == 'skygrid': root_height['tensor'] = arg.cutoff root_height['lower'] = offset tree_model = ReparameterizedTimeTreeModel.json_factory( id_, newick, ratios, root_height, 'taxa', **kwargs ) elif arg.heights == 'shift': shifts = Parameter.json_factory( f'{id_}.shifts', **{'tensor': 0.1, 'full': [len(dates) - 1]} ) shifts['lower'] = 0.0 node_heights = { 'id': f'{id_}.heights', 'type': 'TransformedParameter', 'transform': 'torchtree.evolution.tree_height_transform' '.DifferenceNodeHeightTransform', 'x': shifts, 'parameters': {'tree_model': id_}, } tree_model = FlexibleTimeTreeModel.json_factory( id_, newick, node_heights, 'taxa', **kwargs ) else: branch_lengths = Parameter.json_factory( f'{id_}.blens', **{'tensor': 0.1, 'full': [len(taxa['taxa']) * 2 - 3]} ) branch_lengths['lower'] = 0.0 tree_model = UnRootedTreeModel.json_factory( id_, newick, branch_lengths, 'taxa', **kwargs ) return tree_model