Ejemplo n.º 1
0
def calc_dstat_combinations(args):
    """Calculate genome-wide D-statstics for
       all possible trio combinations of samples
       and outgroups specified.
    """
    mvf = MultiVariantFile(args.mvf, 'read')
    data = {}
    sample_labels = mvf.get_sample_ids()
    if args.outgroup_indices is not None:
        outgroup_indices = [
            int(x) for x in args.outgroup_indices[0].split(",")
        ]
    elif args.outgroup_labels is not None:
        outgroup_indices = mvf.get_sample_indices(
            ids=args.outgroup_labels[0].split(","))
    if args.sample_indices is not None:
        sample_indices = [int(x) for x in args.sample_indices[0].split(",")]
    elif args.sample_labels is not None:
        sample_indices = mvf.get_sample_indices(
            ids=args.sample_labels[0].split(","))
    else:
        sample_indices = mvf.get_sample_indices()
    if args.contig_ids is not None:
        contig_ids = args.contig_ids[0].split(",")
    elif args.contig_labels is not None:
        contig_ids = mvf.get_contig_ids(
            labels=args.contig_labels[0].split(","))
    else:
        contig_ids = None
    if any(x in outgroup_indices for x in sample_indices):
        raise RuntimeError("Sample and Outgroup column lists cannot overlap.")
    for contig, _, allelesets in mvf:
        if contig not in contig_ids:
            continue
        alleles = mvf.decode(allelesets[0])
        for i, j, k in combinations(sample_indices, 3):
            for outgroup in outgroup_indices:
                subset = [alleles[x] for x in [i, j, k, outgroup]]
                if any(x not in 'ATGC' for x in subset):
                    continue
                if subset[-1] not in subset[:3]:
                    continue
                if len(set(subset)) != 2:
                    continue
                # [ABBA, BABA, BBAA]
                val = (0 + 1 * (subset[0] == subset[3]) + 2 *
                       (subset[1] == subset[3]) + 4 * (subset[2] == subset[3]))
                if val in (1, 2):
                    val -= 1
                elif val == 4:
                    val = 2
                else:
                    continue
                tetrad = (i, j, k, outgroup)
                if tetrad not in data:
                    data[tetrad] = {}
                if contig not in data[tetrad]:
                    data[tetrad][contig] = [0, 0, 0]
                data[tetrad][contig][val] += 1
    # WRITE OUTPUT
    headers = ['sample0', 'sample1', 'sample2', "outgroup"]
    for xcontig in contig_ids:
        headers.extend([
            '{}:abba'.format(xcontig), '{}:baba'.format(xcontig),
            '{}:bbaa'.format(xcontig), '{}:D'.format(xcontig)
        ])
    outfile = OutputFile(path=args.out, headers=headers)
    for i, j, k in combinations(sample_indices, 3):
        for outgroup in outgroup_indices:
            tetrad = tuple([i, j, k, outgroup])
            if tetrad not in data:
                continue
            entry = dict(('sample{}'.format(i), sample_labels[x])
                         for i, x in enumerate(tetrad[:3]))
            entry['outgroup'] = sample_labels[outgroup]
            for contig in contig_ids:
                if contig not in data[tetrad]:
                    entry.update(dict().fromkeys([
                        '{}:abba'.format(contig), '{}:baba'.format(contig),
                        '{}:bbaa'.format(contig), '{}:D'.format(contig)
                    ], '0'))
                else:
                    [abba, baba, bbaa] = data[tetrad][contig]
                    if abba > baba and abba > bbaa:

                        dstat = zerodiv(baba - bbaa, baba + bbaa)
                    elif baba > bbaa and baba > abba:
                        dstat = zerodiv(abba - bbaa, abba + bbaa)
                    else:
                        dstat = zerodiv(abba - baba, abba + baba)
                    entry.update([('{}:abba'.format(contig), abba),
                                  ('{}:baba'.format(contig), baba),
                                  ('{}:bbaa'.format(contig), bbaa),
                                  ('{}:D'.format(contig), dstat)])
            outfile.write_entry(entry)
    return ''
