Ejemplo n.º 1
0
def test_skygrid_heterochronous_2_trees():
    sampling_times = [0.0, 1.0, 2.0, 3.0, 12.0]
    thetas_log = torch.tensor(np.array([1.0, 3.0, 6.0, 8.0, 9.0]))
    thetas = Parameter(None, thetas_log.exp())
    grid = Parameter(None, torch.linspace(0, 10.0, steps=5)[1:])
    constant = PiecewiseConstantCoalescentGridModel(
        None,
        thetas,
        grid,
        FakeTreeModel(
            Parameter(
                None,
                torch.tensor(
                    [
                        [sampling_times + [1.5, 4.0, 6.0, 16.0]],
                        [sampling_times + [1.5, 4.0, 6.0, 26.0]],
                    ],
                    dtype=thetas.dtype,
                ),
            )),
    )
    assert torch.allclose(
        torch.tensor([[-19.594893640219844], [-19.596127738260712]],
                     dtype=thetas.dtype),
        torch.squeeze(constant(), -1),
    )
Ejemplo n.º 2
0
def test_gmrf2(thetas, precision):
    thetas = torch.tensor(thetas, requires_grad=True)
    thetas2 = thetas.detach()
    precision = torch.tensor([precision])
    gmrf = GMRF(
        None,
        field=Parameter(None, thetas),
        precision=Parameter(None, precision),
        tree_model=None,
    )
    lp1 = gmrf()
    lp1.backward()

    thetas2.requires_grad = True
    dim = thetas.shape[0]
    Q = torch.zeros((dim, dim))
    Q[range(dim - 1), range(1, dim)] = -1
    Q[range(1, dim), range(dim - 1)] = -1
    Q.fill_diagonal_(2)
    Q[0, 0] = Q[dim - 1, dim - 1] = 1

    Q_scaled = Q * precision

    lp2 = (
        0.5 * (dim - 1) * precision.log()
        - 0.5 * torch.dot(thetas2, Q_scaled @ thetas2)
        - (dim - 1) / 2.0 * 1.8378770664093453
    )
    lp2.backward()
    assert lp1.item() == pytest.approx(lp2.item())
    assert torch.allclose(thetas.grad, thetas2.grad)
Ejemplo n.º 3
0
def test_gmrfw2():
    gmrf = GMRF(
        None,
        field=Parameter(None, torch.tensor([3.0, 10.0, 4.0]).log()),
        precision=Parameter(None, torch.tensor([0.1])),
        tree_model=None,
    )
    assert -4.254919053937792 == pytest.approx(gmrf().item(), 0.000001)
Ejemplo n.º 4
0
def test_gmrf(thetas, precision, expected):
    gmrf = GMRF(
        None,
        field=Parameter(None, torch.tensor(thetas)),
        precision=Parameter(None, torch.tensor([precision])),
        tree_model=None,
    )
    assert expected == pytest.approx(gmrf().item(), 0.000001)
Ejemplo n.º 5
0
def test_weibull_mu():
    sitemodel = WeibullSiteModel(
        'weibull',
        Parameter('shape', torch.tensor([1.0])),
        4,
        mu=Parameter('mu', torch.tensor([2.0])),
    )
    rates_expected = torch.tensor([0.1457844, 0.5131316, 1.0708310, 2.2702530
                                   ]) * 2
    np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)
Ejemplo n.º 6
0
def test_gmrf_batch():
    gmrf = GMRF(
        None,
        field=Parameter(
            None, torch.tensor([[2.0, 30.0, 4.0, 15.0, 6.0], [1.0, 3.0, 6.0, 8.0, 9.0]])
        ),
        precision=Parameter(None, torch.tensor([[2.0], [0.1]])),
        tree_model=None,
    )
    assert torch.allclose(
        gmrf(), torch.tensor([[-1664.2894596388803], [-9.180924185988092]])
    )
