def test_weibull(): key = 'shape' sitemodel = WeibullSiteModel('weibull', Parameter(key, torch.tensor([1.0])), 4) rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530) np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06) assert torch.sum(sitemodel.rates() * sitemodel.probabilities()).item() == pytest.approx( 1.0, 1.0e-6)
def test_weibull_batch2(): key = 'shape' sitemodel = WeibullSiteModel( 'weibull', Parameter(key, torch.tensor([[1.0], [0.1], [1.0]])), 4) rates_expected = torch.tensor([ [0.1457844, 0.5131316, 1.0708310, 2.2702530], [4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819], [0.1457844, 0.5131316, 1.0708310, 2.2702530], ]) assert torch.allclose(sitemodel.rates(), rates_expected)
def test_weibull_mu(): sitemodel = WeibullSiteModel( 'weibull', Parameter('shape', torch.tensor([1.0])), 4, mu=Parameter('mu', torch.tensor([2.0])), ) rates_expected = torch.tensor([0.1457844, 0.5131316, 1.0708310, 2.2702530 ]) * 2 np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)
def test_weibull_json(): dic = {} rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530) new_rates_expected = (4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819) sitemodel = WeibullSiteModel.from_json( { 'id': 'weibull', 'type': 'torchtree.evolution.sitemodel.WeibullSiteModel', 'categories': 4, 'shape': { 'id': 'shape', 'type': 'torchtree.Parameter', 'tensor': [1.0], }, }, dic, ) np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06) dic['shape'].tensor = torch.tensor(np.array([0.1])) np.testing.assert_allclose(sitemodel.rates(), new_rates_expected, rtol=1e-06)
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file): taxa_list = [] with open(flu_a_fasta_file) as fp: for line in fp: if line.startswith('>'): taxon = line[1:].strip() date = float(taxon.split('_')[-1]) taxa_list.append(Taxon(taxon, {'date': date})) taxa = Taxa('taxa', taxa_list) site_pattern = { 'id': 'sp', 'type': 'torchtree.evolution.site_pattern.SitePattern', 'alignment': { "id": "alignment", "type": "torchtree.evolution.alignment.Alignment", 'datatype': 'nucleotide', 'file': flu_a_fasta_file, 'taxa': 'taxa', }, } subst_model = JC69('jc') site_model = WeibullSiteModel('site_model', Parameter(None, torch.tensor([[0.1]])), 4) ratios = [0.5] * 67 root_height = [20.0] with open(flu_a_tree_file) as fp: newick = fp.read().strip() dic = {'taxa': taxa} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree_model', newick, ratios, root_height, 'taxa', **{'keep_branch_lengths': True}), dic, ) branch_model = StrictClockModel(None, Parameter(None, torch.tensor([[0.001]])), tree_model) dic['tree_model'] = tree_model dic['site_model'] = site_model dic['subst_model'] = subst_model dic['branch_model'] = branch_model like = likelihood.TreeLikelihoodModel.from_json( { 'id': 'like', 'type': 'torchtree.tree_likelihood.TreeLikelihoodModel', 'tree_model': 'tree_model', 'site_model': 'site_model', 'site_pattern': site_pattern, 'substitution_model': 'subst_model', 'branch_model': 'branch_model', }, dic, ) assert torch.allclose(torch.tensor([-4618.2062529058]), like()) branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1) site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1) tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat( 3, 68) assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
def test_treelikelihood_batch(): taxa_dict = { 'id': 'taxa', 'type': 'torchtree.evolution.taxa.Taxa', 'taxa': [ { "id": "A", "type": "torchtree.evolution.taxa.Taxon", "attributes": { "date": 0.0 }, }, { "id": "B", "type": "torchtree.evolution.taxa.Taxon", "attributes": { "date": 1.0 }, }, { "id": "C", "type": "torchtree.evolution.taxa.Taxon", "attributes": { "date": 4.0 }, }, { "id": "D", "type": "torchtree.evolution.taxa.Taxon", "attributes": { "date": 5.0 }, }, ], } tree_model_dict = { 'id': 'tree', 'type': 'torchtree.evolution.tree_model.TimeTreeModel', 'newick': '(((A,B),C),D);', 'internal_heights': { 'id': 'heights', 'type': 'torchtree.Parameter', 'tensor': [[10.0, 20.0, 30.0], [100.0, 200.0, 300.0]], }, 'taxa': taxa_dict, } tree_model_dict2 = { 'id': 'tree', 'type': 'torchtree.evolution.tree_model.TimeTreeModel', 'newick': '(((A,B),C),D);', 'internal_heights': { 'id': 'heights', 'type': 'torchtree.Parameter', 'tensor': [100.0, 200.0, 300.0], }, 'taxa': taxa_dict, } site_pattern_dict = { "id": "sites", "type": "torchtree.evolution.site_pattern.SitePattern", "alignment": { "id": "alignment", "type": "torchtree.evolution.alignment.Alignment", 'datatype': 'nucleotide', "taxa": 'taxa', "sequences": [ { "taxon": "A", "sequence": "AAG" }, { "taxon": "B", "sequence": "AAC" }, { "taxon": "C", "sequence": "AAC" }, { "taxon": "D", "sequence": "AAT" }, ], }, } subst_model = JC69('jc') # compute using a batch of 2 samples dic = {} tree_model = TimeTreeModel.from_json(tree_model_dict, dic) site_pattern = SitePattern.from_json(site_pattern_dict, dic) site_model = WeibullSiteModel( None, Parameter(None, torch.tensor([[1.0], [1.0]])), 4) clock_model = StrictClockModel( None, Parameter(None, torch.tensor([[0.01], [0.001]])), tree_model) like_batch = likelihood.TreeLikelihoodModel(None, site_pattern, tree_model, subst_model, site_model, clock_model) # compute using a batch of 1 sample # (the second sample from the previous computation) tree_model2 = TimeTreeModel.from_json(tree_model_dict2, {}) site_model2 = WeibullSiteModel(None, Parameter(None, torch.tensor([1.0])), 4) clock_model2 = StrictClockModel(None, Parameter(None, torch.tensor([0.001])), tree_model2) like = likelihood.TreeLikelihoodModel(None, site_pattern, tree_model2, subst_model, site_model2, clock_model2) assert like() == like_batch()[1]