Пример #1
0
def run(args):
    # check alignment type, set flags, read in if VCF
    is_vcf = False
    ref = None
    tree_meta = {'alignment': args.alignment}
    attributes = ['branch_length']
    # check if tree is provided an can be read
    for fmt in ["newick", "nexus"]:
        try:
            T = Phylo.read(args.tree, fmt)
            tree_meta['input_tree'] = args.tree
            break
        except:
            pass
    if T is None:
        print("ERROR: reading tree from %s failed." % args.tree)
        return -1

    if not args.alignment:
        # fake alignment to appease treetime when only using it for naming nodes...
        if args.ancestral or args.timetree:
            print(
                "ERROR: alignment is required for ancestral reconstruction or timetree inference"
            )
            return -1
        from Bio import SeqRecord, Seq, Align
        seqs = []
        for n in T.get_terminals():
            seqs.append(
                SeqRecord.SeqRecord(seq=Seq.Seq('ACGT'),
                                    id=n.name,
                                    name=n.name,
                                    description=''))
        aln = Align.MultipleSeqAlignment(seqs)
    elif any([args.alignment.lower().endswith(x)
              for x in ['.vcf', '.vcf.gz']]):
        if not args.vcf_reference:
            print(
                "ERROR: a reference Fasta is required with VCF-format alignments"
            )
            return -1

        compress_seq = read_vcf(args.alignment, args.vcf_reference)
        sequences = compress_seq['sequences']
        ref = compress_seq['reference']
        is_vcf = True
        aln = sequences
    else:
        aln = args.alignment

    if args.output:
        tree_fname = args.output
    else:
        tree_fname = '.'.join(args.alignment.split('.')[:-1]) + '_tt.nwk'

    if args.timetree and T:
        if args.metadata is None:
            print(
                "ERROR: meta data with dates is required for time tree reconstruction"
            )
            return -1
        metadata, columns = read_metadata(args.metadata)
        if args.year_limit:
            args.year_limit.sort()
        dates = get_numerical_dates(metadata,
                                    fmt=args.date_fmt,
                                    min_max_year=args.year_limit)
        for n in T.get_terminals():
            if n.name in metadata and 'date' in metadata[n.name]:
                n.raw_date = metadata[n.name]['date']

        if args.root and len(
                args.root
        ) == 1:  #if anything but a list of seqs, don't send as a list
            args.root = args.root[0]

        tt = timetree(
            tree=T,
            aln=aln,
            ref=ref,
            dates=dates,
            confidence=args.date_confidence,
            reroot=args.root or 'best',
            Tc=args.coalescent if args.coalescent is not None else
            0.01,  #Otherwise can't set to 0
            use_marginal=args.time_marginal or False,
            branch_length_mode=args.branch_length_mode or 'auto',
            clock_rate=args.clock_rate,
            n_iqd=args.n_iqd)

        tree_meta['clock'] = {
            'rate': tt.date2dist.clock_rate,
            'intercept': tt.date2dist.intercept,
            'rtt_Tmrca': -tt.date2dist.intercept / tt.date2dist.clock_rate
        }
        attributes.extend([
            'numdate', 'clock_length', 'mutation_length', 'mutations',
            'raw_date', 'date'
        ])
        if not is_vcf:
            attributes.extend(['sequence'
                               ])  #don't add sequences if VCF - huge!
        if args.date_confidence:
            attributes.append('num_date_confidence')
    elif args.ancestral in ['joint', 'marginal']:
        tt = ancestral_sequence_inference(
            tree=T,
            aln=aln,
            ref=ref,
            marginal=args.ancestral,
            optimize_branch_length=args.branchlengths,
            branch_length_mode=args.branch_length_mode)
        attributes.extend(['mutation_length', 'mutations'])
        if not is_vcf:
            attributes.extend(['sequence'
                               ])  #don't add sequences if VCF - huge!
    else:
        from treetime import TreeAnc
        # instantiate treetime for the sole reason to name internal nodes
        tt = TreeAnc(tree=T, aln=aln, ref=ref, gtr='JC69', verbose=1)

    if is_vcf:
        #TreeTime overwrites ambig sites on tips during ancestral reconst.
        #Put these back in tip sequences now, to avoid misleading
        tt.recover_var_ambigs()

    tree_meta['nodes'] = prep_tree(T, attributes, is_vcf)

    if T:
        import json
        tree_success = Phylo.write(T,
                                   tree_fname,
                                   'newick',
                                   format_branch_length='%1.8f')
        if args.node_data:
            node_data_fname = args.node_data
        else:
            node_data_fname = '.'.join(
                args.alignment.split('.')[:-1]) + '.node_data'

        with open(node_data_fname, 'w') as ofile:
            meta_success = json.dump(tree_meta, ofile)

    #If VCF and ancestral reconst. was done, output VCF including new ancestral seqs
    if is_vcf and (args.ancestral or args.timetree):
        if args.output_vcf:
            vcf_fname = args.output_vcf
        else:
            vcf_fname = '.'.join(args.alignment.split('.')[:-1]) + '.vcf'
        write_vcf(tt.get_tree_dict(keep_var_ambigs=True), vcf_fname)

        return 0 if (tree_success and meta_success) else -1
    else:
        return -1
