def test_skygrid_heterochronous_2_trees(): sampling_times = [0.0, 1.0, 2.0, 3.0, 12.0] thetas_log = torch.tensor(np.array([1.0, 3.0, 6.0, 8.0, 9.0])) thetas = Parameter(None, thetas_log.exp()) grid = Parameter(None, torch.linspace(0, 10.0, steps=5)[1:]) constant = PiecewiseConstantCoalescentGridModel( None, thetas, grid, FakeTreeModel( Parameter( None, torch.tensor( [ [sampling_times + [1.5, 4.0, 6.0, 16.0]], [sampling_times + [1.5, 4.0, 6.0, 26.0]], ], dtype=thetas.dtype, ), )), ) assert torch.allclose( torch.tensor([[-19.594893640219844], [-19.596127738260712]], dtype=thetas.dtype), torch.squeeze(constant(), -1), )
def test_gmrf2(thetas, precision): thetas = torch.tensor(thetas, requires_grad=True) thetas2 = thetas.detach() precision = torch.tensor([precision]) gmrf = GMRF( None, field=Parameter(None, thetas), precision=Parameter(None, precision), tree_model=None, ) lp1 = gmrf() lp1.backward() thetas2.requires_grad = True dim = thetas.shape[0] Q = torch.zeros((dim, dim)) Q[range(dim - 1), range(1, dim)] = -1 Q[range(1, dim), range(dim - 1)] = -1 Q.fill_diagonal_(2) Q[0, 0] = Q[dim - 1, dim - 1] = 1 Q_scaled = Q * precision lp2 = ( 0.5 * (dim - 1) * precision.log() - 0.5 * torch.dot(thetas2, Q_scaled @ thetas2) - (dim - 1) / 2.0 * 1.8378770664093453 ) lp2.backward() assert lp1.item() == pytest.approx(lp2.item()) assert torch.allclose(thetas.grad, thetas2.grad)
def test_gmrf(thetas, precision, expected): gmrf = GMRF( None, field=Parameter(None, torch.tensor(thetas)), precision=Parameter(None, torch.tensor([precision])), tree_model=None, ) assert expected == pytest.approx(gmrf().item(), 0.000001)
def test_gmrfw2(): gmrf = GMRF( None, field=Parameter(None, torch.tensor([3.0, 10.0, 4.0]).log()), precision=Parameter(None, torch.tensor([0.1])), tree_model=None, ) assert -4.254919053937792 == pytest.approx(gmrf().item(), 0.000001)
def test_weibull_mu(): sitemodel = WeibullSiteModel( 'weibull', Parameter('shape', torch.tensor([1.0])), 4, mu=Parameter('mu', torch.tensor([2.0])), ) rates_expected = torch.tensor([0.1457844, 0.5131316, 1.0708310, 2.2702530 ]) * 2 np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06)
def test_gmrf_batch(): gmrf = GMRF( None, field=Parameter( None, torch.tensor([[2.0, 30.0, 4.0, 15.0, 6.0], [1.0, 3.0, 6.0, 8.0, 9.0]]) ), precision=Parameter(None, torch.tensor([[2.0], [0.1]])), tree_model=None, ) assert torch.allclose( gmrf(), torch.tensor([[-1664.2894596388803], [-9.180924185988092]]) )
def create_branch_model(id_, tree_id, taxa_count, arg): if arg.rate is not None: rate = [arg.rate] elif 'rate_init' in arg: rate = [arg.rate_init] else: rate = [0.001] rate_parameter = Parameter.json_factory(f'{id_}.rate', **{'tensor': rate}) rate_parameter['lower'] = 0.0 if arg.rate is not None: rate_parameter['lower'] = rate_parameter['upper'] = arg.rate if arg.clock == 'strict': return { 'id': id_, 'type': 'StrictClockModel', 'tree_model': tree_id, 'rate': rate_parameter, } elif arg.clock == 'horseshoe': rates = Parameter.json_factory( f'{id_}.rates.unscaled', **{'tensor': 1.0, 'full': [2 * taxa_count - 2]} ) rates['lower'] = 0.0 rescaled_rates = { 'id': f'{id_}.rates', 'type': 'TransformedParameter', 'transform': 'RescaledRateTransform', 'x': rates, 'parameters': { 'tree_model': tree_id, 'rate': rate_parameter, }, } return { 'id': f'{id_}.simple', 'type': 'SimpleClockModel', 'tree_model': tree_id, 'rate': rescaled_rates, } elif arg.clock == 'ucln': rate = Parameter.json_factory( f'{id_}.rates', **{'tensor': 0.001, 'full': [2 * taxa_count - 2]} ) rate['lower'] = 0.0 return { 'id': id_, 'type': 'SimpleClockModel', 'tree_model': tree_id, 'rate': rate, }
def create_birth_death(birth_death_id, tree_id, arg): R = Parameter.json_factory( f'{birth_death_id}.R', **{'tensor': 3.0, 'full': [arg.grid]}, ) R['lower'] = 0.0 delta = Parameter.json_factory( f'{birth_death_id}.delta', **{'tensor': 3.0, 'full': [arg.grid]}, ) delta['lower'] = 0.0 s = Parameter.json_factory( f'{birth_death_id}.s', **{'tensor': 0.0, 'full': [arg.grid]}, ) s['lower'] = 0.0 s['upper'] = 0.0 rho = Parameter.json_factory( f'{birth_death_id}.rho', **{ 'tensor': [1.0e-6], }, ) rho['lower'] = 0.0 rho['upper'] = 1.0 origin = { 'id': f'{birth_death_id}.origin', 'type': 'TransformedParameter', 'transform': 'torch.distributions.AffineTransform', 'x': { 'id': f'{birth_death_id}.origin.unshifted', 'type': 'Parameter', 'tensor': [1.0], 'lower': 0.0, }, 'parameters': { 'loc': f'{tree_id}.root_height', 'scale': 1.0, }, } bdsk = { 'id': birth_death_id, 'type': 'BDSKModel', 'tree_model': tree_id, 'R': R, 'delta': delta, 's': s, 'rho': rho, 'origin': origin, } return bdsk
def test_weibull_invariant0(): key = 'shape' sitemodel = WeibullSiteModel( 'weibull', Parameter(key, torch.tensor([1.0])), 4, Parameter('inv', torch.tensor([0.0])), ) rates_expected = (0.1457844, 0.5131316, 1.0708310, 2.2702530) np.testing.assert_allclose(sitemodel.rates(), rates_expected, rtol=1e-06) assert torch.sum(sitemodel.rates() * sitemodel.probabilities()).item() == pytest.approx( 1.0, 1.0e-6)
def test_invariant(): prop_invariant = torch.tensor([0.2]) site_model = InvariantSiteModel('pinv', Parameter('inv', prop_invariant)) rates = site_model.rates() props = site_model.probabilities() assert rates.mul(props).sum() == torch.tensor(np.ones(1)) assert torch.all( torch.cat((prop_invariant, torch.tensor([0.8]))).eq(props))
def create_clock_horseshoe_prior(branch_model_id, tree_id): prior_list = [] log_diff = { 'id': f'{branch_model_id}.rates.logdiff', 'type': 'TransformedParameter', 'transform': 'LogDifferenceRateTransform', 'x': f'{branch_model_id}.rates.unscaled', 'parameters': { 'tree_model': tree_id }, } global_scale = Parameter.json_factory(f'{branch_model_id}.global.scale', **{'tensor': [1.0]}) local_scale = Parameter.json_factory( f'{branch_model_id}.local.scales', **{ 'tensor': 1.0, 'full_like': f'{branch_model_id}.rates.unscaled' }, ) global_scale['lower'] = 0.0 local_scale['lower'] = 0.0 prior_list.append( ScaleMixtureNormal.json_factory( f'{branch_model_id}.scale.mixture.prior', log_diff, 0.0, global_scale, local_scale, )) prior_list.append( CTMCScale.json_factory(f'{branch_model_id}.rate.prior', f'{branch_model_id}.rate', 'tree')) for p in ('global.scale', 'local.scales'): prior_list.append( Distribution.json_factory( f'{branch_model_id}.{p}.prior', 'torch.distributions.Cauchy', f'{branch_model_id}.{p}', { 'loc': 0.0, 'scale': 1.0 }, )) return prior_list
def create_gamma_distribution(var_id, x_unres, json_object, concentration, rate): concentration_param = { 'id': var_id + '.' + json_object['id'] + '.concentration', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': Parameter.json_factory( var_id + '.' + json_object['id'] + '.concentration.unres', **{ 'full_like': json_object['id'], 'tensor': concentration }, ), } rate_param = { 'id': var_id + '.' + json_object['id'] + '.rate', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': Parameter.json_factory( var_id + '.' + json_object['id'] + '.rate.unres', **{ 'full_like': json_object['id'], 'tensor': rate }, ), } distr = Distribution.json_factory( var_id + '.' + json_object['id'], 'torch.distributions.Gamma', x_unres, { 'concentration': concentration_param, 'rate': rate_param }, ) return distr, concentration_param, rate_param
def test_smoothed(rescale, expected): tree_model = TimeTreeModel.from_json( TimeTreeModel.json_factory( 'tree', '(((A,B),C),D);', [3.0, 2.0, 4.0], dict(zip('ABCD', [0.0, 0.0, 0.0, 0.0])), ), {}, ) gmrf = GMRF( None, field=Parameter(None, torch.tensor([3.0, 10.0, 4.0]).log()), precision=Parameter(None, torch.tensor([0.1])), tree_model=tree_model, rescale=rescale, ) assert torch.allclose(torch.tensor(expected), gmrf())
def test_ctmc_scale_batch(tree_model_dict): tree_model_dict['internal_heights']['tensor'] = [ [1.5, 4.0, 6.0, 16.0], [1.5, 4.0, 6.0, 16.0], ] tree_model = TimeTreeModel.from_json(tree_model_dict, {}) ctmc_scale = CTMCScale(None, Parameter(None, torch.tensor([[0.001], [0.001]])), tree_model) assert torch.allclose(torch.full((2, 1), 4.475351922659342), ctmc_scale())
def test_weibull_batch2(): key = 'shape' sitemodel = WeibullSiteModel( 'weibull', Parameter(key, torch.tensor([[1.0], [0.1], [1.0]])), 4) rates_expected = torch.tensor([ [0.1457844, 0.5131316, 1.0708310, 2.2702530], [4.766392e-12, 1.391131e-06, 2.179165e-03, 3.997819], [0.1457844, 0.5131316, 1.0708310, 2.2702530], ]) assert torch.allclose(sitemodel.rates(), rates_expected)
def test_gmrf_time_aware(thetas, precision, weights, rescale): thetas = torch.tensor(thetas, requires_grad=True) precision = torch.tensor([precision]) weights_tensor = torch.tensor(weights) if weights is not None else None gmrf = GMRF( None, field=Parameter(None, thetas), precision=Parameter(None, precision), tree_model=None, weights=weights_tensor, rescale=rescale, ) lp1 = gmrf() dim = thetas.shape[0] if weights is not None: times = torch.tensor([0.0, 2.0, 6.0, 12.0, 20.0, 25.0]) durations = times[..., 1:] - times[..., :-1] offdiag = -2.0 / (durations[..., :-1] + durations[..., 1:]) if rescale: offdiag *= times[-1] # rescale with root height else: offdiag = torch.full((dim - 1,), -1.0) Q = torch.zeros((dim, dim)) Q[range(dim - 1), range(1, dim)] = offdiag Q[range(1, dim), range(dim - 1)] = offdiag Q[range(1, dim - 1), range(1, dim - 1)] = -(offdiag[..., :-1] + offdiag[..., 1:]) Q[0, 0] = -Q[0, 1] Q[dim - 1, dim - 1] = -offdiag[-1] Q_scaled = Q * precision lp2 = ( 0.5 * (dim - 1) * precision.log() - 0.5 * torch.dot(thetas, Q_scaled @ thetas) - (dim - 1) / 2.0 * 1.8378770664093453 ) assert lp1.item() == pytest.approx(lp2.item())
def create_normal_distribution(var_id, x_unres, json_object, loc, scale): loc_param = Parameter.json_factory( var_id + '.' + json_object['id'] + '.loc', **{ 'full_like': x_unres, 'tensor': loc }, ) if isinstance(loc, list): del loc_param['full_like'] scale_param = { 'id': var_id + '.' + json_object['id'] + '.scale', 'type': 'TransformedParameter', 'transform': 'torch.distributions.ExpTransform', 'x': Parameter.json_factory( var_id + '.' + json_object['id'] + '.scale.unres', **{ 'full_like': x_unres, 'tensor': scale }, ), } if isinstance(scale, list): del scale_param['x']['full_like'] distr = Distribution.json_factory( var_id + '.' + json_object['id'], 'torch.distributions.Normal', x_unres, { 'loc': loc_param, 'scale': scale_param }, ) return distr, loc_param, scale_param
def test_batch(): normal = Distribution( None, torch.distributions.Normal, Parameter(None, torch.tensor([[1.0], [2.0]])), OrderedDict( { 'loc': Parameter(None, torch.tensor([0.0])), 'scale': Parameter(None, torch.tensor([1.0])), } ), ) exp = Distribution( None, torch.distributions.Exponential, Parameter(None, torch.tensor([[1.0], [2.0]])), OrderedDict({'rate': Parameter(None, torch.tensor([1.0]))}), ) joint = JointDistributionModel(None, [normal, exp]) assert torch.allclose(torch.tensor([-1.418939 - 1, -2.918939 - 2]), joint())
def test_simple(): normal = Distribution( None, torch.distributions.Normal, Parameter(None, torch.tensor([1.0])), OrderedDict( { 'loc': Parameter(None, torch.tensor([0.0])), 'scale': Parameter(None, torch.tensor([1.0])), } ), ) exp = Distribution( None, torch.distributions.Exponential, Parameter(None, torch.tensor([1.0])), OrderedDict({'rate': Parameter(None, torch.tensor([1.0]))}), ) joint = JointDistributionModel(None, [normal, exp]) assert (-1.418939 - 1) == pytest.approx(joint().item())
def test_prior_mrbayes(): dic = {} json_tree = UnRootedTreeModel.json_factory( 'tree', '(6:0.02,((5:0.02,2:0.02):0.02,(4:0.02,3:0.02):0.02):0.02,1:0.02);', [0.0], {str(i): None for i in range(1, 7)}, **{ 'keep_branch_lengths': True, 'branch_lengths_id': 'bl', }) tree_model = UnRootedTreeModel.from_json(json_tree, dic) prior = CompoundGammaDirichletPrior( None, tree_model, Parameter(None, torch.tensor([1.0])), Parameter(None, torch.tensor([1.0])), Parameter(None, torch.tensor([1.0])), Parameter(None, torch.tensor([0.1])), ) assert torch.allclose(prior(), torch.tensor([22.00240516662597]))
def create_site_model_srd06_mus(id_): weights = [2 / 3, 1 / 3] y = Parameter.json_factory('srd06.mu', **{'tensor': [0.5, 0.5]}) y['simplex'] = True mus = { 'id': id_, 'type': 'TransformedParameter', 'transform': 'ConvexCombinationTransform', 'x': y, 'parameters': {'weights': weights}, } return mus
def create_ucln_prior(branch_model_id): joint_list = [] mean = Parameter.json_factory( f'{branch_model_id}.rates.prior.mean', **{'tensor': [0.001]} ) scale = Parameter.json_factory( f'{branch_model_id}.rates.prior.scale', **{'tensor': [1.0]} ) mean['lower'] = 0.0 scale['lower'] = 0.0 joint_list.append( Distribution.json_factory( f'{branch_model_id}.rates.prior', 'LogNormal', f'{branch_model_id}.rates', { 'mean': mean, 'scale': scale, }, ) ) joint_list.append( CTMCScale.json_factory( f'{branch_model_id}.mean.prior', f'{branch_model_id}.rates.prior.mean', 'tree', ) ) joint_list.append( Distribution.json_factory( f'{branch_model_id}.rates.scale.prior', 'torch.distributions.Gamma', f'{branch_model_id}.rates.prior.scale', { 'concentration': 0.5396, 'rate': 2.6184, }, ) ) return joint_list
def test_view_parameter_listener(): t = torch.tensor([1.0, 2.0, 3.0, 4.0]) p = Parameter('param', t) p1 = ViewParameter('param1', p, 3) class FakeListener(ParameterListener): def __init__(self): self.gotit = False def handle_parameter_changed(self, variable, index, event): self.gotit = True listener = FakeListener() p1.add_parameter_listener(listener) p.fire_parameter_changed() assert listener.gotit is True listener.gotit = False p1.fire_parameter_changed() assert listener.gotit is True
def test_transformed_parameter(): t = torch.tensor([1.0, 2.0]) p1 = Parameter('param', t) transformed = torch.distributions.ExpTransform() p2 = TransformedParameter('transformed', p1, transformed) assert p2.need_update is False assert torch.all(p2.tensor.eq(t.exp())) assert p2.need_update is False p1.tensor = torch.tensor([1.0, 3.0]) assert p2.need_update is True assert torch.all(p2.tensor.eq(p1.tensor.exp())) assert p2.need_update is False # jacobian p1.tensor = t assert p2.need_update is True assert torch.all(p2().eq(p1.tensor)) assert p2.need_update is False assert torch.all(p2.tensor.eq(p1.tensor.exp()))
def create_site_model(id_, arg, w=None): if arg.categories == 1: site_model = {'id': id_, 'type': 'ConstantSiteModel'} else: shape = Parameter.json_factory(f'{id_}.shape', **{'tensor': [0.1]}) shape['lower'] = 0.0 site_model = { 'id': id_, 'type': 'WeibullSiteModel', 'categories': arg.categories, 'shape': shape, } if arg.model == 'SRD06': site_model['mu'] = w return site_model
def test_gmrf_covariates_simple(): gmrf = GMRF( None, Parameter(None, torch.tensor([1.0, 2.0, 3.0])), Parameter(None, torch.tensor([0.1])), ) gmrf_covariate = GMRFCovariate( None, field=Parameter(None, torch.tensor([1.0, 2.0, 3.0])), precision=Parameter(None, torch.tensor([0.1])), covariates=Parameter(None, torch.arange(1.0, 7.0).view((3, 2))), beta=Parameter(None, torch.tensor([0.0, 0.0])), ) assert gmrf() == gmrf_covariate()
def test_parameterization(parameterization, param): loc = torch.tensor([1.0, 2.0]) x = torch.tensor([1.0, 2.0]) param_tensor = torch.tensor(param) kwargs = {parameterization: Parameter(None, param_tensor)} dist_model = MultivariateNormal(None, Parameter(None, x), Parameter(None, loc), **kwargs) dist = torch.distributions.multivariate_normal.MultivariateNormal( loc, **{parameterization: param_tensor}) assert dist.log_prob(x) == dist_model() x_tuple = (Parameter(None, x[:-1]), Parameter(None, x[-1:])) dist_model = MultivariateNormal(None, x_tuple, Parameter(None, loc), **kwargs) assert dist.log_prob(x) == dist_model()
def test_parameterization_rsample(parameterization, param): torch.manual_seed(0) loc = torch.tensor([1.0, 2.0]) x = torch.tensor([1.0, 2.0]) param_tensor = torch.tensor(param) kwargs = {parameterization: Parameter(None, param_tensor)} dist_model = MultivariateNormal(None, Parameter(None, x), Parameter(None, loc), **kwargs) dist_model.rsample() torch.manual_seed(0) dist = torch.distributions.multivariate_normal.MultivariateNormal( loc, **{parameterization: param_tensor}) samples = dist.rsample() assert torch.all(samples == dist_model.x.tensor) torch.manual_seed(0) x_tuple = (Parameter(':-1', x[:-1]), Parameter('-1:', x[-1:])) dist_model = MultivariateNormal(None, x_tuple, Parameter(None, loc), **kwargs) dist_model.rsample() assert torch.all(samples == dist_model.x.tensor)
def test_ctmc_scale(tree_model_dict): tree_model = TimeTreeModel.from_json(tree_model_dict, {}) ctmc_scale = CTMCScale(None, Parameter(None, torch.tensor([0.001])), tree_model) assert 4.475351922659342 == pytest.approx(ctmc_scale().item(), 0.00001)
def constant_coalescent(args): tree = read_tree(args.tree, True, True) taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0]) ) tree_model = TimeTreeModel("tree", tree, Taxa('taxa', taxa), ratios_root_height) tree_model._internal_heights.tensor = heights_from_branch_lengths(tree) pop_size = torch.tensor([4.0]) print('JIT off') @benchmark def fn(tree_model, pop_size): return ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) @benchmark def fn_grad(tree_model, pop_size): log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p total_time, log_p = fn(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} {log_p}') ratios_root_height.requires_grad = True pop_size.requires_grad_(True) grad_total_time, grad_log_p = fn_grad(args.replicates, tree_model, pop_size) print(f' {args.replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write( f"coalescent,evaluation,off,{total_time},{log_p.squeeze().item()}\n" ) args.output.write( f"coalescent,gradient,off,{grad_total_time},{grad_log_p.squeeze().item()}\n" ) if args.debug: tree_model.heights_need_update = True log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) log_p.backward() print('gradient ratios: ', ratios_root_height.grad) print('gradient pop size: ', pop_size.grad) ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() print('JIT on') log_prob_script = torch.jit.script(log_prob) @benchmark def fn_jit(tree_model, pop_size): return log_prob_script(tree_model.node_heights, pop_size) @benchmark def fn_grad_jit(tree_model, pop_size): log_p = log_prob_script(tree_model.node_heights, pop_size) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p ratios_root_height.requires_grad = False pop_size.requires_grad_(False) total_time, log_p = fn_jit(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} {log_p}') ratios_root_height.requires_grad = True pop_size.requires_grad_(True) grad_total_time, grad_log_p = fn_grad_jit(args.replicates, tree_model, pop_size) print(f' {args.replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write( f"coalescent,evaluation,on,{total_time},{log_p.squeeze().item()}\n" ) args.output.write( f"coalescent,gradient,on,{grad_total_time},{grad_log_p.squeeze().item()}\n" ) if args.all: print('make sampling times unique and count them:') @benchmark def fn3(tree_model, pop_size): tree_model.heights_need_update = True node_heights = torch.cat( (x, tree_model.node_heights[..., tree_model.taxa_count :]) ) return log_prob_squashed( pop_size, node_heights, counts, tree_model.taxa_count ) @benchmark def fn3_grad(tree_model, ratios_root_height, pop_size): tree_model.heights_need_update = True node_heights = torch.cat( (x, tree_model.node_heights[..., tree_model.taxa_count :]) ) log_p = log_prob_squashed( pop_size, node_heights, counts, tree_model.taxa_count ) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p x, counts = torch.unique(tree_model.sampling_times, return_counts=True) counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1)))) with torch.no_grad(): total_time, log_p = fn3(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} ({log_p})') total_time, log_p = fn3_grad( args.replicates, tree_model, ratios_root_height, pop_size ) print(f' {args.replicates} gradient evaluations: {total_time}')