def test_skygrid_heterochronous_2_trees(): sampling_times = [0.0, 1.0, 2.0, 3.0, 12.0] thetas_log = torch.tensor(np.array([1.0, 3.0, 6.0, 8.0, 9.0])) thetas = Parameter(None, thetas_log.exp()) grid = Parameter(None, torch.linspace(0, 10.0, steps=5)[1:]) constant = PiecewiseConstantCoalescentGridModel( None, thetas, grid, FakeTreeModel( Parameter( None, torch.tensor( [ [sampling_times + [1.5, 4.0, 6.0, 16.0]], [sampling_times + [1.5, 4.0, 6.0, 26.0]], ], dtype=thetas.dtype, ), )), ) assert torch.allclose( torch.tensor([[-19.594893640219844], [-19.596127738260712]], dtype=thetas.dtype), torch.squeeze(constant(), -1), )
def test_gmrf2(thetas, precision): thetas = torch.tensor(thetas, requires_grad=True) thetas2 = thetas.detach() precision = torch.tensor([precision]) gmrf = GMRF( None, field=Parameter(None, thetas), precision=Parameter(None, precision), tree_model=None, ) lp1 = gmrf() lp1.backward() thetas2.requires_grad = True dim = thetas.shape[0] Q = torch.zeros((dim, dim)) Q[range(dim - 1), range(1, dim)] = -1 Q[range(1, dim), range(dim - 1)] = -1 Q.fill_diagonal_(2) Q[0, 0] = Q[dim - 1, dim - 1] = 1 Q_scaled = Q * precision lp2 = ( 0.5 * (dim - 1) * precision.log() - 0.5 * torch.dot(thetas2, Q_scaled @ thetas2) - (dim - 1) / 2.0 * 1.8378770664093453 ) lp2.backward() assert lp1.item() == pytest.approx(lp2.item()) assert torch.allclose(thetas.grad, thetas2.grad)
def test_gmrfw2(): gmrf = GMRF( None, field=Parameter(None, torch.tensor([3.0, 10.0, 4.0]).log()), precision=Parameter(None, torch.tensor([0.1])), tree_model=None, ) assert -4.254919053937792 == pytest.approx(gmrf().item(), 0.000001)
def test_gmrf(thetas, precision, expected): gmrf = GMRF( None, field=Parameter(None, torch.tensor(thetas)), precision=Parameter(None, torch.tensor([precision])), tree_model=None, ) assert expected == pytest.approx(gmrf().item(), 0.000001)
def test_weibull_mu(): sitemodel = WeibullSiteModel( 'weibull', Parameter('shape', torch.tensor([1.0])), 4, mu=Parameter('mu', torch.tensor([2.0])), ) rates_expected = torch.tensor([0.1457844, 0.5131316, 1.0708310, 2.2702530 ]) * 2 np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)
def test_gmrf_batch(): gmrf = GMRF( None, field=Parameter( None, torch.tensor([[2.0, 30.0, 4.0, 15.0, 6.0], [1.0, 3.0, 6.0, 8.0, 9.0]]) ), precision=Parameter(None, torch.tensor([[2.0], [0.1]])), tree_model=None, ) assert torch.allclose( gmrf(), torch.tensor([[-1664.2894596388803], [-9.180924185988092]]) )
def test_weibull_invariant0(): key = 'shape' sitemodel = WeibullSiteModel( 'weibull', Parameter(key, torch.tensor([1.0])), 4, Parameter('inv', torch.tensor([0.0])), ) rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530) np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06) assert torch.sum(sitemodel.rates() * sitemodel.probabilities()).item() == pytest.approx( 1.0, 1.0e-6)
def test_invariant(): prop_invariant = torch.tensor([0.2]) site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant)) rates = site_model.rates() props = site_model.probabilities() assert rates.mul(props).sum() == torch.tensor(np.ones(1)) assert torch.all( torch.cat((prop_invariant, torch.tensor([0.8]))).eq(props))
def test_smoothed(rescale, expected): tree_model = TimeTreeModel.from_json( TimeTreeModel.json_factory( 'tree', '(((A,B),C),D);', [3.0, 2.0, 4.0], dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])), ), {}, ) gmrf = GMRF( None, field=Parameter(None, torch.tensor([3.0, 10.0, 4.0]).log()), precision=Parameter(None, torch.tensor([0.1])), tree_model=tree_model, rescale=rescale, ) assert torch.allclose(torch.tensor(expected), gmrf())
def test_ctmc_scale_batch(tree_model_dict): tree_model_dict['internal_heights']['tensor'] = [ [1.5, 4.0, 6.0, 16.0], [1.5, 4.0, 6.0, 16.0], ] tree_model = TimeTreeModel.from_json(tree_model_dict, {}) ctmc_scale = CTMCScale(None, Parameter(None, torch.tensor([[0.001], [0.001]])), tree_model) assert torch.allclose(torch.full((2, 1), 4.475351922659342), ctmc_scale())
def test_weibull_batch2(): key = 'shape' sitemodel = WeibullSiteModel( 'weibull', Parameter(key, torch.tensor([[1.0], [0.1], [1.0]])), 4) rates_expected = torch.tensor([ [0.1457844, 0.5131316, 1.0708310, 2.2702530], [4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819], [0.1457844, 0.5131316, 1.0708310, 2.2702530], ]) assert torch.allclose(sitemodel.rates(), rates_expected)
def test_gmrf_time_aware(thetas, precision, weights, rescale): thetas = torch.tensor(thetas, requires_grad=True) precision = torch.tensor([precision]) weights_tensor = torch.tensor(weights) if weights is not None else None gmrf = GMRF( None, field=Parameter(None, thetas), precision=Parameter(None, precision), tree_model=None, weights=weights_tensor, rescale=rescale, ) lp1 = gmrf() dim = thetas.shape[0] if weights is not None: times = torch.tensor([0.0, 2.0, 6.0, 12.0, 20.0, 25.0]) durations = times[..., 1:] - times[..., :-1] offdiag = -2.0 / (durations[..., :-1] + durations[..., 1:]) if rescale: offdiag *= times[-1] # rescale with root height else: offdiag = torch.full((dim - 1,), -1.0) Q = torch.zeros((dim, dim)) Q[range(dim - 1), range(1, dim)] = offdiag Q[range(1, dim), range(dim - 1)] = offdiag Q[range(1, dim - 1), range(1, dim - 1)] = -(offdiag[..., :-1] + offdiag[..., 1:]) Q[0, 0] = -Q[0, 1] Q[dim - 1, dim - 1] = -offdiag[-1] Q_scaled = Q * precision lp2 = ( 0.5 * (dim - 1) * precision.log() - 0.5 * torch.dot(thetas, Q_scaled @ thetas) - (dim - 1) / 2.0 * 1.8378770664093453 ) assert lp1.item() == pytest.approx(lp2.item())
def test_prior_mrbayes(): dic = {} json_tree = UnRootedTreeModel.json_factory( 'tree', '(6:0.02,((5:0.02,2:0.02):0.02,(4:0.02,3:0.02):0.02):0.02,1:0.02);', [0.0], {str(i): None for i in range(1, 7)}, **{ 'keep_branch_lengths': True, 'branch_lengths_id': 'bl', }) tree_model = UnRootedTreeModel.from_json(json_tree, dic) prior = CompoundGammaDirichletPrior( None, tree_model, Parameter(None, torch.tensor([1.0])), Parameter(None, torch.tensor([1.0])), Parameter(None, torch.tensor([1.0])), Parameter(None, torch.tensor([0.1])), ) assert torch.allclose(prior(), torch.tensor([22.00240516662597]))
def test_simple(): normal = Distribution( None, torch.distributions.Normal, Parameter(None, torch.tensor([1.0])), OrderedDict( { 'loc': Parameter(None, torch.tensor([0.0])), 'scale': Parameter(None, torch.tensor([1.0])), } ), ) exp = Distribution( None, torch.distributions.Exponential, Parameter(None, torch.tensor([1.0])), OrderedDict({'rate': Parameter(None, torch.tensor([1.0]))}), ) joint = JointDistributionModel(None, [normal, exp]) assert (-1.418939 - 1) == pytest.approx(joint().item())
def test_batch(): normal = Distribution( None, torch.distributions.Normal, Parameter(None, torch.tensor([[1.0], [2.0]])), OrderedDict( { 'loc': Parameter(None, torch.tensor([0.0])), 'scale': Parameter(None, torch.tensor([1.0])), } ), ) exp = Distribution( None, torch.distributions.Exponential, Parameter(None, torch.tensor([[1.0], [2.0]])), OrderedDict({'rate': Parameter(None, torch.tensor([1.0]))}), ) joint = JointDistributionModel(None, [normal, exp]) assert torch.allclose(torch.tensor([-1.418939 - 1, -2.918939 - 2]), joint())
def test_gmrf_covariates_simple(): gmrf = GMRF( None, Parameter(None, torch.tensor([1.0, 2.0, 3.0])), Parameter(None, torch.tensor([0.1])), ) gmrf_covariate = GMRFCovariate( None, field=Parameter(None, torch.tensor([1.0, 2.0, 3.0])), precision=Parameter(None, torch.tensor([0.1])), covariates=Parameter(None, torch.arange(1.0, 7.0).view((3, 2))), beta=Parameter(None, torch.tensor([0.0, 0.0])), ) assert gmrf() == gmrf_covariate()
def test_parameterization(parameterization, param): loc = torch.tensor([1.0, 2.0]) x = torch.tensor([1.0, 2.0]) param_tensor = torch.tensor(param) kwargs = {parameterization: Parameter(None, param_tensor)} dist_model = MultivariateNormal(None, Parameter(None, x), Parameter(None, loc), **kwargs) dist = torch.distributions.multivariate_normal.MultivariateNormal( loc, **{parameterization: param_tensor}) assert dist.log_prob(x) == dist_model() x_tuple = (Parameter(None, x[:-1]), Parameter(None, x[-1:])) dist_model = MultivariateNormal(None, x_tuple, Parameter(None, loc), **kwargs) assert dist.log_prob(x) == dist_model()
def test_transformed_parameter(): t = torch.tensor([1.0, 2.0]) p1 = Parameter('param', t) transformed = torch.distributions.ExpTransform() p2 = TransformedParameter('transformed', p1, transformed) assert p2.need_update is False assert torch.all(p2.tensor.eq(t.exp())) assert p2.need_update is False p1.tensor = torch.tensor([1.0, 3.0]) assert p2.need_update is True assert torch.all(p2.tensor.eq(p1.tensor.exp())) assert p2.need_update is False # jacobian p1.tensor = t assert p2.need_update is True assert torch.all(p2().eq(p1.tensor)) assert p2.need_update is False assert torch.all(p2.tensor.eq(p1.tensor.exp()))
def test_view_parameter_listener(): t = torch.tensor([1.0, 2.0, 3.0, 4.0]) p = Parameter('param', t) p1 = ViewParameter('param1', p, 3) class FakeListener(ParameterListener): def __init__(self): self.gotit = False def handle_parameter_changed(self, variable, index, event): self.gotit = True listener = FakeListener() p1.add_parameter_listener(listener) p.fire_parameter_changed() assert listener.gotit is True listener.gotit = False p1.fire_parameter_changed() assert listener.gotit is True
def test_parameterization_rsample(parameterization, param): torch.manual_seed(0) loc = torch.tensor([1.0, 2.0]) x = torch.tensor([1.0, 2.0]) param_tensor = torch.tensor(param) kwargs = {parameterization: Parameter(None, param_tensor)} dist_model = MultivariateNormal(None, Parameter(None, x), Parameter(None, loc), **kwargs) dist_model.rsample() torch.manual_seed(0) dist = torch.distributions.multivariate_normal.MultivariateNormal( loc, **{parameterization: param_tensor}) samples = dist.rsample() assert torch.all(samples == dist_model.x.tensor) torch.manual_seed(0) x_tuple = (Parameter(':-1', x[:-1]), Parameter('-1:', x[-1:])) dist_model = MultivariateNormal(None, x_tuple, Parameter(None, loc), **kwargs) dist_model.rsample() assert torch.all(samples == dist_model.x.tensor)
def test_invariant_mu(): prop_invariant = torch.tensor([0.2]) site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant), Parameter('mu', torch.tensor([2.0]))) assert site_model.rates()[1] == 2 / 0.8
def test_ctmc_scale(tree_model_dict): tree_model = TimeTreeModel.from_json(tree_model_dict, {}) ctmc_scale = CTMCScale(None, Parameter(None, torch.tensor([0.001])), tree_model) assert 4.475351922659342 == pytest.approx(ctmc_scale().item(), 0.00001)
def test_view_parameter(): t = torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]]) p = Parameter('param', t) p1 = ViewParameter('param1', p, 3) assert torch.all(p1.tensor.eq(t[..., 3])) p1 = ViewParameter('param1', p, slice(3)) assert torch.all(p1.tensor.eq(t[..., :3])) p1 = ViewParameter('param1', p, torch.tensor([0, 3])) assert torch.all(p1.tensor.eq(t[..., torch.tensor([0, 3])])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': 3, }, {'param': p}, ) assert torch.all(p1.tensor.eq(t[..., 3])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': ':3', }, {'param': p}, ) assert torch.all(p1.tensor.eq(t[..., :3])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': '2:', }, {'param': p}, ) assert torch.all(p1.tensor.eq(t[..., 2:])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': '1:3', }, {'param': p}, ) assert torch.all(p1.tensor.eq(t[..., 1:3])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': '1:4:2', }, {'param': p}, ) assert torch.all(p1.tensor.eq(t[..., 1:4:2])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': '::-1', }, {'param': p}, ) assert torch.all(p1.tensor.eq(t[..., torch.LongTensor([3, 2, 1, 0])])) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': '2::-1', }, {'param': p}, ) assert torch.all(p1.tensor.eq(torch.tensor(t.numpy()[Ellipsis, 2::-1].copy()))) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': ':0:-1', }, {'param': p}, ) assert torch.all(p1.tensor.eq(torch.tensor(t.numpy()[..., :0:-1].copy()))) p1 = ViewParameter.from_json( { 'id': 'a', 'type': 'torchtree.ViewParameter', 'parameter': 'param', 'indices': '2:0:-1', }, {'param': p}, ) assert torch.all(p1.tensor.eq(torch.tensor(t.numpy()[..., 2:0:-1].copy())))
def test_parameter_repr(): p = Parameter('param', torch.tensor([1, 2])) assert eval(repr(p)) == p
import pytest import torch from torchtree import Parameter from torchtree.evolution.site_model import ( ConstantSiteModel, InvariantSiteModel, WeibullSiteModel, ) @pytest.mark.parametrize( "mu,expected", ( (None, torch.tensor([1.0])), (Parameter('mu', torch.tensor([2.0])), torch.tensor([2.0])), ), ) def test_constant(mu, expected): sitemodel = ConstantSiteModel('constant', mu) assert torch.all(sitemodel.rates() == expected) assert sitemodel.probabilities()[0] == 1.0 def test_weibull_batch(): key = 'shape' sitemodel = WeibullSiteModel('weibull', Parameter(key, torch.tensor([[1.0], [0.1]])), 4) rates_expected = torch.tensor([ [0.1457844, 0.5131316, 1.0708310, 2.2702530],
def constant_coalescent(args): tree = read_tree(args.tree, True, True) taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0]) ) tree_model = TimeTreeModel("tree", tree, Taxa('taxa', taxa), ratios_root_height) tree_model._internal_heights.tensor = heights_from_branch_lengths(tree) pop_size = torch.tensor([4.0]) print('JIT off') @benchmark def fn(tree_model, pop_size): return ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) @benchmark def fn_grad(tree_model, pop_size): log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p total_time, log_p = fn(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} {log_p}') ratios_root_height.requires_grad = True pop_size.requires_grad_(True) grad_total_time, grad_log_p = fn_grad(args.replicates, tree_model, pop_size) print(f' {args.replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write( f"coalescent,evaluation,off,{total_time},{log_p.squeeze().item()}\n" ) args.output.write( f"coalescent,gradient,off,{grad_total_time},{grad_log_p.squeeze().item()}\n" ) if args.debug: tree_model.heights_need_update = True log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) log_p.backward() print('gradient ratios: ', ratios_root_height.grad) print('gradient pop size: ', pop_size.grad) ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() print('JIT on') log_prob_script = torch.jit.script(log_prob) @benchmark def fn_jit(tree_model, pop_size): return log_prob_script(tree_model.node_heights, pop_size) @benchmark def fn_grad_jit(tree_model, pop_size): log_p = log_prob_script(tree_model.node_heights, pop_size) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p ratios_root_height.requires_grad = False pop_size.requires_grad_(False) total_time, log_p = fn_jit(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} {log_p}') ratios_root_height.requires_grad = True pop_size.requires_grad_(True) grad_total_time, grad_log_p = fn_grad_jit(args.replicates, tree_model, pop_size) print(f' {args.replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write( f"coalescent,evaluation,on,{total_time},{log_p.squeeze().item()}\n" ) args.output.write( f"coalescent,gradient,on,{grad_total_time},{grad_log_p.squeeze().item()}\n" ) if args.all: print('make sampling times unique and count them:') @benchmark def fn3(tree_model, pop_size): tree_model.heights_need_update = True node_heights = torch.cat( (x, tree_model.node_heights[..., tree_model.taxa_count :]) ) return log_prob_squashed( pop_size, node_heights, counts, tree_model.taxa_count ) @benchmark def fn3_grad(tree_model, ratios_root_height, pop_size): tree_model.heights_need_update = True node_heights = torch.cat( (x, tree_model.node_heights[..., tree_model.taxa_count :]) ) log_p = log_prob_squashed( pop_size, node_heights, counts, tree_model.taxa_count ) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p x, counts = torch.unique(tree_model.sampling_times, return_counts=True) counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1)))) with torch.no_grad(): total_time, log_p = fn3(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} ({log_p})') total_time, log_p = fn3_grad( args.replicates, tree_model, ratios_root_height, pop_size ) print(f' {args.replicates} gradient evaluations: {total_time}')
def ratio_transform(args): replicates = args.replicates tree = read_tree(args.tree, True, True) taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) ratios_root_height.tensor = tree_model.transform.inv( heights_from_branch_lengths(tree) ) @benchmark def fn(ratios_root_height): return tree_model.transform( ratios_root_height, ) @benchmark def fn_grad(ratios_root_height): heights = tree_model.transform( ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights total_time, heights = fn(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write(f"ratio_transform,evaluation,off,{total_time},\n") args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n") print(' JIT off') @benchmark def fn2(ratios_root_height): return transform( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn2_grad(ratios_root_height): heights = transform( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn2(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print(' JIT on') transform_script = torch.jit.script(transform) @benchmark def fn2_jit(ratios_root_height): return transform_script( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn2_grad_jit(ratios_root_height): heights = transform_script( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print('ratio_transform v2 JIT off') @benchmark def fn3(ratios_root_height): return transform2( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn3_grad(ratios_root_height): heights = transform2( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn3(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print('ratio_transform v2 JIT on') transform2_script = torch.jit.script(transform2) @benchmark def fn3_jit(ratios_root_height): return transform2_script( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn3_grad_jit(ratios_root_height): heights = transform2_script( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}')
def ratio_transform_jacobian(args): tree = read_tree(args.tree, True, True) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) taxa_count = len(taxa) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) ratios_root_height.tensor = tree_model.transform.inv( heights_from_branch_lengths(tree) ) @benchmark def fn(ratios_root_height): internal_heights = tree_model.transform(ratios_root_height) return tree_model.transform.log_abs_det_jacobian( ratios_root_height, internal_heights ) @benchmark def fn_grad(ratios_root_height): internal_heights = tree_model.transform(ratios_root_height) log_det_jac = tree_model.transform.log_abs_det_jacobian( ratios_root_height, internal_heights ) log_det_jac.backward() ratios_root_height.grad.data.zero_() return log_det_jac print(' JIT off') total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor) print(f' {args.replicates} evaluations: {total_time} ({log_det_jac})') ratios_root_height.requires_grad = True grad_total_time, grad_log_det_jac = fn_grad( args.replicates, ratios_root_height.tensor ) print( f' {args.replicates} gradient evaluations: {grad_total_time}' f' ({grad_log_det_jac})' ) if args.output: args.output.write( f"ratio_transform_jacobian,evaluation,off,{total_time}," f"{log_det_jac.squeeze().item()}\n" ) args.output.write( f"ratio_transform_jacobian,gradient,off,{grad_total_time}," f"{grad_log_det_jac.squeeze().item()}\n" ) if args.debug: internal_heights = tree_model.transform(ratios_root_height.tensor) log_det_jac = tree_model.transform.log_abs_det_jacobian( ratios_root_height.tensor, internal_heights ) log_det_jac.backward() print(ratios_root_height.grad)
def test_view_parameter_repr(): p = Parameter('param', torch.tensor([1, 2])) p2 = ViewParameter('p2', p, 1) assert eval(repr(p2)) == p2
def test_invariant_batch(): prop_invariant = torch.tensor([[0.2], [0.3]]) site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant)) rates = site_model.rates() props = site_model.probabilities() assert torch.all(rates.mul(props).sum(-1) == torch.ones(2))