Пример #2
0
def scan_homoplasies(params):
    """
    the function implementing treetime homoplasies
    """
    if assure_tree(params, tmp_dir='homoplasy_tmp'):
        return 1

    gtr = create_gtr(params)

    ###########################################################################
    ### READ IN VCF
    ###########################################################################
    #sets ref and fixed_pi to None if not VCF
    aln, ref, fixed_pi = read_if_vcf(params)
    is_vcf = True if ref is not None else False

    ###########################################################################
    ### ANCESTRAL RECONSTRUCTION
    ###########################################################################
    treeanc = TreeAnc(params.tree, aln=aln, ref=ref, gtr=gtr, verbose=1,
                      fill_overhangs=True)
    if treeanc.aln is None: # if alignment didn't load, exit
        return 1

    if is_vcf:
        L = len(ref) + params.const
    else:
        L = treeanc.data.full_length + params.const

    N_seq = len(treeanc.aln)
    N_tree = treeanc.tree.count_terminals()
    if params.rescale!=1.0:
        for n in treeanc.tree.find_clades():
            n.branch_length *= params.rescale
            n.mutation_length = n.branch_length

    print("read alignment from file %s with %d sequences of length %d"%(params.aln,N_seq,L))
    print("read tree from file %s with %d leaves"%(params.tree,N_tree))
    print("\ninferring ancestral sequences...")

    ndiff = treeanc.infer_ancestral_sequences('ml', infer_gtr=params.gtr=='infer',
                                      marginal=False, fixed_pi=fixed_pi)
    print("...done.")

    if is_vcf:
        treeanc.recover_var_ambigs()

    ###########################################################################
    ### analysis of reconstruction
    ###########################################################################
    from collections import defaultdict
    from scipy.stats import poisson
    offset = 0 if params.zero_based else 1

    if params.drms:
        DRM_info = read_in_DRMs(params.drms, offset)
        drms = DRM_info['DRMs']

    # construct dictionaries gathering mutations and positions
    mutations = defaultdict(list)
    positions = defaultdict(list)
    terminal_mutations = defaultdict(list)
    for n in treeanc.tree.find_clades():
        if n.up is None:
            continue

        if len(n.mutations):
            for (a,pos, d) in n.mutations:
                if '-' not in [a,d] and 'N' not in [a,d]:
                    mutations[(a,pos+offset,d)].append(n)
                    positions[pos+offset].append(n)
            if n.is_terminal():
                for (a,pos, d) in n.mutations:
                    if '-' not in [a,d] and 'N' not in [a,d]:
                        terminal_mutations[(a,pos+offset,d)].append(n)

    # gather homoplasic mutations by strain
    mutation_by_strain = defaultdict(list)
    for n in treeanc.tree.get_terminals():
        for a,pos,d in n.mutations:
            if pos+offset in positions and len(positions[pos+offset])>1:
                if '-' not in [a,d] and 'N' not in [a,d]:
                    mutation_by_strain[n.name].append([(a,pos+offset,d), len(positions[pos])])


    # total_branch_length is the expected number of substitutions
    # corrected_branch_length is the expected number of observable substitutions
    # (probability of an odd number of substitutions at a particular site)
    total_branch_length = treeanc.tree.total_branch_length()
    corrected_branch_length = np.sum([np.exp(-x.branch_length)*np.sinh(x.branch_length)
                                      for x in treeanc.tree.find_clades()])
    corrected_terminal_branch_length = np.sum([np.exp(-x.branch_length)*np.sinh(x.branch_length)
                                      for x in treeanc.tree.get_terminals()])
    expected_mutations = L*corrected_branch_length
    expected_terminal_mutations = L*corrected_terminal_branch_length

    # make histograms and sum mutations in different categories
    multiplicities = np.bincount([len(x) for x in mutations.values()])
    total_mutations = np.sum([len(x) for x in mutations.values()])

    multiplicities_terminal = np.bincount([len(x) for x in terminal_mutations.values()])
    terminal_mutation_count = np.sum([len(x) for x in terminal_mutations.values()])

    multiplicities_positions = np.bincount([len(x) for x in positions.values()])
    multiplicities_positions[0] = L - np.sum(multiplicities_positions)

    ###########################################################################
    ### Output the distribution of times particular mutations are observed
    ###########################################################################
    print("\nThe TOTAL tree length is %1.3e and %d mutations were observed."
          %(total_branch_length,total_mutations))
    print("Of these %d mutations,"%total_mutations
            +"".join(['\n\t - %d occur %d times'%(n,mi)
                      for mi,n in enumerate(multiplicities) if n]))
    # additional optional output this for terminal mutations only
    if params.detailed:
        print("\nThe TERMINAL branch length is %1.3e and %d mutations were observed."
              %(corrected_terminal_branch_length,terminal_mutation_count))
        print("Of these %d mutations,"%terminal_mutation_count
                +"".join(['\n\t - %d occur %d times'%(n,mi)
                          for mi,n in enumerate(multiplicities_terminal) if n]))


    ###########################################################################
    ### Output the distribution of times mutations at particular positions are observed
    ###########################################################################
    print("\nOf the %d positions in the genome,"%L
            +"".join(['\n\t - %d were hit %d times (expected %1.2f)'%(n,mi,L*poisson.pmf(mi,1.0*total_mutations/L))
                      for mi,n in enumerate(multiplicities_positions) if n]))


    # compare that distribution to a Poisson distribution with the same mean
    p = poisson.pmf(np.arange(10*multiplicities_positions.max()),1.0*total_mutations/L)
    print("\nlog-likelihood difference to Poisson distribution with same mean: %1.3e"%(
            - L*np.sum(p*np.log(p+1e-100))
            + np.sum(multiplicities_positions*np.log(p[:len(multiplicities_positions)]+1e-100))))


    ###########################################################################
    ### Output the mutations that are observed most often
    ###########################################################################
    if params.drms:
        print("\n\nThe ten most homoplasic mutations are:\n\tmut\tmultiplicity\tDRM details (gene drug AAmut)")
        mutations_sorted = sorted(mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
        for mut, val in mutations_sorted[:params.n]:
            if len(val)>1:
                print("\t%s%d%s\t%d\t%s"%(mut[0], mut[1], mut[2], len(val),
                    " ".join([drms[mut[1]]['gene'], drms[mut[1]]['drug'], drms[mut[1]]['alt_base'][mut[2]]]) if mut[1] in drms else ""))
            else:
                break
    else:
        print("\n\nThe ten most homoplasic mutations are:\n\tmut\tmultiplicity")
        mutations_sorted = sorted(mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
        for mut, val in mutations_sorted[:params.n]:
            if len(val)>1:
                print("\t%s%d%s\t%d"%(mut[0], mut[1], mut[2], len(val)))
            else:
                break

    # optional output specifically for mutations on terminal branches
    if params.detailed:
        if params.drms:
            print("\n\nThe ten most homoplasic mutation on terminal branches are:\n\tmut\tmultiplicity\tDRM details (gene drug AAmut)")
            terminal_mutations_sorted = sorted(terminal_mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
            for mut, val in terminal_mutations_sorted[:params.n]:
                if len(val)>1:
                    print("\t%s%d%s\t%d\t%s"%(mut[0], mut[1], mut[2], len(val),
                        " ".join([drms[mut[1]]['gene'], drms[mut[1]]['drug'], drms[mut[1]]['alt_base'][mut[2]]]) if mut[1] in drms else ""))
                else:
                    break
        else:
            print("\n\nThe ten most homoplasic mutation on terminal branches are:\n\tmut\tmultiplicity")
            terminal_mutations_sorted = sorted(terminal_mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
            for mut, val in terminal_mutations_sorted[:params.n]:
                if len(val)>1:
                    print("\t%s%d%s\t%d"%(mut[0], mut[1], mut[2], len(val)))
                else:
                    break

    ###########################################################################
    ### Output strains that have many homoplasic mutations
    ###########################################################################
    # TODO: add statistical criterion
    if params.detailed:
        if params.drms:
            print("\n\nTaxons that carry positions that mutated elsewhere in the tree:\n\ttaxon name\t#of homoplasic mutations\t# DRM")
            mutation_by_strain_sorted = sorted(mutation_by_strain.items(), key=lambda x:len(x[1]), reverse=True)
            for name, val in mutation_by_strain_sorted[:params.n]:
                if len(val):
                    print("\t%s\t%d\t%d"%(name, len(val),
                        len([mut for mut,l in val if mut[1] in drms])))
        else:
            print("\n\nTaxons that carry positions that mutated elsewhere in the tree:\n\ttaxon name\t#of homoplasic mutations")
            mutation_by_strain_sorted = sorted(mutation_by_strain.items(), key=lambda x:len(x[1]), reverse=True)
            for name, val in mutation_by_strain_sorted[:params.n]:
                if len(val):
                    print("\t%s\t%d"%(name, len(val)))


    return 0
Пример #3
0
def scan_homoplasies(params):
    """
    the function implementing treetime homoplasies
    """
    if assure_tree(params, tmp_dir='homoplasy_tmp'):
        return 1

    gtr = create_gtr(params)

    ###########################################################################
    ### READ IN VCF
    ###########################################################################
    #sets ref and fixed_pi to None if not VCF
    aln, ref, fixed_pi = read_if_vcf(params)
    is_vcf = True if ref is not None else False

    ###########################################################################
    ### ANCESTRAL RECONSTRUCTION
    ###########################################################################
    treeanc = TreeAnc(params.tree, aln=aln, ref=ref, gtr=gtr, verbose=1,
                      fill_overhangs=True)
    if treeanc.aln is None: # if alignment didn't load, exit
        return 1

    if is_vcf:
        L = len(ref) + params.const
    else:
        L = treeanc.aln.get_alignment_length() + params.const

    N_seq = len(treeanc.aln)
    N_tree = treeanc.tree.count_terminals()
    if params.rescale!=1.0:
        for n in treeanc.tree.find_clades():
            n.branch_length *= params.rescale
            n.mutation_length = n.branch_length

    print("read alignment from file %s with %d sequences of length %d"%(params.aln,N_seq,L))
    print("read tree from file %s with %d leaves"%(params.tree,N_tree))
    print("\ninferring ancestral sequences...")

    ndiff = treeanc.infer_ancestral_sequences('ml', infer_gtr=params.gtr=='infer',
                                      marginal=False, fixed_pi=fixed_pi)
    print("...done.")
    if ndiff==ttconf.ERROR: # if reconstruction failed, exit
        print("Something went wrong during ancestral reconstruction, please check your input files.", file=sys.stderr)
        return 1
    else:
        print("...done.")

    if is_vcf:
        treeanc.recover_var_ambigs()

    ###########################################################################
    ### analysis of reconstruction
    ###########################################################################
    from collections import defaultdict
    from scipy.stats import poisson
    offset = 0 if params.zero_based else 1

    if params.drms:
        DRM_info = read_in_DRMs(params.drms, offset)
        drms = DRM_info['DRMs']

    # construct dictionaries gathering mutations and positions
    mutations = defaultdict(list)
    positions = defaultdict(list)
    terminal_mutations = defaultdict(list)
    for n in treeanc.tree.find_clades():
        if n.up is None:
            continue

        if len(n.mutations):
            for (a,pos, d) in n.mutations:
                if '-' not in [a,d] and 'N' not in [a,d]:
                    mutations[(a,pos+offset,d)].append(n)
                    positions[pos+offset].append(n)
            if n.is_terminal():
                for (a,pos, d) in n.mutations:
                    if '-' not in [a,d] and 'N' not in [a,d]:
                        terminal_mutations[(a,pos+offset,d)].append(n)

    # gather homoplasic mutations by strain
    mutation_by_strain = defaultdict(list)
    for n in treeanc.tree.get_terminals():
        for a,pos,d in n.mutations:
            if pos+offset in positions and len(positions[pos+offset])>1:
                if '-' not in [a,d] and 'N' not in [a,d]:
                    mutation_by_strain[n.name].append([(a,pos+offset,d), len(positions[pos])])


    # total_branch_length is the expected number of substitutions
    # corrected_branch_length is the expected number of observable substitutions
    # (probability of an odd number of substitutions at a particular site)
    total_branch_length = treeanc.tree.total_branch_length()
    corrected_branch_length = np.sum([np.exp(-x.branch_length)*np.sinh(x.branch_length)
                                      for x in treeanc.tree.find_clades()])
    corrected_terminal_branch_length = np.sum([np.exp(-x.branch_length)*np.sinh(x.branch_length)
                                      for x in treeanc.tree.get_terminals()])
    expected_mutations = L*corrected_branch_length
    expected_terminal_mutations = L*corrected_terminal_branch_length

    # make histograms and sum mutations in different categories
    multiplicities = np.bincount([len(x) for x in mutations.values()])
    total_mutations = np.sum([len(x) for x in mutations.values()])

    multiplicities_terminal = np.bincount([len(x) for x in terminal_mutations.values()])
    terminal_mutation_count = np.sum([len(x) for x in terminal_mutations.values()])

    multiplicities_positions = np.bincount([len(x) for x in positions.values()])
    multiplicities_positions[0] = L - np.sum(multiplicities_positions)

    ###########################################################################
    ### Output the distribution of times particular mutations are observed
    ###########################################################################
    print("\nThe TOTAL tree length is %1.3e and %d mutations were observed."
          %(total_branch_length,total_mutations))
    print("Of these %d mutations,"%total_mutations
            +"".join(['\n\t - %d occur %d times'%(n,mi)
                      for mi,n in enumerate(multiplicities) if n]))
    # additional optional output this for terminal mutations only
    if params.detailed:
        print("\nThe TERMINAL branch length is %1.3e and %d mutations were observed."
              %(corrected_terminal_branch_length,terminal_mutation_count))
        print("Of these %d mutations,"%terminal_mutation_count
                +"".join(['\n\t - %d occur %d times'%(n,mi)
                          for mi,n in enumerate(multiplicities_terminal) if n]))


    ###########################################################################
    ### Output the distribution of times mutations at particular positions are observed
    ###########################################################################
    print("\nOf the %d positions in the genome,"%L
            +"".join(['\n\t - %d were hit %d times (expected %1.2f)'%(n,mi,L*poisson.pmf(mi,1.0*total_mutations/L))
                      for mi,n in enumerate(multiplicities_positions) if n]))


    # compare that distribution to a Poisson distribution with the same mean
    p = poisson.pmf(np.arange(10*multiplicities_positions.max()),1.0*total_mutations/L)
    print("\nlog-likelihood difference to Poisson distribution with same mean: %1.3e"%(
            - L*np.sum(p*np.log(p+1e-100))
            + np.sum(multiplicities_positions*np.log(p[:len(multiplicities_positions)]+1e-100))))


    ###########################################################################
    ### Output the mutations that are observed most often
    ###########################################################################
    if params.drms:
        print("\n\nThe ten most homoplasic mutations are:\n\tmut\tmultiplicity\tDRM details (gene drug AAmut)")
        mutations_sorted = sorted(mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
        for mut, val in mutations_sorted[:params.n]:
            if len(val)>1:
                print("\t%s%d%s\t%d\t%s"%(mut[0], mut[1], mut[2], len(val),
                    " ".join([drms[mut[1]]['gene'], drms[mut[1]]['drug'], drms[mut[1]]['alt_base'][mut[2]]]) if mut[1] in drms else ""))
            else:
                break
    else:
        print("\n\nThe ten most homoplasic mutations are:\n\tmut\tmultiplicity")
        mutations_sorted = sorted(mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
        for mut, val in mutations_sorted[:params.n]:
            if len(val)>1:
                print("\t%s%d%s\t%d"%(mut[0], mut[1], mut[2], len(val)))
            else:
                break

    # optional output specifically for mutations on terminal branches
    if params.detailed:
        if params.drms:
            print("\n\nThe ten most homoplasic mutation on terminal branches are:\n\tmut\tmultiplicity\tDRM details (gene drug AAmut)")
            terminal_mutations_sorted = sorted(terminal_mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
            for mut, val in terminal_mutations_sorted[:params.n]:
                if len(val)>1:
                    print("\t%s%d%s\t%d\t%s"%(mut[0], mut[1], mut[2], len(val),
                        " ".join([drms[mut[1]]['gene'], drms[mut[1]]['drug'], drms[mut[1]]['alt_base'][mut[2]]]) if mut[1] in drms else ""))
                else:
                    break
        else:
            print("\n\nThe ten most homoplasic mutation on terminal branches are:\n\tmut\tmultiplicity")
            terminal_mutations_sorted = sorted(terminal_mutations.items(), key=lambda x:len(x[1])-0.1*x[0][1]/L, reverse=True)
            for mut, val in terminal_mutations_sorted[:params.n]:
                if len(val)>1:
                    print("\t%s%d%s\t%d"%(mut[0], mut[1], mut[2], len(val)))
                else:
                    break

    ###########################################################################
    ### Output strains that have many homoplasic mutations
    ###########################################################################
    # TODO: add statistical criterion
    if params.detailed:
        if params.drms:
            print("\n\nTaxons that carry positions that mutated elsewhere in the tree:\n\ttaxon name\t#of homoplasic mutations\t# DRM")
            mutation_by_strain_sorted = sorted(mutation_by_strain.items(), key=lambda x:len(x[1]), reverse=True)
            for name, val in mutation_by_strain_sorted[:params.n]:
                if len(val):
                    print("\t%s\t%d\t%d"%(name, len(val),
                        len([mut for mut,l in val if mut[1] in drms])))
        else:
            print("\n\nTaxons that carry positions that mutated elsewhere in the tree:\n\ttaxon name\t#of homoplasic mutations")
            mutation_by_strain_sorted = sorted(mutation_by_strain.items(), key=lambda x:len(x[1]), reverse=True)
            for name, val in mutation_by_strain_sorted[:params.n]:
                if len(val):
                    print("\t%s\t%d"%(name, len(val)))


    return 0