Ejemplo n.º 7
0
def test_weibull_invariant0():
    key = 'shape'
    sitemodel = WeibullSiteModel(
        'weibull',
        Parameter(key, torch.tensor([1.0])),
        4,
        Parameter('inv', torch.tensor([0.0])),
    )
    rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530)
    np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)

    assert torch.sum(sitemodel.rates() *
                     sitemodel.probabilities()).item() == pytest.approx(
                         1.0, 1.0e-6)
Ejemplo n.º 8
0
def test_invariant():
    prop_invariant = torch.tensor([0.2])
    site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant))
    rates = site_model.rates()
    props = site_model.probabilities()
    assert rates.mul(props).sum() == torch.tensor(np.ones(1))
    assert torch.all(
        torch.cat((prop_invariant, torch.tensor([0.8]))).eq(props))
Ejemplo n.º 9
0
def test_smoothed(rescale, expected):
    tree_model = TimeTreeModel.from_json(
        TimeTreeModel.json_factory(
            'tree',
            '(((A,B),C),D);',
            [3.0, 2.0, 4.0],
            dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])),
        ),
        {},
    )

    gmrf = GMRF(
        None,
        field=Parameter(None, torch.tensor([3.0, 10.0, 4.0]).log()),
        precision=Parameter(None, torch.tensor([0.1])),
        tree_model=tree_model,
        rescale=rescale,
    )
    assert torch.allclose(torch.tensor(expected), gmrf())
Ejemplo n.º 10
0
def test_ctmc_scale_batch(tree_model_dict):
    tree_model_dict['internal_heights']['tensor'] = [
        [1.5, 4.0, 6.0, 16.0],
        [1.5, 4.0, 6.0, 16.0],
    ]
    tree_model = TimeTreeModel.from_json(tree_model_dict, {})
    ctmc_scale = CTMCScale(None,
                           Parameter(None, torch.tensor([[0.001], [0.001]])),
                           tree_model)
    assert torch.allclose(torch.full((2, 1), 4.475351922659342), ctmc_scale())
Ejemplo n.º 11
0
def test_weibull_batch2():
    key = 'shape'
    sitemodel = WeibullSiteModel(
        'weibull', Parameter(key, torch.tensor([[1.0], [0.1], [1.0]])), 4)
    rates_expected = torch.tensor([
        [0.1457844, 0.5131316, 1.0708310, 2.2702530],
        [4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819],
        [0.1457844, 0.5131316, 1.0708310, 2.2702530],
    ])
    assert torch.allclose(sitemodel.rates(), rates_expected)
Ejemplo n.º 12
0
def test_gmrf_time_aware(thetas, precision, weights, rescale):
    thetas = torch.tensor(thetas, requires_grad=True)

    precision = torch.tensor([precision])
    weights_tensor = torch.tensor(weights) if weights is not None else None
    gmrf = GMRF(
        None,
        field=Parameter(None, thetas),
        precision=Parameter(None, precision),
        tree_model=None,
        weights=weights_tensor,
        rescale=rescale,
    )
    lp1 = gmrf()

    dim = thetas.shape[0]
    if weights is not None:
        times = torch.tensor([0.0, 2.0, 6.0, 12.0, 20.0, 25.0])
        durations = times[..., 1:] - times[..., :-1]
        offdiag = -2.0 / (durations[..., :-1] + durations[..., 1:])
        if rescale:
            offdiag *= times[-1]  # rescale with root height
    else:
        offdiag = torch.full((dim - 1,), -1.0)

    Q = torch.zeros((dim, dim))
    Q[range(dim - 1), range(1, dim)] = offdiag
    Q[range(1, dim), range(dim - 1)] = offdiag

    Q[range(1, dim - 1), range(1, dim - 1)] = -(offdiag[..., :-1] + offdiag[..., 1:])
    Q[0, 0] = -Q[0, 1]
    Q[dim - 1, dim - 1] = -offdiag[-1]

    Q_scaled = Q * precision
    lp2 = (
        0.5 * (dim - 1) * precision.log()
        - 0.5 * torch.dot(thetas, Q_scaled @ thetas)
        - (dim - 1) / 2.0 * 1.8378770664093453
    )

    assert lp1.item() == pytest.approx(lp2.item())
