def test_taxa(): taxa = { "id": "taxa", "type": "torchtree.evolution.taxa.Taxa", "taxa": [ { "id": "A_Belgium_2_1981", "type": "torchtree.evolution.taxa.Taxon", "attributes": { "date": 1981 }, }, { "id": "A_ChristHospital_231_1982", "type": "torchtree.evolution.taxa.Taxon", "attributes": { "date": 1982 }, }, ], } dic = {} taxa = Taxa.from_json(taxa, dic) assert len(taxa) == 2 assert taxa[1].id == 'A_ChristHospital_231_1982' assert taxa[1]['date'] == 1982
def test_site_pattern(): taxa = Taxa(None, [Taxon(taxon, {}) for taxon in 'ABCD']) sequences = [ Sequence(taxon, seq) for taxon, seq in zip('ABCD', ['AAG', 'AAC', 'AAC', 'AAT']) ] alignment = Alignment(None, sequences, taxa, NucleotideDataType(None)) site_pattern = SitePattern(None, alignment) partials, weights = site_pattern.compute_tips_partials() assert torch.all(weights == torch.tensor([[2.0, 1.0]])) assert partials[0].shape == torch.Size([4, 2])
def _prepare_tiny(tiny_newick_file, tiny_fasta_file): tree, dna = read_tree_and_alignment(tiny_newick_file, tiny_fasta_file, False, False) branch_lengths = torch.tensor([ float(node.edge_length) for node in sorted(list(tree.postorder_node_iter())[:-1], key=lambda x: x.index) ], ) indices = [] for node in tree.postorder_internal_node_iter(): indices.append([node.index] + [child.index for child in node.child_nodes()]) sequences = [] taxa = [] for taxon, seq in dna.items(): sequences.append(Sequence(taxon.label, str(seq))) taxa.append(Taxon(taxon.label, None)) partials, weights_tensor = compress_alignment( Alignment(None, sequences, Taxa(None, taxa), NucleotideDataType(None))) partials.extend([None] * (len(dna) - 1)) return partials, weights_tensor, indices, branch_lengths
def constant_coalescent(args): tree = read_tree(args.tree, True, True) taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0]) ) tree_model = TimeTreeModel("tree", tree, Taxa('taxa', taxa), ratios_root_height) tree_model._internal_heights.tensor = heights_from_branch_lengths(tree) pop_size = torch.tensor([4.0]) print('JIT off') @benchmark def fn(tree_model, pop_size): return ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) @benchmark def fn_grad(tree_model, pop_size): log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p total_time, log_p = fn(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} {log_p}') ratios_root_height.requires_grad = True pop_size.requires_grad_(True) grad_total_time, grad_log_p = fn_grad(args.replicates, tree_model, pop_size) print(f' {args.replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write( f"coalescent,evaluation,off,{total_time},{log_p.squeeze().item()}\n" ) args.output.write( f"coalescent,gradient,off,{grad_total_time},{grad_log_p.squeeze().item()}\n" ) if args.debug: tree_model.heights_need_update = True log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights) log_p.backward() print('gradient ratios: ', ratios_root_height.grad) print('gradient pop size: ', pop_size.grad) ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() print('JIT on') log_prob_script = torch.jit.script(log_prob) @benchmark def fn_jit(tree_model, pop_size): return log_prob_script(tree_model.node_heights, pop_size) @benchmark def fn_grad_jit(tree_model, pop_size): log_p = log_prob_script(tree_model.node_heights, pop_size) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p ratios_root_height.requires_grad = False pop_size.requires_grad_(False) total_time, log_p = fn_jit(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} {log_p}') ratios_root_height.requires_grad = True pop_size.requires_grad_(True) grad_total_time, grad_log_p = fn_grad_jit(args.replicates, tree_model, pop_size) print(f' {args.replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write( f"coalescent,evaluation,on,{total_time},{log_p.squeeze().item()}\n" ) args.output.write( f"coalescent,gradient,on,{grad_total_time},{grad_log_p.squeeze().item()}\n" ) if args.all: print('make sampling times unique and count them:') @benchmark def fn3(tree_model, pop_size): tree_model.heights_need_update = True node_heights = torch.cat( (x, tree_model.node_heights[..., tree_model.taxa_count :]) ) return log_prob_squashed( pop_size, node_heights, counts, tree_model.taxa_count ) @benchmark def fn3_grad(tree_model, ratios_root_height, pop_size): tree_model.heights_need_update = True node_heights = torch.cat( (x, tree_model.node_heights[..., tree_model.taxa_count :]) ) log_p = log_prob_squashed( pop_size, node_heights, counts, tree_model.taxa_count ) log_p.backward() ratios_root_height.tensor.grad.data.zero_() pop_size.grad.data.zero_() return log_p x, counts = torch.unique(tree_model.sampling_times, return_counts=True) counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1)))) with torch.no_grad(): total_time, log_p = fn3(args.replicates, tree_model, pop_size) print(f' {args.replicates} evaluations: {total_time} ({log_p})') total_time, log_p = fn3_grad( args.replicates, tree_model, ratios_root_height, pop_size ) print(f' {args.replicates} gradient evaluations: {total_time}')
def ratio_transform(args): replicates = args.replicates tree = read_tree(args.tree, True, True) taxa_count = len(tree.taxon_namespace) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) ratios_root_height.tensor = tree_model.transform.inv( heights_from_branch_lengths(tree) ) @benchmark def fn(ratios_root_height): return tree_model.transform( ratios_root_height, ) @benchmark def fn_grad(ratios_root_height): heights = tree_model.transform( ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights total_time, heights = fn(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {grad_total_time}') if args.output: args.output.write(f"ratio_transform,evaluation,off,{total_time},\n") args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n") print(' JIT off') @benchmark def fn2(ratios_root_height): return transform( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn2_grad(ratios_root_height): heights = transform( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn2(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print(' JIT on') transform_script = torch.jit.script(transform) @benchmark def fn2_jit(ratios_root_height): return transform_script( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn2_grad_jit(ratios_root_height): heights = transform_script( tree_model.transform._forward_indices, tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print('ratio_transform v2 JIT off') @benchmark def fn3(ratios_root_height): return transform2( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn3_grad(ratios_root_height): heights = transform2( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn3(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}') print('ratio_transform v2 JIT on') transform2_script = torch.jit.script(transform2) @benchmark def fn3_jit(ratios_root_height): return transform2_script( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) @benchmark def fn3_grad_jit(ratios_root_height): heights = transform2_script( tree_model.transform._forward_indices.tolist(), tree_model.transform._bounds, ratios_root_height, ) heights.backward(torch.ones_like(ratios_root_height)) ratios_root_height.grad.data.zero_() return heights ratios_root_height.requires_grad = False total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} evaluations: {total_time}') ratios_root_height.requires_grad = True total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor) print(f' {replicates} gradient evaluations: {total_time}')
def ratio_transform_jacobian(args): tree = read_tree(args.tree, True, True) taxa = [] for node in tree.leaf_node_iter(): taxa.append(Taxon(node.label, {'date': node.date})) taxa_count = len(taxa) ratios_root_height = Parameter( "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20]) ) tree_model = ReparameterizedTimeTreeModel( "tree", tree, Taxa('taxa', taxa), ratios_root_height ) ratios_root_height.tensor = tree_model.transform.inv( heights_from_branch_lengths(tree) ) @benchmark def fn(ratios_root_height): internal_heights = tree_model.transform(ratios_root_height) return tree_model.transform.log_abs_det_jacobian( ratios_root_height, internal_heights ) @benchmark def fn_grad(ratios_root_height): internal_heights = tree_model.transform(ratios_root_height) log_det_jac = tree_model.transform.log_abs_det_jacobian( ratios_root_height, internal_heights ) log_det_jac.backward() ratios_root_height.grad.data.zero_() return log_det_jac print(' JIT off') total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor) print(f' {args.replicates} evaluations: {total_time} ({log_det_jac})') ratios_root_height.requires_grad = True grad_total_time, grad_log_det_jac = fn_grad( args.replicates, ratios_root_height.tensor ) print( f' {args.replicates} gradient evaluations: {grad_total_time}' f' ({grad_log_det_jac})' ) if args.output: args.output.write( f"ratio_transform_jacobian,evaluation,off,{total_time}," f"{log_det_jac.squeeze().item()}\n" ) args.output.write( f"ratio_transform_jacobian,gradient,off,{grad_total_time}," f"{grad_log_det_jac.squeeze().item()}\n" ) if args.debug: internal_heights = tree_model.transform(ratios_root_height.tensor) log_det_jac = tree_model.transform.log_abs_det_jacobian( ratios_root_height.tensor, internal_heights ) log_det_jac.backward() print(ratios_root_height.grad)
def unrooted_treelikelihood(args, subst_model): tree, dna = read_tree_and_alignment(args.tree, args.input, True, True) branch_lengths = torch.tensor( [ float(node.edge_length) * args.scaler for node in sorted( list(tree.postorder_node_iter())[:-1], key=lambda x: x.index ) ], ) branch_lengths = torch.clamp(branch_lengths, min=1.0e-6) indices = [] for node in tree.postorder_internal_node_iter(): indices.append([node.index] + [child.index for child in node.child_nodes()]) sequences = [] taxa = [] for taxon, seq in dna.items(): sequences.append(Sequence(taxon.label, str(seq))) taxa.append(Taxon(taxon.label, None)) partials, weights_tensor = compress_alignment( Alignment(None, sequences, Taxa(None, taxa), NucleotideDataType('nuc')) ) partials.extend([None] * (len(dna) - 1)) freqs = subst_model.frequencies proportions = torch.tensor([[[1.0]]]) threshold = 1.0e-20 if args.dtype == 'float32' else 1.0e-40 print('treelikelihood v1') @benchmark def fn_safe(bls): mats = subst_model.p_t(bls) return calculate_treelikelihood_discrete_safe( partials, weights_tensor, indices, mats, freqs, proportions, threshold ) @benchmark def fn_safe_grad(bls): mats = subst_model.p_t(bls) log_prob = calculate_treelikelihood_discrete_safe( partials, weights_tensor, indices, mats, freqs, proportions, threshold ) log_prob.backward() return log_prob blens = branch_lengths.unsqueeze(0).unsqueeze(-1) total_time, log_prob = fn_safe(args.replicates, blens) print(f' {args.replicates} evaluations: {total_time} ({log_prob})') blens.requires_grad = True for p in subst_model.parameters(): p.requires_grad = True grad_total_time, grad_log_prob = fn_safe_grad(args.replicates, blens) print( f' {args.replicates} gradient evaluations: {grad_total_time} ({grad_log_prob}' ) if args.output: name = '' if isinstance(subst_model, JC69) else type(subst_model).__name__ args.output.write( f"treelikelihood{name},evaluation,off,{total_time}," f"{log_prob.squeeze().item()}\n" ) args.output.write( f"treelikelihood{name},gradient,off,{grad_total_time}," f"{grad_log_prob.squeeze().item()}\n" ) if args.all and isinstance(subst_model, JC69): tip_partials = torch.stack(partials[: len(sequences)]) indices = torch.tensor(indices) print('treelikelihood v2 JIT off') @benchmark def fn2(bls): mats = p_t(bls) return calculate_treelikelihood_discrete_split( tip_partials, weights_tensor, indices, mats, freqs, proportions ) @benchmark def fn2_grad(bls): mats = p_t(bls) log_prob = calculate_treelikelihood_discrete_split( tip_partials, weights_tensor, indices, mats, freqs, proportions ) log_prob.backward() return log_prob with torch.no_grad(): total_time, log_prob = fn2(args.replicates, blens) print(f' {args.replicates} evaluations: {total_time} ({log_prob})') total_time, log_prob = fn2_grad(args.replicates, blens) print(f' {args.replicates} gradient evaluations: {total_time}') print('treelikelihood v2 JIT on') like2_jit = torch.jit.script(calculate_treelikelihood_discrete_split) @benchmark def fn2_jit(bls): mats = p_t(bls) return like2_jit( tip_partials, weights_tensor, indices, mats, freqs, proportions ) @benchmark def fn2_grad_jit(bls): mats = p_t(bls) log_prob = like2_jit( tip_partials, weights_tensor, indices, mats, freqs, proportions ) log_prob.backward() return log_prob with torch.no_grad(): total_time, log_prob = fn2_jit(args.replicates, blens) print(f' {args.replicates} evaluations: {total_time}') total_time, log_prob = fn2_grad_jit(args.replicates, blens) print(f' {args.replicates} gradient evaluations: {total_time}') print('treelikelihood v3 JIT off') @benchmark def fn3(bls): mats = p_t(bls) return calculate_treelikelihood_discrete_cat( tip_partials, weights_tensor, indices, mats, freqs, proportions ) @benchmark def fn3_grad(bls): mats = p_t(bls) log_prob = calculate_treelikelihood_discrete_cat( tip_partials, weights_tensor, indices, mats, freqs, proportions ) log_prob.backward() return log_prob with torch.no_grad(): total_time, log_prob = fn3(args.replicates, blens) print(f' {args.replicates} evaluations: {total_time} ({log_prob})') total_time, log_prob = fn3_grad(args.replicates, blens) print(f' {args.replicates} gradient evaluations: {total_time}') print('treelikelihood v3 JIT on') like3_jit = torch.jit.script(calculate_treelikelihood_discrete_cat) @benchmark def fn3_jit(bls): mats = p_t(bls) return like3_jit( tip_partials, weights_tensor, indices, mats, freqs, proportions ) @benchmark def fn3_grad_jit(bls): mats = p_t(bls) log_prob = like3_jit( tip_partials, weights_tensor, indices, mats, freqs, proportions ) log_prob.backward() return log_prob with torch.no_grad(): total_time, log_prob = fn3_jit(args.replicates, blens) print(f' {args.replicates} evaluations: {total_time} ({log_prob})') total_time, log_prob = fn3_grad_jit(args.replicates, blens) print(f' {args.replicates} gradient evaluations: {total_time}')
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file): taxa_list = [] with open(flu_a_fasta_file) as fp: for line in fp: if line.startswith('>'): taxon = line[1:].strip() date = float(taxon.split('_')[-1]) taxa_list.append(Taxon(taxon, {'date': date})) taxa = Taxa('taxa', taxa_list) site_pattern = { 'id': 'sp', 'type': 'torchtree.evolution.site_pattern.SitePattern', 'alignment': { "id": "alignment", "type": "torchtree.evolution.alignment.Alignment", 'datatype': 'nucleotide', 'file': flu_a_fasta_file, 'taxa': 'taxa', }, } subst_model = JC69('jc') site_model = WeibullSiteModel('site_model', Parameter(None, torch.tensor([[0.1]])), 4) ratios = [0.5] * 67 root_height = [20.0] with open(flu_a_tree_file) as fp: newick = fp.read().strip() dic = {'taxa': taxa} tree_model = ReparameterizedTimeTreeModel.from_json( ReparameterizedTimeTreeModel.json_factory( 'tree_model', newick, ratios, root_height, 'taxa', **{'keep_branch_lengths': True}), dic, ) branch_model = StrictClockModel(None, Parameter(None, torch.tensor([[0.001]])), tree_model) dic['tree_model'] = tree_model dic['site_model'] = site_model dic['subst_model'] = subst_model dic['branch_model'] = branch_model like = likelihood.TreeLikelihoodModel.from_json( { 'id': 'like', 'type': 'torchtree.tree_likelihood.TreeLikelihoodModel', 'tree_model': 'tree_model', 'site_model': 'site_model', 'site_pattern': site_pattern, 'substitution_model': 'subst_model', 'branch_model': 'branch_model', }, dic, ) assert torch.allclose(torch.tensor([-4618.2062529058]), like()) branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1) site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1) tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat( 3, 68) assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
def build_alignment(file_name, data_type): sequences = read_fasta_sequences(file_name) taxa = Taxa('taxa', [Taxon(sequence.taxon, {}) for sequence in sequences]) return Alignment('alignment', sequences, taxa, data_type)