示例#1
0
def test_general_node_height_transform_hetero_all(ratios, root_height, keep,
                                                  expected_ratios_root):
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(A:2,(B:1.5,(C:2,D:1):2.5):2.5);', ratios, root_height,
            dict(zip('ABCD', [5.0, 3.0, 0.0, 1.0])), **{
                'keep_branch_lengths': keep,
                'ratios_id': 'ratios',
                'root_height_id': 'root_height',
            }),
        dic,
    )
    expected = torch.tensor([5.0, 3.0, 0.0, 1.0, 2.0, 4.5, 7.0])
    expected_bounds = torch.tensor([5.0, 3.0, 0.0, 1.0, 1.0, 3.0, 5.0])
    expected_branch_lengths = torch.tensor([2.0, 1.5, 2.0, 1.0, 2.5, 2.5])
    log_det_jacobian = torch.log(expected[5] - expected_bounds[4]) + torch.log(
        expected[6] - expected_bounds[5])
    assert torch.allclose(
        torch.tensor(expected_ratios_root),
        torch.cat((dic['ratios'].tensor, dic['root_height'].tensor)),
    )
    assert torch.allclose(expected, tree_model.node_heights)
    assert torch.allclose(expected_bounds, tree_model.transform._bounds)
    assert torch.allclose(expected_branch_lengths, tree_model.branch_lengths())
    assert torch.allclose(tree_model(), log_det_jacobian)
示例#2
0
def tree_model_node_heights_transformed(ratios_list):
    tree_model = ReparameterizedTimeTreeModel.json_factory(
        'tree',
        '(((A,B),C),D);',
        ratios_list[:-1],
        ratios_list[-1:],
        dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])),
    )
    return tree_model
示例#3
0
def test_general_node_height_transform_hetero(ratios, root_height):
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(((A,B),C),D);', ratios, root_height,
            dict(zip('ABCD', [0.0, 1.0, 4.0, 5.0])), **{
                'ratios_id': 'ratios',
                'root_height_id': 'root_height'
            }),
        dic,
    )
    expected = torch.log((tree_model.node_heights[-2] - 1.0) *
                         (tree_model.node_heights[-1] - 4.0)).item()
    assert tree_model().item() == pytest.approx(expected, 0.0001)
示例#4
0
def test_general_node_height_heights_to_ratios(ratios, root_height):
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(((A,B),C),D);', ratios, root_height,
            dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])), **{
                'ratios_id': 'ratios',
                'root_height_id': 'root_height'
            }),
        {},
    )
    ratios_heights = tree_model.transform.inv(tree_model.node_heights[4:])
    assert torch.allclose(
        ratios_heights,
        torch.tensor(ratios + root_height, dtype=ratios_heights.dtype),
    )
示例#5
0
def test_general_node_height_transform_hetero_7():
    taxa = dict(zip('ABCDEFG', [5.0, 3.0, 0.0, 1.0, 0.0, 5.0, 6.0]))
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(A,(B,(C,(D,(E,(F,G))))));', [0.5] * (len(taxa) - 2),
            [10.0], taxa, **{
                'ratios_id': 'ratios',
                'root_height_id': 'root_height'
            }),
        dic,
    )
    log_det_jacobian = torch.tensor([0.0])
    for i in range(len(taxa), 2 * len(taxa) - 2):
        log_det_jacobian += (tree_model.node_heights[i + 1] -
                             tree_model.transform._bounds[i]).log()
    assert torch.allclose(tree_model(), log_det_jacobian)
示例#6
0
def test_keep_branch_lengths_heights():
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '((((A_0:1.5,B_1:0.5):2.5,C_2:2):2,D_3:3):10,E_12:4);',
            [0.0] * 3, [0.0],
            dict(
                zip(['A_0', 'B_1', 'C_2', 'D_3', 'E_12'],
                    [0.0, 1.0, 2.0, 3.0, 12.0])), **{
                        'keep_branch_lengths': True,
                        'ratios_id': 'ratios',
                        'root_height_id': 'root_height',
                    }),
        dic,
    )
    assert torch.allclose(
        tree_model.branch_lengths(),
        torch.tensor([1.5, 0.5, 2.0, 3.0, 4.0, 2.5, 2.0, 10.0]),
    )
