Beispiel #1
0
def get_tree_likelihood_computation(
    tree: TensorflowRootedTree,
    input: Alignment,
    dtype: tf.DType,
    bito_instance: tp.Optional[object] = None
) -> tp.Tuple[tp.Callable[[tf.Tensor], tf.Tensor], tf.Tensor]:
    base_unrooted_tree = tree.get_unrooted_tree()
    subst_model = JC()
    if bito_instance is not None:
        from treeflow.acceleration.bito.beagle import (
            phylogenetic_likelihood as beagle_likelihood, )
        log_prob, _ = beagle_likelihood(input.fasta_file,
                                        subst_model,
                                        subst_model.frequencies(dtype=dtype),
                                        inst=bito_instance)
    else:
        compressed_alignment = input.get_compressed_alignment()
        encoded_sequences = compressed_alignment.get_encoded_sequence_tensor(
            tree.taxon_set, dtype=dtype)
        weights = compressed_alignment.get_weights_tensor(dtype=dtype)
        site_count = compressed_alignment.site_count
        frequencies = subst_model.frequencies(dtype=dtype)

        def log_prob(branch_lengths: tf.Tensor):
            tree = base_unrooted_tree.with_branch_lengths(branch_lengths)
            transition_probs_tree = get_transition_probabilities_tree(
                tree, subst_model, dtype=dtype)
            dist = SampleWeighted(
                LeafCTMC(transition_probs_tree, frequencies),
                weights=weights,
                sample_shape=(site_count, ),
            )
            return dist.log_prob(encoded_sequences)

    return log_prob, base_unrooted_tree.branch_lengths
Beispiel #2
0
def test_LeafCTMC_log_prob_over_sites(
    hello_tensor_tree: TensorflowRootedTree,
    hello_alignment: Alignment,
    hky_params,
    hello_hky_log_likelihood,
    function_mode: bool,
):
    """Integration-style test"""

    def log_prob_fn(tree, sequences):
        transition_prob_tree = get_transition_probabilities_tree(
            tree.get_unrooted_tree(), HKY(), **hky_params
        )  # TODO: Better solution for batch dimensions
        dist = Sample(
            LeafCTMC(transition_prob_tree, hky_params["frequencies"]),
            sample_shape=hello_alignment.site_count,
        )
        return dist.log_prob(sequences)

    sequences = hello_alignment.get_encoded_sequence_tensor(hello_tensor_tree.taxon_set)

    if function_mode:
        log_prob_fn = tf.function(log_prob_fn)
    res = log_prob_fn(hello_tensor_tree, sequences)
    assert_allclose(res, hello_hky_log_likelihood)
Beispiel #3
0
def test_seqio_parse_fasta(hello_fasta_file):
    alignment = Alignment(hello_fasta_file)
    expected_keys = {"mars", "saturn", "jupiter"}
    expected_len = 31
    assert set(alignment.sequence_mapping.keys()) == expected_keys
    for key in expected_keys:
        assert (len(alignment.sequence_mapping[key])) == expected_len
Beispiel #4
0
def test_phylo_likelihood_hky_beast(
    hello_tensor_tree: TensorflowRootedTree,
    hello_alignment: Alignment,
    function_mode: bool,
    hky_params,
    hello_hky_log_likelihood: float,
):
    subst_model = HKY()
    eigen = subst_model.eigen(**hky_params)
    probs = tf.expand_dims(
        get_transition_probabilities_eigen(eigen,
                                           hello_tensor_tree.branch_lengths),
        0)
    encoded_sequences = hello_alignment.get_encoded_sequence_tensor(
        hello_tensor_tree.taxon_set)
    if function_mode:
        func = tf.function(phylogenetic_likelihood)
    else:
        func = phylogenetic_likelihood
    site_partials = func(
        encoded_sequences,
        probs,
        hky_params["frequencies"],
        hello_tensor_tree.topology.postorder_node_indices,
        hello_tensor_tree.topology.node_child_indices,
        batch_shape=tf.shape(encoded_sequences)[:1],
    )
    res = tf.reduce_sum(tf.math.log(site_partials))
    expected = hello_hky_log_likelihood
    assert_allclose(res.numpy(), expected)
