def test_LeafCTMC_event_shape_tensor(transition_prob_tree, hky_params): state_count = transition_prob_tree.branch_lengths.numpy().shape[-1] dist = LeafCTMC(transition_prob_tree, hky_params["frequencies"]) assert tuple(dist.event_shape_tensor().numpy()) == ( transition_prob_tree.taxon_count, state_count, )
def log_prob_fn(tree, sequences): transition_prob_tree = get_transition_probabilities_tree( tree.get_unrooted_tree(), HKY(), **hky_params ) # TODO: Better solution for batch dimensions dist = Sample( LeafCTMC(transition_prob_tree, hky_params["frequencies"]), sample_shape=hello_alignment.site_count, ) return dist.log_prob(sequences)
def log_prob(branch_lengths: tf.Tensor): tree = base_unrooted_tree.with_branch_lengths(branch_lengths) transition_probs_tree = get_transition_probabilities_tree( tree, subst_model, dtype=dtype) dist = SampleWeighted( LeafCTMC(transition_probs_tree, frequencies), weights=weights, sample_shape=(site_count, ), ) return dist.log_prob(encoded_sequences)
def compute_log_prob_uncompressed( alignment: Alignment, transition_probs_tree: TensorflowUnrootedTree, frequencies, ): sequences_encoded = alignment.get_encoded_sequence_tensor( transition_probs_tree.taxon_set) dist = Sample( LeafCTMC(transition_probs_tree, frequencies), sample_shape=(alignment.site_count, ), ) return dist.log_prob(sequences_encoded)
def test_log_prob_conditioned_hky(hky_params, newick_fasta_file_dated): from treeflow.acceleration.bito.beagle import ( phylogenetic_likelihood as beagle_likelihood, ) newick_file, fasta_file, dated = newick_fasta_file_dated subst_model = HKY() tensor_tree = convert_tree_to_tensor( parse_newick(newick_file)).get_unrooted_tree() alignment = Alignment(fasta_file) sequences = alignment.get_encoded_sequence_tensor(tensor_tree.taxon_set) treeflow_func = lambda blens: Sample( LeafCTMC( get_transition_probabilities_tree( tensor_tree.with_branch_lengths(blens), subst_model, ** hky_params), hky_params["frequencies"], ), sample_shape=alignment.site_count, ).log_prob(sequences) beagle_func, _ = beagle_likelihood(fasta_file, subst_model, newick_file=newick_file, dated=dated, **hky_params) blens = tensor_tree.branch_lengths with tf.GradientTape() as tf_t: tf_t.watch(blens) tf_ll = treeflow_func(blens) tf_gradient = tf_t.gradient(tf_ll, blens) with tf.GradientTape() as bito_t: bito_t.watch(blens) bito_ll = beagle_func(blens) bito_gradient = bito_t.gradient(bito_ll, blens) assert_allclose(tf_ll.numpy(), bito_ll.numpy()) assert_allclose(tf_gradient.numpy(), bito_gradient.numpy())
def alignment_func( tree: TensorflowRootedTree, rates: tf.Tensor, site_count: int, weights: tp.Optional[tf.Tensor] = None, ): unrooted_time_tree = tree.get_unrooted_tree() distance_tree = unrooted_time_tree.with_branch_lengths( unrooted_time_tree.branch_lengths * rates) subst_model = JC() frequencies = subst_model.frequencies(dtype=rates.dtype) transition_probs = get_transition_probabilities_tree( distance_tree, subst_model, frequencies=frequencies, ) leaf_ctmc = LeafCTMC(transition_probs, frequencies=frequencies) sample_shape = (site_count, ) if weights is None: return tfd.Sample(leaf_ctmc, sample_shape=sample_shape) else: return SampleWeighted(leaf_ctmc, weights=weights, sample_shape=sample_shape)
def event_shape_function(transition_prob_tree, frequencies): dist = LeafCTMC(transition_prob_tree, frequencies) return dist.event_shape