示例#7
0
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file):
    taxa_list = []
    with open(flu_a_fasta_file) as fp:
        for line in fp:
            if line.startswith('>'):
                taxon = line[1:].strip()
                date = float(taxon.split('_')[-1])
                taxa_list.append(Taxon(taxon, {'date': date}))
    taxa = Taxa('taxa', taxa_list)

    site_pattern = {
        'id': 'sp',
        'type': 'torchtree.evolution.site_pattern.SitePattern',
        'alignment': {
            "id": "alignment",
            "type": "torchtree.evolution.alignment.Alignment",
            'datatype': 'nucleotide',
            'file': flu_a_fasta_file,
            'taxa': 'taxa',
        },
    }
    subst_model = JC69('jc')
    site_model = WeibullSiteModel('site_model',
                                  Parameter(None, torch.tensor([[0.1]])), 4)
    ratios = [0.5] * 67
    root_height = [20.0]
    with open(flu_a_tree_file) as fp:
        newick = fp.read().strip()

    dic = {'taxa': taxa}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree_model', newick, ratios, root_height, 'taxa',
            **{'keep_branch_lengths': True}),
        dic,
    )
    branch_model = StrictClockModel(None,
                                    Parameter(None, torch.tensor([[0.001]])),
                                    tree_model)
    dic['tree_model'] = tree_model
    dic['site_model'] = site_model
    dic['subst_model'] = subst_model
    dic['branch_model'] = branch_model

    like = likelihood.TreeLikelihoodModel.from_json(
        {
            'id': 'like',
            'type': 'torchtree.tree_likelihood.TreeLikelihoodModel',
            'tree_model': 'tree_model',
            'site_model': 'site_model',
            'site_pattern': site_pattern,
            'substitution_model': 'subst_model',
            'branch_model': 'branch_model',
        },
        dic,
    )
    assert torch.allclose(torch.tensor([-4618.2062529058]), like())

    branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1)
    site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1)
    tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat(
        3, 68)
    assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
示例#8
0
def create_tree_model(id_: str, taxa: dict, arg):
    tree_format = 'newick'
    with open(arg.tree) as fp:
        if next(fp).upper().startswith('#NEXUS'):
            tree_format = 'nexus'
    if tree_format == 'nexus':
        tree = Tree.get(
            path=arg.tree,
            schema=tree_format,
            tree_offset=0,
            preserve_underscores=True,
        )
        newick = str(tree) + ';'
    else:
        with open(arg.tree) as fp:
            newick = fp.read()
            newick = newick.strip()

    kwargs = {}
    if arg.keep:
        kwargs['keep_branch_lengths'] = True

    if arg.clock is not None:
        dates = [taxon['attributes']['date'] for taxon in taxa['taxa']]
        offset = max(dates) - min(dates)

        if arg.heights == 'ratio':
            ratios = Parameter.json_factory(
                f'{id_}.ratios', **{'tensor': 0.1, 'full': [len(dates) - 2]}
            )
            ratios['lower'] = 0.0
            ratios['upper'] = 1.0

            root_height = Parameter.json_factory(
                f'{id_}.root_height', **{'tensor': [offset + 1.0]}
            )
            if 'root_height_init' in arg:
                root_height['tensor'] = [max(dates) - arg.root_height_init]
            elif arg.coalescent == 'skygrid':
                root_height['tensor'] = arg.cutoff

            root_height['lower'] = offset
            tree_model = ReparameterizedTimeTreeModel.json_factory(
                id_, newick, ratios, root_height, 'taxa', **kwargs
            )
        elif arg.heights == 'shift':
            shifts = Parameter.json_factory(
                f'{id_}.shifts', **{'tensor': 0.1, 'full': [len(dates) - 1]}
            )
            shifts['lower'] = 0.0
            node_heights = {
                'id': f'{id_}.heights',
                'type': 'TransformedParameter',
                'transform': 'torchtree.evolution.tree_height_transform'
                '.DifferenceNodeHeightTransform',
                'x': shifts,
                'parameters': {'tree_model': id_},
            }
            tree_model = FlexibleTimeTreeModel.json_factory(
                id_, newick, node_heights, 'taxa', **kwargs
            )
    else:
        branch_lengths = Parameter.json_factory(
            f'{id_}.blens', **{'tensor': 0.1, 'full': [len(taxa['taxa']) * 2 - 3]}
        )
        branch_lengths['lower'] = 0.0
        tree_model = UnRootedTreeModel.json_factory(
            id_, newick, branch_lengths, 'taxa', **kwargs
        )
    return tree_model