Exemple #1
0
def test_LeafCTMC_event_shape_tensor(transition_prob_tree, hky_params):
    state_count = transition_prob_tree.branch_lengths.numpy().shape[-1]
    dist = LeafCTMC(transition_prob_tree, hky_params["frequencies"])
    assert tuple(dist.event_shape_tensor().numpy()) == (
        transition_prob_tree.taxon_count,
        state_count,
    )
Exemple #2
0
 def log_prob_fn(tree, sequences):
     transition_prob_tree = get_transition_probabilities_tree(
         tree.get_unrooted_tree(), HKY(), **hky_params
     )  # TODO: Better solution for batch dimensions
     dist = Sample(
         LeafCTMC(transition_prob_tree, hky_params["frequencies"]),
         sample_shape=hello_alignment.site_count,
     )
     return dist.log_prob(sequences)
Exemple #3
0
 def log_prob(branch_lengths: tf.Tensor):
     tree = base_unrooted_tree.with_branch_lengths(branch_lengths)
     transition_probs_tree = get_transition_probabilities_tree(
         tree, subst_model, dtype=dtype)
     dist = SampleWeighted(
         LeafCTMC(transition_probs_tree, frequencies),
         weights=weights,
         sample_shape=(site_count, ),
     )
     return dist.log_prob(encoded_sequences)
def compute_log_prob_uncompressed(
    alignment: Alignment,
    transition_probs_tree: TensorflowUnrootedTree,
    frequencies,
):
    sequences_encoded = alignment.get_encoded_sequence_tensor(
        transition_probs_tree.taxon_set)
    dist = Sample(
        LeafCTMC(transition_probs_tree, frequencies),
        sample_shape=(alignment.site_count, ),
    )
    return dist.log_prob(sequences_encoded)
Exemple #5
0
def test_log_prob_conditioned_hky(hky_params, newick_fasta_file_dated):
    from treeflow.acceleration.bito.beagle import (
        phylogenetic_likelihood as beagle_likelihood, )

    newick_file, fasta_file, dated = newick_fasta_file_dated
    subst_model = HKY()
    tensor_tree = convert_tree_to_tensor(
        parse_newick(newick_file)).get_unrooted_tree()
    alignment = Alignment(fasta_file)
    sequences = alignment.get_encoded_sequence_tensor(tensor_tree.taxon_set)
    treeflow_func = lambda blens: Sample(
        LeafCTMC(
            get_transition_probabilities_tree(
                tensor_tree.with_branch_lengths(blens), subst_model, **
                hky_params),
            hky_params["frequencies"],
        ),
        sample_shape=alignment.site_count,
    ).log_prob(sequences)

    beagle_func, _ = beagle_likelihood(fasta_file,
                                       subst_model,
                                       newick_file=newick_file,
                                       dated=dated,
                                       **hky_params)

    blens = tensor_tree.branch_lengths
    with tf.GradientTape() as tf_t:
        tf_t.watch(blens)
        tf_ll = treeflow_func(blens)
    tf_gradient = tf_t.gradient(tf_ll, blens)

    with tf.GradientTape() as bito_t:
        bito_t.watch(blens)
        bito_ll = beagle_func(blens)
    bito_gradient = bito_t.gradient(bito_ll, blens)

    assert_allclose(tf_ll.numpy(), bito_ll.numpy())
    assert_allclose(tf_gradient.numpy(), bito_gradient.numpy())
Exemple #6
0
def alignment_func(
    tree: TensorflowRootedTree,
    rates: tf.Tensor,
    site_count: int,
    weights: tp.Optional[tf.Tensor] = None,
):
    unrooted_time_tree = tree.get_unrooted_tree()
    distance_tree = unrooted_time_tree.with_branch_lengths(
        unrooted_time_tree.branch_lengths * rates)
    subst_model = JC()
    frequencies = subst_model.frequencies(dtype=rates.dtype)
    transition_probs = get_transition_probabilities_tree(
        distance_tree,
        subst_model,
        frequencies=frequencies,
    )
    leaf_ctmc = LeafCTMC(transition_probs, frequencies=frequencies)
    sample_shape = (site_count, )
    if weights is None:
        return tfd.Sample(leaf_ctmc, sample_shape=sample_shape)
    else:
        return SampleWeighted(leaf_ctmc,
                              weights=weights,
                              sample_shape=sample_shape)
Exemple #7
0
 def event_shape_function(transition_prob_tree, frequencies):
     dist = LeafCTMC(transition_prob_tree, frequencies)
     return dist.event_shape