Ejemplo n.º 13
0
def test_prior_mrbayes():
    dic = {}
    json_tree = UnRootedTreeModel.json_factory(
        'tree',
        '(6:0.02,((5:0.02,2:0.02):0.02,(4:0.02,3:0.02):0.02):0.02,1:0.02);',
        [0.0], {str(i): None
                for i in range(1, 7)}, **{
                    'keep_branch_lengths': True,
                    'branch_lengths_id': 'bl',
                })

    tree_model = UnRootedTreeModel.from_json(json_tree, dic)
    prior = CompoundGammaDirichletPrior(
        None,
        tree_model,
        Parameter(None, torch.tensor([1.0])),
        Parameter(None, torch.tensor([1.0])),
        Parameter(None, torch.tensor([1.0])),
        Parameter(None, torch.tensor([0.1])),
    )
    assert torch.allclose(prior(), torch.tensor([22.00240516662597]))
Ejemplo n.º 14
0
def test_simple():
    normal = Distribution(
        None,
        torch.distributions.Normal,
        Parameter(None, torch.tensor([1.0])),
        OrderedDict(
            {
                'loc': Parameter(None, torch.tensor([0.0])),
                'scale': Parameter(None, torch.tensor([1.0])),
            }
        ),
    )

    exp = Distribution(
        None,
        torch.distributions.Exponential,
        Parameter(None, torch.tensor([1.0])),
        OrderedDict({'rate': Parameter(None, torch.tensor([1.0]))}),
    )
    joint = JointDistributionModel(None, [normal, exp])
    assert (-1.418939 - 1) == pytest.approx(joint().item())
Ejemplo n.º 15
0
def test_batch():
    normal = Distribution(
        None,
        torch.distributions.Normal,
        Parameter(None, torch.tensor([[1.0], [2.0]])),
        OrderedDict(
            {
                'loc': Parameter(None, torch.tensor([0.0])),
                'scale': Parameter(None, torch.tensor([1.0])),
            }
        ),
    )

    exp = Distribution(
        None,
        torch.distributions.Exponential,
        Parameter(None, torch.tensor([[1.0], [2.0]])),
        OrderedDict({'rate': Parameter(None, torch.tensor([1.0]))}),
    )
    joint = JointDistributionModel(None, [normal, exp])
    assert torch.allclose(torch.tensor([-1.418939 - 1, -2.918939 - 2]), joint())
Ejemplo n.º 16
0
def test_gmrf_covariates_simple():
    gmrf = GMRF(
        None,
        Parameter(None, torch.tensor([1.0, 2.0, 3.0])),
        Parameter(None, torch.tensor([0.1])),
    )
    gmrf_covariate = GMRFCovariate(
        None,
        field=Parameter(None, torch.tensor([1.0, 2.0, 3.0])),
        precision=Parameter(None, torch.tensor([0.1])),
        covariates=Parameter(None, torch.arange(1.0, 7.0).view((3, 2))),
        beta=Parameter(None, torch.tensor([0.0, 0.0])),
    )
    assert gmrf() == gmrf_covariate()
Ejemplo n.º 17
0
def test_parameterization(parameterization, param):
    loc = torch.tensor([1.0, 2.0])
    x = torch.tensor([1.0, 2.0])
    param_tensor = torch.tensor(param)
    kwargs = {parameterization: Parameter(None, param_tensor)}
    dist_model = MultivariateNormal(None, Parameter(None, x),
                                    Parameter(None, loc), **kwargs)

    dist = torch.distributions.multivariate_normal.MultivariateNormal(
        loc, **{parameterization: param_tensor})
    assert dist.log_prob(x) == dist_model()

    x_tuple = (Parameter(None, x[:-1]), Parameter(None, x[-1:]))
    dist_model = MultivariateNormal(None, x_tuple, Parameter(None, loc),
                                    **kwargs)
    assert dist.log_prob(x) == dist_model()