def test_sample_weighted(wnv_newick_file, wnv_fasta_file, hky_params):
    alignment = Alignment(wnv_fasta_file)
    numpy_tree = parse_newick(wnv_newick_file)
    tensor_tree = convert_tree_to_tensor(numpy_tree)
    transition_prob_tree = get_transition_probabilities_tree(
        tensor_tree.get_unrooted_tree(), HKY(), **hky_params)

    log_prob_uncompressed = compute_log_prob_uncompressed(
        alignment, transition_prob_tree, hky_params["frequencies"])
    log_prob_compressed = compute_log_prob_compressed(
        alignment, transition_prob_tree, hky_params["frequencies"])
    assert_allclose(log_prob_uncompressed.numpy(), log_prob_compressed.numpy())
def compute_log_prob_uncompressed(
    alignment: Alignment,
    transition_probs_tree: TensorflowUnrootedTree,
    frequencies,
):
    sequences_encoded = alignment.get_encoded_sequence_tensor(
        transition_probs_tree.taxon_set)
    dist = Sample(
        LeafCTMC(transition_probs_tree, frequencies),
        sample_shape=(alignment.site_count, ),
    )
    return dist.log_prob(sequences_encoded)
Beispiel #7
0
def test_log_prob_conditioned_hky(hky_params, newick_fasta_file_dated):
    from treeflow.acceleration.bito.beagle import (
        phylogenetic_likelihood as beagle_likelihood, )

    newick_file, fasta_file, dated = newick_fasta_file_dated
    subst_model = HKY()
    tensor_tree = convert_tree_to_tensor(
        parse_newick(newick_file)).get_unrooted_tree()
    alignment = Alignment(fasta_file)
    sequences = alignment.get_encoded_sequence_tensor(tensor_tree.taxon_set)
    treeflow_func = lambda blens: Sample(
        LeafCTMC(
            get_transition_probabilities_tree(
                tensor_tree.with_branch_lengths(blens), subst_model, **
                hky_params),
            hky_params["frequencies"],
        ),
        sample_shape=alignment.site_count,
    ).log_prob(sequences)

    beagle_func, _ = beagle_likelihood(fasta_file,
                                       subst_model,
                                       newick_file=newick_file,
                                       dated=dated,
                                       **hky_params)

    blens = tensor_tree.branch_lengths
    with tf.GradientTape() as tf_t:
        tf_t.watch(blens)
        tf_ll = treeflow_func(blens)
    tf_gradient = tf_t.gradient(tf_ll, blens)

    with tf.GradientTape() as bito_t:
        bito_t.watch(blens)
        bito_ll = beagle_func(blens)
    bito_gradient = bito_t.gradient(bito_ll, blens)

    assert_allclose(tf_ll.numpy(), bito_ll.numpy())
    assert_allclose(tf_gradient.numpy(), bito_gradient.numpy())
