コード例 #1
0
def test_general_node_height_transform_hetero_all(ratios, root_height, keep,
                                                  expected_ratios_root):
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(A:2,(B:1.5,(C:2,D:1):2.5):2.5);', ratios, root_height,
            dict(zip('ABCD', [5.0, 3.0, 0.0, 1.0])), **{
                'keep_branch_lengths': keep,
                'ratios_id': 'ratios',
                'root_height_id': 'root_height',
            }),
        dic,
    )
    expected = torch.tensor([5.0, 3.0, 0.0, 1.0, 2.0, 4.5, 7.0])
    expected_bounds = torch.tensor([5.0, 3.0, 0.0, 1.0, 1.0, 3.0, 5.0])
    expected_branch_lengths = torch.tensor([2.0, 1.5, 2.0, 1.0, 2.5, 2.5])
    log_det_jacobian = torch.log(expected[5] - expected_bounds[4]) + torch.log(
        expected[6] - expected_bounds[5])
    assert torch.allclose(
        torch.tensor(expected_ratios_root),
        torch.cat((dic['ratios'].tensor, dic['root_height'].tensor)),
    )
    assert torch.allclose(expected, tree_model.node_heights)
    assert torch.allclose(expected_bounds, tree_model.transform._bounds)
    assert torch.allclose(expected_branch_lengths, tree_model.branch_lengths())
    assert torch.allclose(tree_model(), log_det_jacobian)
コード例 #2
0
def test_general_node_height_transform_hetero(ratios, root_height):
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(((A,B),C),D);', ratios, root_height,
            dict(zip('ABCD', [0.0, 1.0, 4.0, 5.0])), **{
                'ratios_id': 'ratios',
                'root_height_id': 'root_height'
            }),
        dic,
    )
    expected = torch.log((tree_model.node_heights[-2] - 1.0) *
                         (tree_model.node_heights[-1] - 4.0)).item()
    assert tree_model().item() == pytest.approx(expected, 0.0001)
コード例 #3
0
def test_general_node_height_heights_to_ratios(ratios, root_height):
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(((A,B),C),D);', ratios, root_height,
            dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])), **{
                'ratios_id': 'ratios',
                'root_height_id': 'root_height'
            }),
        {},
    )
    ratios_heights = tree_model.transform.inv(tree_model.node_heights[4:])
    assert torch.allclose(
        ratios_heights,
        torch.tensor(ratios + root_height, dtype=ratios_heights.dtype),
    )
コード例 #4
0
def test_general_node_height_transform_hetero_7():
    taxa = dict(zip('ABCDEFG', [5.0, 3.0, 0.0, 1.0, 0.0, 5.0, 6.0]))
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '(A,(B,(C,(D,(E,(F,G))))));', [0.5] * (len(taxa) - 2),
            [10.0], taxa, **{
                'ratios_id': 'ratios',
                'root_height_id': 'root_height'
            }),
        dic,
    )
    log_det_jacobian = torch.tensor([0.0])
    for i in range(len(taxa), 2 * len(taxa) - 2):
        log_det_jacobian += (tree_model.node_heights[i + 1] -
                             tree_model.transform._bounds[i]).log()
    assert torch.allclose(tree_model(), log_det_jacobian)
コード例 #5
0
ファイル: test_coalescent.py プロジェクト: 4ment/phylotorch
def tree_model_node_heights_transformed(ratios_list):
    tree_model = ReparameterizedTimeTreeModel.json_factory(
        'tree',
        '(((A,B),C),D);',
        ratios_list[:-1],
        ratios_list[-1:],
        dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])),
    )
    return tree_model
コード例 #6
0
def test_keep_branch_lengths_heights():
    dic = {}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree', '((((A_0:1.5,B_1:0.5):2.5,C_2:2):2,D_3:3):10,E_12:4);',
            [0.0] * 3, [0.0],
            dict(
                zip(['A_0', 'B_1', 'C_2', 'D_3', 'E_12'],
                    [0.0, 1.0, 2.0, 3.0, 12.0])), **{
                        'keep_branch_lengths': True,
                        'ratios_id': 'ratios',
                        'root_height_id': 'root_height',
                    }),
        dic,
    )
    assert torch.allclose(
        tree_model.branch_lengths(),
        torch.tensor([1.5, 0.5, 2.0, 3.0, 4.0, 2.5, 2.0, 10.0]),
    )
コード例 #7
0
def ratio_transform(args):
    replicates = args.replicates
    tree = read_tree(args.tree, True, True)
    taxa_count = len(tree.taxon_namespace)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        return tree_model.transform(
            ratios_root_height,
        )

    @benchmark
    def fn_grad(ratios_root_height):
        heights = tree_model.transform(
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    total_time, heights = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(f"ratio_transform,evaluation,off,{total_time},\n")
        args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n")

    print('  JIT off')

    @benchmark
    def fn2(ratios_root_height):
        return transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad(ratios_root_height):
        heights = transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('  JIT on')
    transform_script = torch.jit.script(transform)

    @benchmark
    def fn2_jit(ratios_root_height):
        return transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad_jit(ratios_root_height):
        heights = transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT off')

    @benchmark
    def fn3(ratios_root_height):
        return transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad(ratios_root_height):
        heights = transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT on')
    transform2_script = torch.jit.script(transform2)

    @benchmark
    def fn3_jit(ratios_root_height):
        return transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad_jit(ratios_root_height):
        heights = transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')
コード例 #8
0
def ratio_transform_jacobian(args):
    tree = read_tree(args.tree, True, True)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    taxa_count = len(taxa)
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        return tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )

    @benchmark
    def fn_grad(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )
        log_det_jac.backward()
        ratios_root_height.grad.data.zero_()
        return log_det_jac

    print('  JIT off')
    total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {args.replicates} evaluations: {total_time} ({log_det_jac})')

    ratios_root_height.requires_grad = True
    grad_total_time, grad_log_det_jac = fn_grad(
        args.replicates, ratios_root_height.tensor
    )
    print(
        f'  {args.replicates} gradient evaluations: {grad_total_time}'
        f' ({grad_log_det_jac})'
    )

    if args.output:
        args.output.write(
            f"ratio_transform_jacobian,evaluation,off,{total_time},"
            f"{log_det_jac.squeeze().item()}\n"
        )
        args.output.write(
            f"ratio_transform_jacobian,gradient,off,{grad_total_time},"
            f"{grad_log_det_jac.squeeze().item()}\n"
        )

    if args.debug:
        internal_heights = tree_model.transform(ratios_root_height.tensor)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height.tensor, internal_heights
        )
        log_det_jac.backward()
        print(ratios_root_height.grad)