Ejemplo n.º 2
0
def calc_pairwise_distances(args):
    """Count the pairwise nucleotide distance between
       combinations of samples in a window
    """
    args.qprint("Running CalcPairwiseDistances")
    mvf = MultiVariantFile(args.mvf, 'read')
    args.qprint("Input MVF: Read")
    data = {}
    data_order = []
    if args.sample_indices is not None:
        sample_indices = [int(x) for x in args.sample_indices[0].split(",")]
    elif args.sample_labels is not None:
        sample_indices = mvf.get_sample_indices(
            ids=args.sample_labels[0].split(","))
    else:
        sample_indices = mvf.get_sample_indices()
    sample_labels = mvf.get_sample_ids(indices=sample_indices)
    args.qprint("Calculating for sample columns: {}".format(
        list(sample_indices)))
    current_contig = None
    current_position = 0
    data_in_buffer = False
    sample_pairs = [tuple(x) for x in combinations(sample_indices, 2)]
    base_matches = dict((x, {}) for x in sample_pairs)
    all_match = {}
    if mvf.flavor == 'dna':
        allele_frames = (0, )
        args.data_type = 'dna'
    elif mvf.flavor == 'prot':
        allele_frames = (0, )
        args.data_type = 'dna'
    elif mvf.flavor == 'codon':
        if args.data_type == 'prot':
            allele_frames = (0, )
        else:
            allele_frames = (1, 2, 3)
            args.data_type = 'dna'
    args.qprint("MVF flavor is: {}".format(mvf.flavor))
    args.qprint("Data type is: {}".format(args.data_type))
    args.qprint("Ambiguous mode: {}".format(args.ambig))
    args.qprint("Processing MVF Records")
    pwdistance_function = get_pairwise_function(args.data_type, args.ambig)
    if args.emit_counts:
        outfile_emitcounts = open(args.out + ".pairwisecounts", 'w')
    for contig, pos, allelesets in mvf.iterentries(decode=None):
        # Check Minimum Site Coverage
        if check_mincoverage(args.mincoverage, allelesets[0]) is False:
            continue
        # Establish first contig
        if current_contig is None:
            current_contig = contig[:]
            if args.windowsize > 0:
                while pos > current_position + args.windowsize - 1:
                    current_position += args.windowsize
        # Check if windows are specified.
        if not same_window((current_contig, current_position),
                           (contig, pos), args.windowsize):
            data[(current_contig, current_position)] = {
                'contig': current_contig,
                'position': current_position
            }
            data_order.append((current_contig, current_position))
            all_diff, all_total = pwdistance_function(all_match)
            for samplepair in base_matches:
                ndiff, ntotal = pwdistance_function(base_matches[samplepair])
                taxa = "{};{}".format(sample_labels[samplepair[0]],
                                      sample_labels[samplepair[1]])
                data[(current_contig, current_position)].update({
                    '{};ndiff'.format(taxa):
                    ndiff + all_diff,
                    '{};ntotal'.format(taxa):
                    ntotal + all_total,
                    '{};dist'.format(taxa):
                    zerodiv(ndiff + all_diff, ntotal + all_total)
                })
            if contig != current_contig:
                current_contig = contig[:]
                current_position = 0
                if args.windowsize > 0:
                    while pos > current_position + args.windowsize - 1:
                        current_position += args.windowsize
            else:
                current_position += args.windowsize
            if args.emit_counts:
                args.qprint("Writing Full Count Table")
                for p0, p1 in base_matches:
                    outfile_emitcounts.write("#{}\t{}\t{}\t{}\n{}\n".format(
                        p0, p1, current_position, current_contig, "\n".join([
                            "{} {}".format(x,
                                           (base_matches[(p0, p1)].get(x, 0) +
                                            all_match.get(x, 0)))
                            for x in set(base_matches[(p0,
                                                       p1)]).union(all_match)
                        ])))
            base_matches = dict((x, {}) for x in sample_pairs)
            all_match = {}
            data_in_buffer = False
        for iframe in allele_frames:
            alleles = allelesets[iframe]
            if len(alleles) == 1:
                all_match["{0}{0}".format(alleles)] = (
                    all_match.get("{0}{0}".format(alleles), 0) + 1)
                data_in_buffer = True
                continue
            if alleles[1] == '+':
                if alleles[2] in 'X-':
                    continue
                samplepair = (0, int(alleles[3:]))
                if any(x not in sample_indices for x in samplepair):
                    continue
                basepair = "{0}{1}".format(alleles[0], alleles[2])
                base_matches[samplepair][basepair] = (
                    base_matches[samplepair].get(basepair, 0) + 1)
                data_in_buffer = True
                continue
            alleles = mvf.decode(alleles)
            valid_positions = [
                i for i, x in enumerate(alleles)
                if x not in 'X-' and i in sample_indices
            ]
            assert len(alleles) == 4
            assert alleles[0] not in 'X-', alleles
            assert alleles[1] not in 'X-', alleles
            for i, j in combinations(valid_positions, 2):
                samplepair = (i, j)
                basepair = "{0}{1}".format(alleles[i], alleles[j])
                base_matches[samplepair][basepair] = (
                    base_matches[samplepair].get(basepair, 0) + 1)
            data_in_buffer = True
        # print(base_matches)
    if data_in_buffer is True:
        print(sum(base_matches[samplepair].values()), base_matches[samplepair],
              samplepair)
        print(sum(all_match.values()), all_match)
        print(sum(base_matches[samplepair].values()) + sum(all_match.values()))
        # Check whether, windows, contigs, or total
        if args.windowsize == 0:
            current_contig = 'TOTAL'
            current_position = 0
        elif args.windowsize == -1:
            current_position = 0
        data[(current_contig, current_position)] = {
            'contig': current_contig,
            'position': current_position
        }
        data_order.append((current_contig, current_position))
        # print("All match")
        all_diff, all_total = pwdistance_function(all_match)
        print(all_diff, all_total)
        for samplepair in base_matches:
            ndiff, ntotal = pwdistance_function(base_matches[samplepair])
            taxa = "{};{}".format(sample_labels[samplepair[0]],
                                  sample_labels[samplepair[1]])
            data[(current_contig, current_position)].update({
                '{};ndiff'.format(taxa):
                ndiff + all_diff,
                '{};ntotal'.format(taxa):
                ntotal + all_total,
                '{};dist'.format(taxa):
                zerodiv(ndiff + all_diff, ntotal + all_total)
            })
        if args.emit_counts:
            args.qprint("Writing Full Count Table")
            for p0, p1 in base_matches:
                outfile_emitcounts.write("#{}\t{}\t{}\t{}\n{}\n".format(
                    p0, p1, current_position, current_contig, "\n".join([
                        "{} {}".format(x, (base_matches[(p0, p1)].get(x, 0) +
                                           all_match.get(x, 0)))
                        for x in set(base_matches[(p0, p1)]).union(all_match)
                    ])))
    args.qprint("Writing Output")
    headers = ['contig', 'position']
    for samplepair in sample_pairs:
        headers.extend([
            '{};{};{}'.format(sample_labels[samplepair[0]],
                              sample_labels[samplepair[1]], x)
            for x in ('ndiff', 'ntotal', 'dist')
        ])
    outfile = OutputFile(path=args.out, headers=headers)
    for okey in data_order:
        outfile.write_entry(data[okey])
    if args.emit_counts:
        outfile_emitcounts.close()
    return ''
