Exemplo n.º 1
0
def test_taxa():
    taxa = {
        "id":
        "taxa",
        "type":
        "torchtree.evolution.taxa.Taxa",
        "taxa": [
            {
                "id": "A_Belgium_2_1981",
                "type": "torchtree.evolution.taxa.Taxon",
                "attributes": {
                    "date": 1981
                },
            },
            {
                "id": "A_ChristHospital_231_1982",
                "type": "torchtree.evolution.taxa.Taxon",
                "attributes": {
                    "date": 1982
                },
            },
        ],
    }
    dic = {}
    taxa = Taxa.from_json(taxa, dic)
    assert len(taxa) == 2
    assert taxa[1].id == 'A_ChristHospital_231_1982'
    assert taxa[1]['date'] == 1982
Exemplo n.º 2
0
def test_site_pattern():
    taxa = Taxa(None, [Taxon(taxon, {}) for taxon in 'ABCD'])
    sequences = [
        Sequence(taxon, seq)
        for taxon, seq in zip('ABCD', ['AAG', 'AAC', 'AAC', 'AAT'])
    ]
    alignment = Alignment(None, sequences, taxa, NucleotideDataType(None))

    site_pattern = SitePattern(None, alignment)
    partials, weights = site_pattern.compute_tips_partials()
    assert torch.all(weights == torch.tensor([[2.0, 1.0]]))

    assert partials[0].shape == torch.Size([4, 2])
Exemplo n.º 3
0
def _prepare_tiny(tiny_newick_file, tiny_fasta_file):
    tree, dna = read_tree_and_alignment(tiny_newick_file, tiny_fasta_file,
                                        False, False)
    branch_lengths = torch.tensor([
        float(node.edge_length)
        for node in sorted(list(tree.postorder_node_iter())[:-1],
                           key=lambda x: x.index)
    ], )
    indices = []
    for node in tree.postorder_internal_node_iter():
        indices.append([node.index] +
                       [child.index for child in node.child_nodes()])

    sequences = []
    taxa = []
    for taxon, seq in dna.items():
        sequences.append(Sequence(taxon.label, str(seq)))
        taxa.append(Taxon(taxon.label, None))

    partials, weights_tensor = compress_alignment(
        Alignment(None, sequences, Taxa(None, taxa), NucleotideDataType(None)))
    partials.extend([None] * (len(dna) - 1))
    return partials, weights_tensor, indices, branch_lengths
Exemplo n.º 4
0
def constant_coalescent(args):
    tree = read_tree(args.tree, True, True)
    taxa_count = len(tree.taxon_namespace)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0])
    )
    tree_model = TimeTreeModel("tree", tree, Taxa('taxa', taxa), ratios_root_height)
    tree_model._internal_heights.tensor = heights_from_branch_lengths(tree)
    pop_size = torch.tensor([4.0])

    print('JIT off')

    @benchmark
    def fn(tree_model, pop_size):
        return ConstantCoalescent(pop_size).log_prob(tree_model.node_heights)

    @benchmark
    def fn_grad(tree_model, pop_size):
        log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights)
        log_p.backward()
        ratios_root_height.tensor.grad.data.zero_()
        pop_size.grad.data.zero_()
        return log_p

    total_time, log_p = fn(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} evaluations: {total_time} {log_p}')

    ratios_root_height.requires_grad = True
    pop_size.requires_grad_(True)
    grad_total_time, grad_log_p = fn_grad(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(
            f"coalescent,evaluation,off,{total_time},{log_p.squeeze().item()}\n"
        )
        args.output.write(
            f"coalescent,gradient,off,{grad_total_time},{grad_log_p.squeeze().item()}\n"
        )

    if args.debug:
        tree_model.heights_need_update = True
        log_p = ConstantCoalescent(pop_size).log_prob(tree_model.node_heights)
        log_p.backward()
        print('gradient ratios: ', ratios_root_height.grad)
        print('gradient pop size: ', pop_size.grad)
        ratios_root_height.tensor.grad.data.zero_()
        pop_size.grad.data.zero_()

    print('JIT on')
    log_prob_script = torch.jit.script(log_prob)

    @benchmark
    def fn_jit(tree_model, pop_size):
        return log_prob_script(tree_model.node_heights, pop_size)

    @benchmark
    def fn_grad_jit(tree_model, pop_size):
        log_p = log_prob_script(tree_model.node_heights, pop_size)
        log_p.backward()
        ratios_root_height.tensor.grad.data.zero_()
        pop_size.grad.data.zero_()
        return log_p

    ratios_root_height.requires_grad = False
    pop_size.requires_grad_(False)
    total_time, log_p = fn_jit(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} evaluations: {total_time} {log_p}')

    ratios_root_height.requires_grad = True
    pop_size.requires_grad_(True)
    grad_total_time, grad_log_p = fn_grad_jit(args.replicates, tree_model, pop_size)
    print(f'  {args.replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(
            f"coalescent,evaluation,on,{total_time},{log_p.squeeze().item()}\n"
        )
        args.output.write(
            f"coalescent,gradient,on,{grad_total_time},{grad_log_p.squeeze().item()}\n"
        )

    if args.all:
        print('make sampling times unique and count them:')

        @benchmark
        def fn3(tree_model, pop_size):
            tree_model.heights_need_update = True
            node_heights = torch.cat(
                (x, tree_model.node_heights[..., tree_model.taxa_count :])
            )
            return log_prob_squashed(
                pop_size, node_heights, counts, tree_model.taxa_count
            )

        @benchmark
        def fn3_grad(tree_model, ratios_root_height, pop_size):
            tree_model.heights_need_update = True
            node_heights = torch.cat(
                (x, tree_model.node_heights[..., tree_model.taxa_count :])
            )
            log_p = log_prob_squashed(
                pop_size, node_heights, counts, tree_model.taxa_count
            )
            log_p.backward()
            ratios_root_height.tensor.grad.data.zero_()
            pop_size.grad.data.zero_()
            return log_p

        x, counts = torch.unique(tree_model.sampling_times, return_counts=True)
        counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1))))

        with torch.no_grad():
            total_time, log_p = fn3(args.replicates, tree_model, pop_size)
        print(f'  {args.replicates} evaluations: {total_time} ({log_p})')

        total_time, log_p = fn3_grad(
            args.replicates, tree_model, ratios_root_height, pop_size
        )
        print(f'  {args.replicates} gradient evaluations: {total_time}')