Ejemplo n.º 18
0
def test_transformed_parameter():
    t = torch.tensor([1.0, 2.0])
    p1 = Parameter('param', t)
    transformed = torch.distributions.ExpTransform()
    p2 = TransformedParameter('transformed', p1, transformed)
    assert p2.need_update is False
    assert torch.all(p2.tensor.eq(t.exp()))
    assert p2.need_update is False

    p1.tensor = torch.tensor([1.0, 3.0])
    assert p2.need_update is True

    assert torch.all(p2.tensor.eq(p1.tensor.exp()))
    assert p2.need_update is False

    # jacobian
    p1.tensor = t
    assert p2.need_update is True
    assert torch.all(p2().eq(p1.tensor))
    assert p2.need_update is False

    assert torch.all(p2.tensor.eq(p1.tensor.exp()))
Ejemplo n.º 19
0
def test_view_parameter_listener():
    t = torch.tensor([1.0, 2.0, 3.0, 4.0])
    p = Parameter('param', t)

    p1 = ViewParameter('param1', p, 3)

    class FakeListener(ParameterListener):
        def __init__(self):
            self.gotit = False

        def handle_parameter_changed(self, variable, index, event):
            self.gotit = True

    listener = FakeListener()
    p1.add_parameter_listener(listener)

    p.fire_parameter_changed()
    assert listener.gotit is True

    listener.gotit = False
    p1.fire_parameter_changed()
    assert listener.gotit is True
Ejemplo n.º 20
0
def test_parameterization_rsample(parameterization, param):
    torch.manual_seed(0)
    loc = torch.tensor([1.0, 2.0])
    x = torch.tensor([1.0, 2.0])
    param_tensor = torch.tensor(param)
    kwargs = {parameterization: Parameter(None, param_tensor)}
    dist_model = MultivariateNormal(None, Parameter(None, x),
                                    Parameter(None, loc), **kwargs)
    dist_model.rsample()

    torch.manual_seed(0)
    dist = torch.distributions.multivariate_normal.MultivariateNormal(
        loc, **{parameterization: param_tensor})
    samples = dist.rsample()
    assert torch.all(samples == dist_model.x.tensor)

    torch.manual_seed(0)
    x_tuple = (Parameter(':-1', x[:-1]), Parameter('-1:', x[-1:]))
    dist_model = MultivariateNormal(None, x_tuple, Parameter(None, loc),
                                    **kwargs)
    dist_model.rsample()
    assert torch.all(samples == dist_model.x.tensor)
Ejemplo n.º 21
0
def test_invariant_mu():
    prop_invariant = torch.tensor([0.2])
    site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant),
                                    Parameter('mu', torch.tensor([2.0])))
    assert site_model.rates()[1] == 2 / 0.8
Ejemplo n.º 22
0
def test_ctmc_scale(tree_model_dict):
    tree_model = TimeTreeModel.from_json(tree_model_dict, {})
    ctmc_scale = CTMCScale(None, Parameter(None, torch.tensor([0.001])),
                           tree_model)
    assert 4.475351922659342 == pytest.approx(ctmc_scale().item(), 0.00001)