コード例 #9
0
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file):
    taxa_list = []
    with open(flu_a_fasta_file) as fp:
        for line in fp:
            if line.startswith('>'):
                taxon = line[1:].strip()
                date = float(taxon.split('_')[-1])
                taxa_list.append(Taxon(taxon, {'date': date}))
    taxa = Taxa('taxa', taxa_list)

    site_pattern = {
        'id': 'sp',
        'type': 'torchtree.evolution.site_pattern.SitePattern',
        'alignment': {
            "id": "alignment",
            "type": "torchtree.evolution.alignment.Alignment",
            'datatype': 'nucleotide',
            'file': flu_a_fasta_file,
            'taxa': 'taxa',
        },
    }
    subst_model = JC69('jc')
    site_model = WeibullSiteModel('site_model',
                                  Parameter(None, torch.tensor([[0.1]])), 4)
    ratios = [0.5] * 67
    root_height = [20.0]
    with open(flu_a_tree_file) as fp:
        newick = fp.read().strip()

    dic = {'taxa': taxa}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree_model', newick, ratios, root_height, 'taxa',
            **{'keep_branch_lengths': True}),
        dic,
    )
    branch_model = StrictClockModel(None,
                                    Parameter(None, torch.tensor([[0.001]])),
                                    tree_model)
    dic['tree_model'] = tree_model
    dic['site_model'] = site_model
    dic['subst_model'] = subst_model
    dic['branch_model'] = branch_model

    like = likelihood.TreeLikelihoodModel.from_json(
        {
            'id': 'like',
            'type': 'torchtree.tree_likelihood.TreeLikelihoodModel',
            'tree_model': 'tree_model',
            'site_model': 'site_model',
            'site_pattern': site_pattern,
            'substitution_model': 'subst_model',
            'branch_model': 'branch_model',
        },
        dic,
    )
    assert torch.allclose(torch.tensor([-4618.2062529058]), like())

    branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1)
    site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1)
    tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat(
        3, 68)
    assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
コード例 #10
0
def create_tree_model(id_: str, taxa: dict, arg):
    tree_format = 'newick'
    with open(arg.tree) as fp:
        if next(fp).upper().startswith('#NEXUS'):
            tree_format = 'nexus'
    if tree_format == 'nexus':
        tree = Tree.get(
            path=arg.tree,
            schema=tree_format,
            tree_offset=0,
            preserve_underscores=True,
        )
        newick = str(tree) + ';'
    else:
        with open(arg.tree) as fp:
            newick = fp.read()
            newick = newick.strip()

    kwargs = {}
    if arg.keep:
        kwargs['keep_branch_lengths'] = True

    if arg.clock is not None:
        dates = [taxon['attributes']['date'] for taxon in taxa['taxa']]
        offset = max(dates) - min(dates)

        if arg.heights == 'ratio':
            ratios = Parameter.json_factory(
                f'{id_}.ratios', **{'tensor': 0.1, 'full': [len(dates) - 2]}
            )
            ratios['lower'] = 0.0
            ratios['upper'] = 1.0

            root_height = Parameter.json_factory(
                f'{id_}.root_height', **{'tensor': [offset + 1.0]}
            )
            if 'root_height_init' in arg:
                root_height['tensor'] = [max(dates) - arg.root_height_init]
            elif arg.coalescent == 'skygrid':
                root_height['tensor'] = arg.cutoff

            root_height['lower'] = offset
            tree_model = ReparameterizedTimeTreeModel.json_factory(
                id_, newick, ratios, root_height, 'taxa', **kwargs
            )
        elif arg.heights == 'shift':
            shifts = Parameter.json_factory(
                f'{id_}.shifts', **{'tensor': 0.1, 'full': [len(dates) - 1]}
            )
            shifts['lower'] = 0.0
            node_heights = {
                'id': f'{id_}.heights',
                'type': 'TransformedParameter',
                'transform': 'torchtree.evolution.tree_height_transform'
                '.DifferenceNodeHeightTransform',
                'x': shifts,
                'parameters': {'tree_model': id_},
            }
            tree_model = FlexibleTimeTreeModel.json_factory(
                id_, newick, node_heights, 'taxa', **kwargs
            )
    else:
        branch_lengths = Parameter.json_factory(
            f'{id_}.blens', **{'tensor': 0.1, 'full': [len(taxa['taxa']) * 2 - 3]}
        )
        branch_lengths['lower'] = 0.0
        tree_model = UnRootedTreeModel.json_factory(
            id_, newick, branch_lengths, 'taxa', **kwargs
        )
    return tree_model