Esempio n. 1
0
def data_to_tensor_tree(tree_test_data: TreeTestData) -> TensorflowRootedTree:
    numpy_tree = NumpyRootedTree(
        node_heights=tree_test_data.node_heights,
        sampling_times=tree_test_data.sampling_times,
        parent_indices=tree_test_data.parent_indices,
    )
    tf_tree = convert_tree_to_tensor(numpy_tree)
    return tf_tree
def test_BirthDeathContemporarySampling_log_prob():
    newick = "((((human:0.024003,(chimp:0.010772,bonobo:0.010772):0.013231):0.012035,gorilla:0.036038):0.033087000000000005,orangutan:0.069125):0.030456999999999998,siamang:0.099582);"
    tree = convert_tree_to_tensor(parse_newick(newick))
    birth_diff_rate = tf.constant(1.0, dtype=DEFAULT_FLOAT_DTYPE_TF)
    relative_death_rate = tf.constant(0.5, dtype=DEFAULT_FLOAT_DTYPE_TF)
    expected = 1.2661341104158121  # From BEAST 2 BirthDeathGerhard08ModelTest

    dist = BirthDeathContemporarySampling(tree.taxon_count, birth_diff_rate,
                                          relative_death_rate)
    res = dist.log_prob(tree)

    assert_allclose(res.numpy(), expected)
def test_sample_weighted(wnv_newick_file, wnv_fasta_file, hky_params):
    alignment = Alignment(wnv_fasta_file)
    numpy_tree = parse_newick(wnv_newick_file)
    tensor_tree = convert_tree_to_tensor(numpy_tree)
    transition_prob_tree = get_transition_probabilities_tree(
        tensor_tree.get_unrooted_tree(), HKY(), **hky_params)

    log_prob_uncompressed = compute_log_prob_uncompressed(
        alignment, transition_prob_tree, hky_params["frequencies"])
    log_prob_compressed = compute_log_prob_compressed(
        alignment, transition_prob_tree, hky_params["frequencies"])
    assert_allclose(log_prob_uncompressed.numpy(), log_prob_compressed.numpy())
Esempio n. 4
0
def transition_prob_tree(flat_tree_test_data):
    tree = convert_tree_to_tensor(
        NumpyRootedTree(
            heights=flat_tree_test_data.heights,
            parent_indices=flat_tree_test_data.parent_indices,
        )
    ).get_unrooted_tree()
    state_count = 5
    transition_probs = tf.fill(
        tree.branch_lengths.shape + (state_count, state_count), 1.0 / state_count
    )
    return TensorflowUnrootedTree(
        branch_lengths=transition_probs,
        topology=numpy_topology_to_tensor(tree.topology),
    )
Esempio n. 5
0
def test_log_prob_conditioned_hky(hky_params, newick_fasta_file_dated):
    from treeflow.acceleration.bito.beagle import (
        phylogenetic_likelihood as beagle_likelihood, )

    newick_file, fasta_file, dated = newick_fasta_file_dated
    subst_model = HKY()
    tensor_tree = convert_tree_to_tensor(
        parse_newick(newick_file)).get_unrooted_tree()
    alignment = Alignment(fasta_file)
    sequences = alignment.get_encoded_sequence_tensor(tensor_tree.taxon_set)
    treeflow_func = lambda blens: Sample(
        LeafCTMC(
            get_transition_probabilities_tree(
                tensor_tree.with_branch_lengths(blens), subst_model, **
                hky_params),
            hky_params["frequencies"],
        ),
        sample_shape=alignment.site_count,
    ).log_prob(sequences)

    beagle_func, _ = beagle_likelihood(fasta_file,
                                       subst_model,
                                       newick_file=newick_file,
                                       dated=dated,
                                       **hky_params)

    blens = tensor_tree.branch_lengths
    with tf.GradientTape() as tf_t:
        tf_t.watch(blens)
        tf_ll = treeflow_func(blens)
    tf_gradient = tf_t.gradient(tf_ll, blens)

    with tf.GradientTape() as bito_t:
        bito_t.watch(blens)
        bito_ll = beagle_func(blens)
    bito_gradient = bito_t.gradient(bito_ll, blens)

    assert_allclose(tf_ll.numpy(), bito_ll.numpy())
    assert_allclose(tf_gradient.numpy(), bito_gradient.numpy())
def tree_from_ratio_test_data(
        ratio_test_data: RatioTestData) -> TensorflowRootedTree:
    return convert_tree_to_tensor(
        numpy_tree_from_ratio_test_data(ratio_test_data), )
Esempio n. 7
0
def test_get_anchor_heights_tensor(ratio_test_data: RatioTestData):
    tree = convert_tree_to_tensor(
        numpy_tree_from_ratio_test_data(ratio_test_data))
    res = get_anchor_heights_tensor(tree.topology, tree.sampling_times)
    assert_allclose(res.numpy(), ratio_test_data.anchor_heights)
