Пример #1
0
def test_weibull():
    key = 'shape'
    sitemodel = WeibullSiteModel('weibull', Parameter(key,
                                                      torch.tensor([1.0])), 4)
    rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530)
    np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)

    assert torch.sum(sitemodel.rates() *
                     sitemodel.probabilities()).item() == pytest.approx(
                         1.0, 1.0e-6)
Пример #2
0
def test_weibull_batch2():
    key = 'shape'
    sitemodel = WeibullSiteModel(
        'weibull', Parameter(key, torch.tensor([[1.0], [0.1], [1.0]])), 4)
    rates_expected = torch.tensor([
        [0.1457844, 0.5131316, 1.0708310, 2.2702530],
        [4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819],
        [0.1457844, 0.5131316, 1.0708310, 2.2702530],
    ])
    assert torch.allclose(sitemodel.rates(), rates_expected)
Пример #3
0
def test_weibull_mu():
    sitemodel = WeibullSiteModel(
        'weibull',
        Parameter('shape', torch.tensor([1.0])),
        4,
        mu=Parameter('mu', torch.tensor([2.0])),
    )
    rates_expected = torch.tensor([0.1457844, 0.5131316, 1.0708310, 2.2702530
                                   ]) * 2
    np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)
Пример #4
0
def test_weibull_json():
    dic = {}
    rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530)
    new_rates_expected = (4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819)
    sitemodel = WeibullSiteModel.from_json(
        {
            'id': 'weibull',
            'type': 'torchtree.evolution.sitemodel.WeibullSiteModel',
            'categories': 4,
            'shape': {
                'id': 'shape',
                'type': 'torchtree.Parameter',
                'tensor': [1.0],
            },
        },
        dic,
    )
    np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)
    dic['shape'].tensor = torch.tensor(np.array([0.1]))
    np.testing.assert_allclose(sitemodel.rates(),
                               new_rates_expected,
                               rtol=1e-06)
Пример #5
0
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file):
    taxa_list = []
    with open(flu_a_fasta_file) as fp:
        for line in fp:
            if line.startswith('>'):
                taxon = line[1:].strip()
                date = float(taxon.split('_')[-1])
                taxa_list.append(Taxon(taxon, {'date': date}))
    taxa = Taxa('taxa', taxa_list)

    site_pattern = {
        'id': 'sp',
        'type': 'torchtree.evolution.site_pattern.SitePattern',
        'alignment': {
            "id": "alignment",
            "type": "torchtree.evolution.alignment.Alignment",
            'datatype': 'nucleotide',
            'file': flu_a_fasta_file,
            'taxa': 'taxa',
        },
    }
    subst_model = JC69('jc')
    site_model = WeibullSiteModel('site_model',
                                  Parameter(None, torch.tensor([[0.1]])), 4)
    ratios = [0.5] * 67
    root_height = [20.0]
    with open(flu_a_tree_file) as fp:
        newick = fp.read().strip()

    dic = {'taxa': taxa}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree_model', newick, ratios, root_height, 'taxa',
            **{'keep_branch_lengths': True}),
        dic,
    )
    branch_model = StrictClockModel(None,
                                    Parameter(None, torch.tensor([[0.001]])),
                                    tree_model)
    dic['tree_model'] = tree_model
    dic['site_model'] = site_model
    dic['subst_model'] = subst_model
    dic['branch_model'] = branch_model

    like = likelihood.TreeLikelihoodModel.from_json(
        {
            'id': 'like',
            'type': 'torchtree.tree_likelihood.TreeLikelihoodModel',
            'tree_model': 'tree_model',
            'site_model': 'site_model',
            'site_pattern': site_pattern,
            'substitution_model': 'subst_model',
            'branch_model': 'branch_model',
        },
        dic,
    )
    assert torch.allclose(torch.tensor([-4618.2062529058]), like())

    branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1)
    site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1)
    tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat(
        3, 68)
    assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
Пример #6
0
def test_treelikelihood_batch():
    taxa_dict = {
        'id':
        'taxa',
        'type':
        'torchtree.evolution.taxa.Taxa',
        'taxa': [
            {
                "id": "A",
                "type": "torchtree.evolution.taxa.Taxon",
                "attributes": {
                    "date": 0.0
                },
            },
            {
                "id": "B",
                "type": "torchtree.evolution.taxa.Taxon",
                "attributes": {
                    "date": 1.0
                },
            },
            {
                "id": "C",
                "type": "torchtree.evolution.taxa.Taxon",
                "attributes": {
                    "date": 4.0
                },
            },
            {
                "id": "D",
                "type": "torchtree.evolution.taxa.Taxon",
                "attributes": {
                    "date": 5.0
                },
            },
        ],
    }
    tree_model_dict = {
        'id': 'tree',
        'type': 'torchtree.evolution.tree_model.TimeTreeModel',
        'newick': '(((A,B),C),D);',
        'internal_heights': {
            'id': 'heights',
            'type': 'torchtree.Parameter',
            'tensor': [[10.0, 20.0, 30.0], [100.0, 200.0, 300.0]],
        },
        'taxa': taxa_dict,
    }
    tree_model_dict2 = {
        'id': 'tree',
        'type': 'torchtree.evolution.tree_model.TimeTreeModel',
        'newick': '(((A,B),C),D);',
        'internal_heights': {
            'id': 'heights',
            'type': 'torchtree.Parameter',
            'tensor': [100.0, 200.0, 300.0],
        },
        'taxa': taxa_dict,
    }
    site_pattern_dict = {
        "id": "sites",
        "type": "torchtree.evolution.site_pattern.SitePattern",
        "alignment": {
            "id":
            "alignment",
            "type":
            "torchtree.evolution.alignment.Alignment",
            'datatype':
            'nucleotide',
            "taxa":
            'taxa',
            "sequences": [
                {
                    "taxon": "A",
                    "sequence": "AAG"
                },
                {
                    "taxon": "B",
                    "sequence": "AAC"
                },
                {
                    "taxon": "C",
                    "sequence": "AAC"
                },
                {
                    "taxon": "D",
                    "sequence": "AAT"
                },
            ],
        },
    }
    subst_model = JC69('jc')
    # compute using a batch of 2 samples
    dic = {}
    tree_model = TimeTreeModel.from_json(tree_model_dict, dic)
    site_pattern = SitePattern.from_json(site_pattern_dict, dic)
    site_model = WeibullSiteModel(
        None, Parameter(None, torch.tensor([[1.0], [1.0]])), 4)
    clock_model = StrictClockModel(
        None, Parameter(None, torch.tensor([[0.01], [0.001]])), tree_model)
    like_batch = likelihood.TreeLikelihoodModel(None, site_pattern, tree_model,
                                                subst_model, site_model,
                                                clock_model)

    # compute using a batch of 1 sample
    # (the second sample from the previous computation)
    tree_model2 = TimeTreeModel.from_json(tree_model_dict2, {})
    site_model2 = WeibullSiteModel(None, Parameter(None, torch.tensor([1.0])),
                                   4)
    clock_model2 = StrictClockModel(None, Parameter(None,
                                                    torch.tensor([0.001])),
                                    tree_model2)
    like = likelihood.TreeLikelihoodModel(None, site_pattern, tree_model2,
                                          subst_model, site_model2,
                                          clock_model2)

    assert like() == like_batch()[1]