Ejemplo n.º 23
0
def test_view_parameter():
    t = torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])
    p = Parameter('param', t)

    p1 = ViewParameter('param1', p, 3)
    assert torch.all(p1.tensor.eq(t[..., 3]))

    p1 = ViewParameter('param1', p, slice(3))
    assert torch.all(p1.tensor.eq(t[..., :3]))

    p1 = ViewParameter('param1', p, torch.tensor([0, 3]))
    assert torch.all(p1.tensor.eq(t[..., torch.tensor([0, 3])]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': 3,
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(t[..., 3]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': ':3',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(t[..., :3]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': '2:',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(t[..., 2:]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': '1:3',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(t[..., 1:3]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': '1:4:2',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(t[..., 1:4:2]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': '::-1',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(t[..., torch.LongTensor([3, 2, 1, 0])]))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': '2::-1',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(torch.tensor(t.numpy()[Ellipsis, 2::-1].copy())))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': ':0:-1',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(torch.tensor(t.numpy()[..., :0:-1].copy())))

    p1 = ViewParameter.from_json(
        {
            'id': 'a',
            'type': 'torchtree.ViewParameter',
            'parameter': 'param',
            'indices': '2:0:-1',
        },
        {'param': p},
    )
    assert torch.all(p1.tensor.eq(torch.tensor(t.numpy()[..., 2:0:-1].copy())))
Ejemplo n.º 24
0
def test_parameter_repr():
    p = Parameter('param', torch.tensor([1, 2]))
    assert eval(repr(p)) == p
Ejemplo n.º 25
0
import pytest
import torch

from torchtree import Parameter
from torchtree.evolution.site_model import (
    ConstantSiteModel,
    InvariantSiteModel,
    WeibullSiteModel,
)


@pytest.mark.parametrize(
    "mu,expected",
    (
        (None, torch.tensor([1.0])),
        (Parameter('mu', torch.tensor([2.0])), torch.tensor([2.0])),
    ),
)
def test_constant(mu, expected):
    sitemodel = ConstantSiteModel('constant', mu)
    assert torch.all(sitemodel.rates() == expected)
    assert sitemodel.probabilities()[0] == 1.0


def test_weibull_batch():
    key = 'shape'
    sitemodel = WeibullSiteModel('weibull',
                                 Parameter(key, torch.tensor([[1.0], [0.1]])),
                                 4)
    rates_expected = torch.tensor([
        [0.1457844, 0.5131316, 1.0708310, 2.2702530],
Ejemplo n.º 26
0
def constant_coalescent(args):
    tree = read_tree(args.tree, True, True)
    taxa_count = len(tree.taxon_namespace)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0])
    )
    tree_model = TimeTreeModel("tree", tree, Taxa('taxa', taxa), ratios_root_height)
    tree_model._internal_heights.tensor = heights_from_branch_lengths(tree)
    pop_size = torch.tensor([4.0])

    print('JIT off')

    @benchmark
    def fn(tree_model, pop_size):
        return ConstantCoalescent(pop_size).log_prob(tree_model.node_heights)

    @benchmark
    def fn_grad(tree_model, pop_size):
        log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights)
        log_p.backward()
        ratios_root_height.tensor.grad.data.zero_()
        pop_size.grad.data.zero_()
        return log_p

    total_time, log_p = fn(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} evaluations: {total_time} {log_p}')

    ratios_root_height.requires_grad = True
    pop_size.requires_grad_(True)
    grad_total_time, grad_log_p = fn_grad(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(
            f"coalescent,evaluation,off,{total_time},{log_p.squeeze().item()}\n"
        )
        args.output.write(
            f"coalescent,gradient,off,{grad_total_time},{grad_log_p.squeeze().item()}\n"
        )

    if args.debug:
        tree_model.heights_need_update = True
        log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights)
        log_p.backward()
        print('gradient ratios: ', ratios_root_height.grad)
        print('gradient pop size: ', pop_size.grad)
        ratios_root_height.tensor.grad.data.zero_()
        pop_size.grad.data.zero_()

    print('JIT on')
    log_prob_script = torch.jit.script(log_prob)

    @benchmark
    def fn_jit(tree_model, pop_size):
        return log_prob_script(tree_model.node_heights, pop_size)

    @benchmark
    def fn_grad_jit(tree_model, pop_size):
        log_p = log_prob_script(tree_model.node_heights, pop_size)
        log_p.backward()
        ratios_root_height.tensor.grad.data.zero_()
        pop_size.grad.data.zero_()
        return log_p

    ratios_root_height.requires_grad = False
    pop_size.requires_grad_(False)
    total_time, log_p = fn_jit(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} evaluations: {total_time} {log_p}')

    ratios_root_height.requires_grad = True
    pop_size.requires_grad_(True)
    grad_total_time, grad_log_p = fn_grad_jit(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(
            f"coalescent,evaluation,on,{total_time},{log_p.squeeze().item()}\n"
        )
        args.output.write(
            f"coalescent,gradient,on,{grad_total_time},{grad_log_p.squeeze().item()}\n"
        )

    if args.all:
        print('make sampling times unique and count them:')

        @benchmark
        def fn3(tree_model, pop_size):
            tree_model.heights_need_update = True
            node_heights = torch.cat(
                (x, tree_model.node_heights[..., tree_model.taxa_count :])
            )
            return log_prob_squashed(
                pop_size, node_heights, counts, tree_model.taxa_count
            )

        @benchmark
        def fn3_grad(tree_model, ratios_root_height, pop_size):
            tree_model.heights_need_update = True
            node_heights = torch.cat(
                (x, tree_model.node_heights[..., tree_model.taxa_count :])
            )
            log_p = log_prob_squashed(
                pop_size, node_heights, counts, tree_model.taxa_count
            )
            log_p.backward()
            ratios_root_height.tensor.grad.data.zero_()
            pop_size.grad.data.zero_()
            return log_p

        x, counts = torch.unique(tree_model.sampling_times, return_counts=True)
        counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1))))

        with torch.no_grad():
            total_time, log_p = fn3(args.replicates, tree_model, pop_size)
        print(f'  {args.replicates} evaluations: {total_time} ({log_p})')

        total_time, log_p = fn3_grad(
            args.replicates, tree_model, ratios_root_height, pop_size
        )
        print(f'  {args.replicates} gradient evaluations: {total_time}')