Beispiel #8
0
def treeflow_vi(
    input,
    topology,
    num_steps,
    optimizer,
    substitution_model,
    clock_model,
    tree_prior,
    learning_rate,
    init_values,
    approx_output,
    trace_output,
):
    optimizer = optimizer_classes[optimizer](learning_rate=learning_rate)

    print(f"Parsing topology {topology}")
    tree = convert_tree_to_tensor(parse_newick(topology))

    print(f"Parsing alignment {input}")
    alignment = Alignment(input).get_compressed_alignment()
    encoded_sequences = alignment.get_encoded_sequence_tensor(tree.taxon_set)
    pattern_counts = alignment.get_weights_tensor()

    print(f"Parsing initial values...")
    init_values_dict = (None if init_values is None else {
        key: tf.constant(value, dtype=DEFAULT_FLOAT_DTYPE_TF)
        for key, value in parse_init_values(init_values).items()
    })

    if (clock_model == _FIXED_STRICT and substitution_model == _JC_KEY
            and tree_prior == _CONSTANT_COALESCENT):
        model = get_example_phylo_model(
            taxon_count=tree.taxon_count,
            site_count=alignment.site_count,
            sampling_times=tree.sampling_times,
            pattern_counts=pattern_counts,
            init_values=init_values_dict,
        )
    else:
        raise ValueError(
            "Only example with JC, fixed-rate strict clock, constant coalescent supported for now"
        )
    pinned_model = model.experimental_pin(alignment=encoded_sequences)
    model_names = set(pinned_model._flat_resolve_names())

    init_loc = {
        key: value
        for key, value in init_values_dict.items() if key in model_names
    }
    init_loc["tree"] = tree

    print(f"Running VI for {num_steps} iterations...")
    approx, trace = fit_fixed_topology_variational_approximation(
        model=pinned_model,
        topologies={DEFAULT_TREE_NAME: tree.topology},
        init_loc=init_loc,
        optimizer=optimizer,
        num_steps=num_steps,
    )
    print("Inference complete")
    print("Approx sample:")
    print(tf.nest.map_structure(lambda x: x.numpy(), approx.sample()))

    if approx_output is not None:
        print(f"Saving approximation to {approx_output}...")
        with open(approx_output, "wb") as f:
            pickle.dump(approx, f)  # TODO: Support saving approximation

    if trace_output is not None:
        print(f"Saving trace to {trace_output}...")
        with open(trace_output, "wb") as f:
            pickle.dump(trace, f)

    print("Exiting...")
Beispiel #9
0
def treeflow_benchmark(input: str, tree: str, replicates: int,
                       output: tp.Optional[io.StringIO], dtype: str,
                       scaler: float, precompile: bool, eager: bool,
                       memory: bool, use_bito: bool):

    print("Parsing input...")
    alignment = Alignment(input)
    numpy_tree = parse_newick(tree, remove_zero_edges=True)
    tf_dtype = DTYPE_MAPPING[dtype]
    tensor_tree = convert_tree_to_tensor(numpy_tree, height_dtype=tf_dtype)
    scaler_tensor = tf.constant(scaler, dtype=tf_dtype)

    computations = [
        "treelikelihood",
        "ratio_transform",
        "ratio_transform_jacobian",
        "constant_coalescent",
    ]
    tasks = ["gradient", "evaluation"]
    jits = [True, False] if eager else [True]

    if output:
        output.write("function,mode,JIT,time,logprob")
        if memory:
            output.write(",max_mem")
        output.write("\n")

    print("Starting benchmark...")
    if use_bito:
        from treeflow.acceleration.bito.instance import get_instance
        dated = not np.allclose(numpy_tree.sampling_times, 0.0)
        bito_instance = get_instance(tree, dated=dated)
    else:
        bito_instance = None
    for (computation, task, jit) in product(computations, tasks, jits):
        print(
            f"Benchmarking {computation} {task}{' in function mode' if jit else ''}..."
        )
        benchmark_args = (
            tensor_tree,
            alignment,
            tf_dtype,
            scaler_tensor,
            computation,
            task,
            jit,
            precompile,
            replicates,
        )
        max_mem: tp.Optional[float]
        if memory:
            from memory_profiler import memory_usage
            max_mem, (time, value) = memory_usage(
                (benchmark, benchmark_args, dict(bito_instance=bito_instance)),
                retval=True,
                max_usage=True,
                max_iterations=1,
            )
        else:
            time, value = benchmark(*benchmark_args,
                                    bito_instance=bito_instance)
            max_mem = None
        if output:
            jit_str = "on" if jit else "off"
            output.write(
                f"{computation},{task},{jit_str},{time},{value if value else ''}"
            )
            if max_mem is not None:
                output.write(f",{max_mem}")
            output.write("\n")
        if max_mem is not None:
            print(f"Max memory usage: {max_mem}")
        print("\n")
    print("Benchmark complete")