Esempio n. 8
0
def treeflow_vi(
    input,
    topology,
    num_steps,
    optimizer,
    substitution_model,
    clock_model,
    tree_prior,
    learning_rate,
    init_values,
    approx_output,
    trace_output,
):
    optimizer = optimizer_classes[optimizer](learning_rate=learning_rate)

    print(f"Parsing topology {topology}")
    tree = convert_tree_to_tensor(parse_newick(topology))

    print(f"Parsing alignment {input}")
    alignment = Alignment(input).get_compressed_alignment()
    encoded_sequences = alignment.get_encoded_sequence_tensor(tree.taxon_set)
    pattern_counts = alignment.get_weights_tensor()

    print(f"Parsing initial values...")
    init_values_dict = (None if init_values is None else {
        key: tf.constant(value, dtype=DEFAULT_FLOAT_DTYPE_TF)
        for key, value in parse_init_values(init_values).items()
    })

    if (clock_model == _FIXED_STRICT and substitution_model == _JC_KEY
            and tree_prior == _CONSTANT_COALESCENT):
        model = get_example_phylo_model(
            taxon_count=tree.taxon_count,
            site_count=alignment.site_count,
            sampling_times=tree.sampling_times,
            pattern_counts=pattern_counts,
            init_values=init_values_dict,
        )
    else:
        raise ValueError(
            "Only example with JC, fixed-rate strict clock, constant coalescent supported for now"
        )
    pinned_model = model.experimental_pin(alignment=encoded_sequences)
    model_names = set(pinned_model._flat_resolve_names())

    init_loc = {
        key: value
        for key, value in init_values_dict.items() if key in model_names
    }
    init_loc["tree"] = tree

    print(f"Running VI for {num_steps} iterations...")
    approx, trace = fit_fixed_topology_variational_approximation(
        model=pinned_model,
        topologies={DEFAULT_TREE_NAME: tree.topology},
        init_loc=init_loc,
        optimizer=optimizer,
        num_steps=num_steps,
    )
    print("Inference complete")
    print("Approx sample:")
    print(tf.nest.map_structure(lambda x: x.numpy(), approx.sample()))

    if approx_output is not None:
        print(f"Saving approximation to {approx_output}...")
        with open(approx_output, "wb") as f:
            pickle.dump(approx, f)  # TODO: Support saving approximation

    if trace_output is not None:
        print(f"Saving trace to {trace_output}...")
        with open(trace_output, "wb") as f:
            pickle.dump(trace, f)

    print("Exiting...")
Esempio n. 9
0
def treeflow_benchmark(input: str, tree: str, replicates: int,
                       output: tp.Optional[io.StringIO], dtype: str,
                       scaler: float, precompile: bool, eager: bool,
                       memory: bool, use_bito: bool):

    print("Parsing input...")
    alignment = Alignment(input)
    numpy_tree = parse_newick(tree, remove_zero_edges=True)
    tf_dtype = DTYPE_MAPPING[dtype]
    tensor_tree = convert_tree_to_tensor(numpy_tree, height_dtype=tf_dtype)
    scaler_tensor = tf.constant(scaler, dtype=tf_dtype)

    computations = [
        "treelikelihood",
        "ratio_transform",
        "ratio_transform_jacobian",
        "constant_coalescent",
    ]
    tasks = ["gradient", "evaluation"]
    jits = [True, False] if eager else [True]

    if output:
        output.write("function,mode,JIT,time,logprob")
        if memory:
            output.write(",max_mem")
        output.write("\n")

    print("Starting benchmark...")
    if use_bito:
        from treeflow.acceleration.bito.instance import get_instance
        dated = not np.allclose(numpy_tree.sampling_times, 0.0)
        bito_instance = get_instance(tree, dated=dated)
    else:
        bito_instance = None
    for (computation, task, jit) in product(computations, tasks, jits):
        print(
            f"Benchmarking {computation} {task}{' in function mode' if jit else ''}..."
        )
        benchmark_args = (
            tensor_tree,
            alignment,
            tf_dtype,
            scaler_tensor,
            computation,
            task,
            jit,
            precompile,
            replicates,
        )
        max_mem: tp.Optional[float]
        if memory:
            from memory_profiler import memory_usage
            max_mem, (time, value) = memory_usage(
                (benchmark, benchmark_args, dict(bito_instance=bito_instance)),
                retval=True,
                max_usage=True,
                max_iterations=1,
            )
        else:
            time, value = benchmark(*benchmark_args,
                                    bito_instance=bito_instance)
            max_mem = None
        if output:
            jit_str = "on" if jit else "off"
            output.write(
                f"{computation},{task},{jit_str},{time},{value if value else ''}"
            )
            if max_mem is not None:
                output.write(f",{max_mem}")
            output.write("\n")
        if max_mem is not None:
            print(f"Max memory usage: {max_mem}")
        print("\n")
    print("Benchmark complete")
def get_treeflow_forward_func(numpy_tree, tensor_constant):
    anchor_heights = tensor_constant(get_anchor_heights(numpy_tree))
    return NodeHeightRatioBijector(
        convert_tree_to_tensor(numpy_tree).topology, anchor_heights).forward