Ejemplo n.º 27
0
def ratio_transform(args):
    replicates = args.replicates
    tree = read_tree(args.tree, True, True)
    taxa_count = len(tree.taxon_namespace)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        return tree_model.transform(
            ratios_root_height,
        )

    @benchmark
    def fn_grad(ratios_root_height):
        heights = tree_model.transform(
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    total_time, heights = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(f"ratio_transform,evaluation,off,{total_time},\n")
        args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n")

    print('  JIT off')

    @benchmark
    def fn2(ratios_root_height):
        return transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad(ratios_root_height):
        heights = transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('  JIT on')
    transform_script = torch.jit.script(transform)

    @benchmark
    def fn2_jit(ratios_root_height):
        return transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad_jit(ratios_root_height):
        heights = transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT off')

    @benchmark
    def fn3(ratios_root_height):
        return transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad(ratios_root_height):
        heights = transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT on')
    transform2_script = torch.jit.script(transform2)

    @benchmark
    def fn3_jit(ratios_root_height):
        return transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad_jit(ratios_root_height):
        heights = transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')
Ejemplo n.º 28
0
def ratio_transform_jacobian(args):
    tree = read_tree(args.tree, True, True)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    taxa_count = len(taxa)
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        return tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )

    @benchmark
    def fn_grad(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )
        log_det_jac.backward()
        ratios_root_height.grad.data.zero_()
        return log_det_jac

    print('  JIT off')
    total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {args.replicates} evaluations: {total_time} ({log_det_jac})')

    ratios_root_height.requires_grad = True
    grad_total_time, grad_log_det_jac = fn_grad(
        args.replicates, ratios_root_height.tensor
    )
    print(
        f'  {args.replicates} gradient evaluations: {grad_total_time}'
        f' ({grad_log_det_jac})'
    )

    if args.output:
        args.output.write(
            f"ratio_transform_jacobian,evaluation,off,{total_time},"
            f"{log_det_jac.squeeze().item()}\n"
        )
        args.output.write(
            f"ratio_transform_jacobian,gradient,off,{grad_total_time},"
            f"{grad_log_det_jac.squeeze().item()}\n"
        )

    if args.debug:
        internal_heights = tree_model.transform(ratios_root_height.tensor)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height.tensor, internal_heights
        )
        log_det_jac.backward()
        print(ratios_root_height.grad)
Ejemplo n.º 29
0
def test_view_parameter_repr():
    p = Parameter('param', torch.tensor([1, 2]))
    p2 = ViewParameter('p2', p, 1)
    assert eval(repr(p2)) == p2
Ejemplo n.º 30
0
def test_invariant_batch():
    prop_invariant = torch.tensor([[0.2], [0.3]])
    site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant))
    rates = site_model.rates()
    props = site_model.probabilities()
    assert torch.all(rates.mul(props).sum(-1) == torch.ones(2))