Ejemplo n.º 3
0
def calc_pairwise_distances(args):
    """Count the pairwise nucleotide distance between
       combinations of samples in a window
    """
    mvf = MultiVariantFile(args.mvf, 'read')
    data = {}
    sample_labels = mvf.get_sample_labels()
    if args.sample_indices is not None:
        sample_indices = [int(x) for x in
                          args.sample_indices[0].split(",")]
    elif args.sample_labels is not None:
        sample_indices = mvf.get_sample_indices(
            labels=args.sample_labels[0].split(","))
    else:
        sample_indices = mvf.get_sample_indices()
    current_contig = None
    current_position = 0
    data_in_buffer = False
    sample_pairs = [tuple(x) for x in combinations(sample_indices, 2)]
    base_matches = dict([(x, {}) for x in sample_pairs])
    all_match = {}
    for contig, pos, allelesets in mvf:
        # Check Minimum Site Coverage
        if check_mincoverage(args.mincoverage, allelesets[0]) is False:
            continue
        # Establish first contig
        if current_contig is None:
            current_contig = contig[:]
            while pos > current_position + args.windowsize - 1:
                current_position += args.windowsize
        # Check if windows are specified.
        if not same_window((current_contig, current_position),
                           (contig, pos), args.windowsize):
            data[(current_contig, current_position)] = {
                'contig': current_contig, 'position': current_position}
            if mvf.flavor == 'dna':
                all_diff, all_total = pairwise_distance_nuc(all_match)
            elif mvf.flavor == 'prot':
                all_diff, all_total = pairwise_distance_prot(all_match)
            for samplepair in base_matches:
                if mvf.flavor == 'dna':
                    ndiff, ntotal = pairwise_distance_nuc(
                        base_matches[samplepair])
                elif mvf.flavor == 'prot':
                    ndiff, ntotal = pairwise_distance_prot(
                        base_matches[samplepair])
                taxa = "{};{}".format(sample_labels[samplepair[0]],
                                      sample_labels[samplepair[1]])
                data[(current_contig, current_position)].update({
                    '{};ndiff'.format(taxa): ndiff + all_diff,
                    '{};ntotal'.format(taxa): ntotal + all_total,
                    '{};dist'.format(taxa): zerodiv(ndiff + all_diff,
                                                    ntotal + all_total)})
            if contig != current_contig:
                current_contig = contig[:]
                current_position = 0
                while pos > current_position + args.windowsize - 1:
                    current_position += args.windowsize
            else:
                current_position += args.windowsize
            base_matches = dict([(x, {}) for x in sample_pairs])
            all_match = {}
            data_in_buffer = False
        alleles = allelesets[0]
        if len(alleles) == 1:
            all_match["{}{}".format(alleles, alleles)] = (
                all_match.get("{}{}".format(alleles, alleles),
                              0) + 1)
            data_in_buffer = True
            continue
        if alleles[1] == '+':
            if 'X' in alleles or '-' in alleles:
                continue
            samplepair = (0, int(alleles[3:]))
            if any(x not in sample_indices for x in samplepair):
                continue
            basepair = "{}{}".format(alleles[0], alleles[2])
            base_matches[samplepair][basepair] = (
                base_matches[samplepair].get(basepair, 0) + 1)
            data_in_buffer = True
            continue
        alleles = mvf.decode(alleles)
        valid_positions = [i for i, x in enumerate(alleles)
                           if x not in 'X-']
        for i, j in combinations(valid_positions, 2):
            samplepair = (i, j)
            if any(x not in sample_indices for x in samplepair):
                continue
            basepair = "{}{}".format(alleles[i], alleles[j])
            base_matches[samplepair][basepair] = (
                base_matches[samplepair].get(basepair, 0) + 1)
        data_in_buffer = True
    if data_in_buffer is True:
        # Check whether, windows, contigs, or total
        if args.windowsize == 0:
            current_contig = 'TOTAL'
            current_position = 0
        elif args.windowsize == -1:
            current_position = 0
        data[(current_contig, current_position)] = {
            'contig': current_contig, 'position': current_position}
        if mvf.flavor == 'dna':
            all_diff, all_total = pairwise_distance_nuc(all_match)
        elif mvf.flavor == 'prot':
            all_diff, all_total = pairwise_distance_prot(all_match)
        for samplepair in base_matches:
            if mvf.flavor == 'dna':
                ndiff, ntotal = pairwise_distance_nuc(base_matches[samplepair])
            elif mvf.flavor == 'prot':
                ndiff, ntotal = pairwise_distance_prot(
                    base_matches[samplepair])
            taxa = "{};{}".format(sample_labels[samplepair[0]],
                                  sample_labels[samplepair[1]])
            data[(current_contig, current_position)].update({
                '{};ndiff'.format(taxa): ndiff + all_diff,
                '{};ntotal'.format(taxa): ntotal + all_total,
                '{};dist'.format(taxa): zerodiv(ndiff + all_diff,
                                                ntotal + all_total)})
    headers = ['contig', 'position']
    for samplepair in sample_pairs:
        headers.extend(['{};{};{}'.format(
            sample_labels[samplepair[0]],
            sample_labels[samplepair[1]],
            x) for x in ('ndiff', 'ntotal', 'dist')])
    outfile = OutputFile(path=args.out, headers=headers)
    sorted_entries = sorted([(
        data[k]['contig'], data[k]['position'], k)
                             for k in data])
    for _, _, k in sorted_entries:
        outfile.write_entry(data[k])
    return ''