Exemplo n.º 5
0
def ratio_transform(args):
    replicates = args.replicates
    tree = read_tree(args.tree, True, True)
    taxa_count = len(tree.taxon_namespace)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        return tree_model.transform(
            ratios_root_height,
        )

    @benchmark
    def fn_grad(ratios_root_height):
        heights = tree_model.transform(
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    total_time, heights = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    grad_total_time, heights = fn_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {grad_total_time}')

    if args.output:
        args.output.write(f"ratio_transform,evaluation,off,{total_time},\n")
        args.output.write(f"ratio_transform,gradient,off,{grad_total_time},\n")

    print('  JIT off')

    @benchmark
    def fn2(ratios_root_height):
        return transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad(ratios_root_height):
        heights = transform(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('  JIT on')
    transform_script = torch.jit.script(transform)

    @benchmark
    def fn2_jit(ratios_root_height):
        return transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn2_grad_jit(ratios_root_height):
        heights = transform_script(
            tree_model.transform._forward_indices,
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn2_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn2_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT off')

    @benchmark
    def fn3(ratios_root_height):
        return transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad(ratios_root_height):
        heights = transform2(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')

    print('ratio_transform v2 JIT on')
    transform2_script = torch.jit.script(transform2)

    @benchmark
    def fn3_jit(ratios_root_height):
        return transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )

    @benchmark
    def fn3_grad_jit(ratios_root_height):
        heights = transform2_script(
            tree_model.transform._forward_indices.tolist(),
            tree_model.transform._bounds,
            ratios_root_height,
        )
        heights.backward(torch.ones_like(ratios_root_height))
        ratios_root_height.grad.data.zero_()
        return heights

    ratios_root_height.requires_grad = False
    total_time, heights = fn3_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} evaluations: {total_time}')

    ratios_root_height.requires_grad = True
    total_time, heights = fn3_grad_jit(args.replicates, ratios_root_height.tensor)
    print(f'  {replicates} gradient evaluations: {total_time}')
Exemplo n.º 6
0
def ratio_transform_jacobian(args):
    tree = read_tree(args.tree, True, True)
    taxa = []
    for node in tree.leaf_node_iter():
        taxa.append(Taxon(node.label, {'date': node.date}))
    taxa_count = len(taxa)
    ratios_root_height = Parameter(
        "internal_heights", torch.tensor([0.5] * (taxa_count - 1) + [20])
    )
    tree_model = ReparameterizedTimeTreeModel(
        "tree", tree, Taxa('taxa', taxa), ratios_root_height
    )

    ratios_root_height.tensor = tree_model.transform.inv(
        heights_from_branch_lengths(tree)
    )

    @benchmark
    def fn(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        return tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )

    @benchmark
    def fn_grad(ratios_root_height):
        internal_heights = tree_model.transform(ratios_root_height)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height, internal_heights
        )
        log_det_jac.backward()
        ratios_root_height.grad.data.zero_()
        return log_det_jac

    print('  JIT off')
    total_time, log_det_jac = fn(args.replicates, ratios_root_height.tensor)
    print(f'  {args.replicates} evaluations: {total_time} ({log_det_jac})')

    ratios_root_height.requires_grad = True
    grad_total_time, grad_log_det_jac = fn_grad(
        args.replicates, ratios_root_height.tensor
    )
    print(
        f'  {args.replicates} gradient evaluations: {grad_total_time}'
        f' ({grad_log_det_jac})'
    )

    if args.output:
        args.output.write(
            f"ratio_transform_jacobian,evaluation,off,{total_time},"
            f"{log_det_jac.squeeze().item()}\n"
        )
        args.output.write(
            f"ratio_transform_jacobian,gradient,off,{grad_total_time},"
            f"{grad_log_det_jac.squeeze().item()}\n"
        )

    if args.debug:
        internal_heights = tree_model.transform(ratios_root_height.tensor)
        log_det_jac = tree_model.transform.log_abs_det_jacobian(
            ratios_root_height.tensor, internal_heights
        )
        log_det_jac.backward()
        print(ratios_root_height.grad)
Exemplo n.º 7
0
def unrooted_treelikelihood(args, subst_model):
    tree, dna = read_tree_and_alignment(args.tree, args.input, True, True)
    branch_lengths = torch.tensor(
        [
            float(node.edge_length) * args.scaler
            for node in sorted(
                list(tree.postorder_node_iter())[:-1], key=lambda x: x.index
            )
        ],
    )
    branch_lengths = torch.clamp(branch_lengths, min=1.0e-6)
    indices = []
    for node in tree.postorder_internal_node_iter():
        indices.append([node.index] + [child.index for child in node.child_nodes()])

    sequences = []
    taxa = []
    for taxon, seq in dna.items():
        sequences.append(Sequence(taxon.label, str(seq)))
        taxa.append(Taxon(taxon.label, None))

    partials, weights_tensor = compress_alignment(
        Alignment(None, sequences, Taxa(None, taxa), NucleotideDataType('nuc'))
    )
    partials.extend([None] * (len(dna) - 1))
    freqs = subst_model.frequencies
    proportions = torch.tensor([[[1.0]]])
    threshold = 1.0e-20 if args.dtype == 'float32' else 1.0e-40

    print('treelikelihood v1')

    @benchmark
    def fn_safe(bls):
        mats = subst_model.p_t(bls)
        return calculate_treelikelihood_discrete_safe(
            partials, weights_tensor, indices, mats, freqs, proportions, threshold
        )

    @benchmark
    def fn_safe_grad(bls):
        mats = subst_model.p_t(bls)
        log_prob = calculate_treelikelihood_discrete_safe(
            partials, weights_tensor, indices, mats, freqs, proportions, threshold
        )
        log_prob.backward()
        return log_prob

    blens = branch_lengths.unsqueeze(0).unsqueeze(-1)

    total_time, log_prob = fn_safe(args.replicates, blens)
    print(f'  {args.replicates} evaluations: {total_time} ({log_prob})')

    blens.requires_grad = True
    for p in subst_model.parameters():
        p.requires_grad = True

    grad_total_time, grad_log_prob = fn_safe_grad(args.replicates, blens)
    print(
        f'  {args.replicates} gradient evaluations: {grad_total_time} ({grad_log_prob}'
    )

    if args.output:
        name = '' if isinstance(subst_model, JC69) else type(subst_model).__name__
        args.output.write(
            f"treelikelihood{name},evaluation,off,{total_time},"
            f"{log_prob.squeeze().item()}\n"
        )
        args.output.write(
            f"treelikelihood{name},gradient,off,{grad_total_time},"
            f"{grad_log_prob.squeeze().item()}\n"
        )

    if args.all and isinstance(subst_model, JC69):
        tip_partials = torch.stack(partials[: len(sequences)])
        indices = torch.tensor(indices)

        print('treelikelihood v2 JIT off')

        @benchmark
        def fn2(bls):
            mats = p_t(bls)
            return calculate_treelikelihood_discrete_split(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )

        @benchmark
        def fn2_grad(bls):
            mats = p_t(bls)
            log_prob = calculate_treelikelihood_discrete_split(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )
            log_prob.backward()
            return log_prob

        with torch.no_grad():
            total_time, log_prob = fn2(args.replicates, blens)
        print(f'  {args.replicates} evaluations: {total_time} ({log_prob})')

        total_time, log_prob = fn2_grad(args.replicates, blens)
        print(f'  {args.replicates} gradient evaluations: {total_time}')

        print('treelikelihood v2 JIT on')
        like2_jit = torch.jit.script(calculate_treelikelihood_discrete_split)

        @benchmark
        def fn2_jit(bls):
            mats = p_t(bls)
            return like2_jit(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )

        @benchmark
        def fn2_grad_jit(bls):
            mats = p_t(bls)
            log_prob = like2_jit(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )
            log_prob.backward()
            return log_prob

        with torch.no_grad():
            total_time, log_prob = fn2_jit(args.replicates, blens)
        print(f'  {args.replicates} evaluations: {total_time}')

        total_time, log_prob = fn2_grad_jit(args.replicates, blens)
        print(f'  {args.replicates} gradient evaluations: {total_time}')

        print('treelikelihood v3 JIT off')

        @benchmark
        def fn3(bls):
            mats = p_t(bls)
            return calculate_treelikelihood_discrete_cat(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )

        @benchmark
        def fn3_grad(bls):
            mats = p_t(bls)
            log_prob = calculate_treelikelihood_discrete_cat(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )
            log_prob.backward()
            return log_prob

        with torch.no_grad():
            total_time, log_prob = fn3(args.replicates, blens)
        print(f'  {args.replicates} evaluations: {total_time} ({log_prob})')

        total_time, log_prob = fn3_grad(args.replicates, blens)
        print(f'  {args.replicates} gradient evaluations: {total_time}')

        print('treelikelihood v3 JIT on')
        like3_jit = torch.jit.script(calculate_treelikelihood_discrete_cat)

        @benchmark
        def fn3_jit(bls):
            mats = p_t(bls)
            return like3_jit(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )

        @benchmark
        def fn3_grad_jit(bls):
            mats = p_t(bls)
            log_prob = like3_jit(
                tip_partials, weights_tensor, indices, mats, freqs, proportions
            )
            log_prob.backward()
            return log_prob

        with torch.no_grad():
            total_time, log_prob = fn3_jit(args.replicates, blens)
        print(f'  {args.replicates} evaluations: {total_time} ({log_prob})')

        total_time, log_prob = fn3_grad_jit(args.replicates, blens)
        print(f'  {args.replicates} gradient evaluations: {total_time}')
Exemplo n.º 8
0
def test_treelikelihood_weibull(flu_a_tree_file, flu_a_fasta_file):
    taxa_list = []
    with open(flu_a_fasta_file) as fp:
        for line in fp:
            if line.startswith('>'):
                taxon = line[1:].strip()
                date = float(taxon.split('_')[-1])
                taxa_list.append(Taxon(taxon, {'date': date}))
    taxa = Taxa('taxa', taxa_list)

    site_pattern = {
        'id': 'sp',
        'type': 'torchtree.evolution.site_pattern.SitePattern',
        'alignment': {
            "id": "alignment",
            "type": "torchtree.evolution.alignment.Alignment",
            'datatype': 'nucleotide',
            'file': flu_a_fasta_file,
            'taxa': 'taxa',
        },
    }
    subst_model = JC69('jc')
    site_model = WeibullSiteModel('site_model',
                                  Parameter(None, torch.tensor([[0.1]])), 4)
    ratios = [0.5] * 67
    root_height = [20.0]
    with open(flu_a_tree_file) as fp:
        newick = fp.read().strip()

    dic = {'taxa': taxa}
    tree_model = ReparameterizedTimeTreeModel.from_json(
        ReparameterizedTimeTreeModel.json_factory(
            'tree_model', newick, ratios, root_height, 'taxa',
            **{'keep_branch_lengths': True}),
        dic,
    )
    branch_model = StrictClockModel(None,
                                    Parameter(None, torch.tensor([[0.001]])),
                                    tree_model)
    dic['tree_model'] = tree_model
    dic['site_model'] = site_model
    dic['subst_model'] = subst_model
    dic['branch_model'] = branch_model

    like = likelihood.TreeLikelihoodModel.from_json(
        {
            'id': 'like',
            'type': 'torchtree.tree_likelihood.TreeLikelihoodModel',
            'tree_model': 'tree_model',
            'site_model': 'site_model',
            'site_pattern': site_pattern,
            'substitution_model': 'subst_model',
            'branch_model': 'branch_model',
        },
        dic,
    )
    assert torch.allclose(torch.tensor([-4618.2062529058]), like())

    branch_model._rates.tensor = branch_model._rates.tensor.repeat(3, 1)
    site_model._shape.tensor = site_model._shape.tensor.repeat(3, 1)
    tree_model._internal_heights.tensor = tree_model._internal_heights.tensor.repeat(
        3, 68)
    assert torch.allclose(torch.tensor([[-4618.2062529058] * 3]), like())
Exemplo n.º 9
0
def build_alignment(file_name, data_type):
    sequences = read_fasta_sequences(file_name)
    taxa = Taxa('taxa', [Taxon(sequence.taxon, {}) for sequence in sequences])
    return Alignment('alignment', sequences, taxa, data_type)