Exemplo n.º 1
0
class IhhhmmmParser(object):
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir, remove_N_nukes=True)
        self.perfplotter = PerformancePlotter(self.germline_seqs, self.args.plotdir, 'ihhhmmm')

        self.details = OrderedDict()
        self.failtails = {}
        self.n_partially_failed = 0

        # get sequence info that was passed to ihhhmmm
        self.siminfo = OrderedDict()
        self.sim_need = []  # list of queries that we still need to find
        with opener('r')(self.args.simfname) as seqfile:
            reader = csv.DictReader(seqfile)
            iline = 0
            for line in reader:
                if self.args.queries != None and line['unique_id'] not in self.args.queries:
                    continue
                self.siminfo[line['unique_id']] = line
                self.sim_need.append(line['unique_id'])
                iline += 1
                if args.n_max_queries > 0 and iline >= args.n_max_queries:
                    break

        fostream_names = glob.glob(self.args.indir + '/*.fostream')
        fostream_names.sort()  # maybe already sorted?
        for infname in fostream_names:
            if len(self.sim_need) == 0:
                break

            # try to get whatever you can for the failures
            unique_ids = self.find_partial_failures(infname)  # returns list of unique ids in this file

            with opener('r')(infname) as infile:
                self.parse_file(infile, unique_ids)

        # now check that we got results for all the queries we wanted
        n_failed = 0
        for unique_id in self.siminfo:
            if unique_id not in self.details and unique_id not in self.failtails:
                print '%-20s  no info' % unique_id
                self.perfplotter.add_fail()
                n_failed += 1

        print ''
        print 'partially failed: %d / %d = %.2f' % (self.n_partially_failed, len(self.siminfo), float(self.n_partially_failed) / len(self.siminfo))
        print 'failed:           %d / %d = %.2f' % (n_failed, len(self.siminfo), float(n_failed) / len(self.siminfo))
        print ''

        self.perfplotter.plot()

    # ----------------------------------------------------------------------------------------
    def parse_file(self, infile, unique_ids):
        fk = FileKeeper(infile.readlines())
        i_id = 0
        while not fk.eof and len(self.sim_need) > 0:
            self.parse_detail(fk, unique_ids[i_id])
            i_id += 1
        
    # ----------------------------------------------------------------------------------------
    def parse_detail(self, fk, unique_id):
        assert fk.iline < len(fk.lines)

        while fk.line[1] != 'Details':
            fk.increment()
            if fk.eof:
                return

        fk.increment()
        info = {}
        info['unique_id'] = unique_id
        for begin_line, column, index, required, default in line_order:
            if fk.line[0].find(begin_line) != 0:
                if required:
                    print 'oop', begin_line, fk.line
                    sys.exit()
                else:
                    info[column] = default
                    continue
            if column != '':
                info[column] = clean_value(column, fk.line[index])
                # if '[' in info[column]:
                #     print 'added', column, clean_value(column, fk.line[index])
                if column.find('_gene') == 1:
                    region = column[0]
                    info[region + '_5p_del'] = int(fk.line[fk.line.index('start:') + 1]) - 1  # NOTE their indices are 1-based
                    gl_length = int(fk.line[fk.line.index('gene:') + 1]) - 1
                    match_end = int(fk.line[fk.line.index('end:') + 1]) - 1
                    assert gl_length >= match_end
                    info[region + '_3p_del'] = gl_length - match_end

            fk.increment()

        if unique_id not in self.sim_need:
            while not fk.eof and fk.line[1] != 'Details':  # skip stuff until start of next Detail block
                fk.increment()
            return

        info['fv_insertion'] = ''
        info['jf_insertion'] = ''
        info['seq'] = info['v_qr_seq'] + info['vd_insertion'] + info['d_qr_seq'] + info['dj_insertion'] + info['j_qr_seq']

        if '-' in info['seq']:
            print 'ERROR found a dash in %s, returning failure' % unique_id
            while not fk.eof and fk.line[1] != 'Details':  # skip stuff until start of next Detail block
                fk.increment()
            return

        if info['seq'] not in self.siminfo[unique_id]['seq']:  # arg. I can't do != because it tacks on v left and j right deletions
            print 'ERROR didn\'t find the right sequence for %s' % unique_id
            print '  ', info['seq']
            print '  ', self.siminfo[unique_id]['seq']
            sys.exit()

        if self.args.debug:
            print unique_id
            utils.print_reco_event(self.germline_seqs, self.siminfo[unique_id], label='true:', extra_str='    ')
            utils.print_reco_event(self.germline_seqs, info, label='inferred:', extra_str='    ')

        for region in utils.regions:
            if info[region + '_gene'] not in self.germline_seqs[region]:
                print 'ERROR %s not in germlines' % info[region + '_gene']
                assert False

            gl_seq = info[region + '_gl_seq']
            if '[' in gl_seq:  # ambiguous
                for nuke in utils.nukes:
                    gl_seq = gl_seq.replace('[', nuke)
                    if gl_seq in self.germline_seqs[region][info[region + '_gene']]:
                        print '  replaced [ with %s' % nuke
                        break
                info[region + '_gl_seq'] = gl_seq

            if info[region + '_gl_seq'] not in self.germline_seqs[region][info[region + '_gene']]:
                print 'ERROR gl match not found for %s in %s' % (info[region + '_gene'], unique_id)
                print '  ', info[region + '_gl_seq']
                print '  ', self.germline_seqs[region][info[region + '_gene']]                
                self.perfplotter.add_partial_fail(self.siminfo[unique_id], info)
                while not fk.eof and fk.line[1] != 'Details':  # skip stuff until start of next Detail block
                    fk.increment()
                return

        self.perfplotter.evaluate(self.siminfo[unique_id], info)
        self.details[unique_id] = info
        self.sim_need.remove(unique_id)

        while not fk.eof and fk.line[1] != 'Details':  # skip stuff until start of next Detail block
            fk.increment()
        
    # ----------------------------------------------------------------------------------------
    def find_partial_failures(self, fostream_name):
        unique_ids = []
        for line in open(fostream_name.replace('.fostream', '')).readlines():
            if len(self.sim_need) == 0:
                return
            if len(line.strip()) == 0:  # skip blank lines
                continue

            line = line.replace('"', '')
            line = line.split(';')

            unique_id = line[0]
            
            if 'NA' not in line:  # skip lines that were ok
                unique_ids.append(unique_id)
                continue
            if unique_id not in self.sim_need:
                continue
            if unique_id not in self.siminfo:
                continue  # not looking for this <unique_id> a.t.m.

            info = {}
            info['unique_id'] = unique_id
            for stuff in line:
                for region in utils.regions:  # add the first instance of IGH[VDJ] (if it's there at all)
                    if 'IGH'+region.upper() in stuff and region+'_gene' not in info:
                        genes = re.findall('IGH' + region.upper() + '[^ ][^ ]*', stuff)
                        if len(genes) == 0:
                            print 'ERROR no %s genes in %s' % (region, stuff)
                        gene = genes[0]
                        if gene not in self.germline_seqs[region]:
                            print 'ERROR bad gene %s for %s' % (gene, unique_id)
                            sys.exit()
                        info[region + '_gene'] = gene
            self.perfplotter.add_partial_fail(self.siminfo[unique_id], info)
            if self.args.debug:
                print '%-20s  partial fail %s %s %s' % (unique_id,
                                                     utils.color_gene(info['v_gene']) if 'v_gene' in info else '',
                                                     utils.color_gene(info['d_gene']) if 'd_gene' in info else '',
                                                     utils.color_gene(info['j_gene']) if 'j_gene' in info else ''),
                print '  (true %s %s %s)' % tuple([self.siminfo[unique_id][region + '_gene'] for region in utils.regions])
            self.failtails[unique_id] = info
            self.n_partially_failed += 1
            self.sim_need.remove(unique_id)

        return unique_ids
Exemplo n.º 2
0
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir)

        perfplotter = PerformancePlotter(self.germline_seqs, self.args.plotdir, 'imgt')

        # get sequence info that was passed to imgt
        self.seqinfo = {}
        with opener('r')(self.args.simfname) as simfile:
            reader = csv.DictReader(simfile)
            iline = 0
            for line in reader:
                if self.args.queries != None and line['unique_id'] not in self.args.queries:
                    continue
                if len(re.findall('_[FP]', line['j_gene'])) > 0:
                    line['j_gene'] = line['j_gene'].replace(re.findall('_[FP]', line['j_gene'])[0], '')
                self.seqinfo[line['unique_id']] = line
                iline += 1
                if self.args.n_queries > 0 and iline >= self.args.n_queries:
                    break

        paragraphs, csv_info = None, None
        if self.args.infname != None and '.html' in self.args.infname:
            print 'reading', self.args.infname
            with opener('r')(self.args.infname) as infile:
                soup = BeautifulSoup(infile)
                paragraphs = soup.find_all('pre')

        summarydir = self.args.indir[ : self.args.indir.rfind('/')]  # one directoy up from <indir>, which has the detailed per-sequence files
        summary_fname = glob.glob(summarydir + '/1_Summary_*.txt')
        assert len(summary_fname) == 1
        summary_fname = summary_fname[0]
        get_genes_to_skip(summary_fname, self.germline_seqs)

        n_failed, n_skipped, n_total, n_not_found, n_found = 0, 0, 0, 0, 0
        for unique_id in self.seqinfo:
            if self.args.debug:
                print unique_id,
            imgtinfo = []
            # print 'true'
            # utils.print_reco_event(self.germline_seqs, self.seqinfo[unique_id])
            if self.args.infname != None and '.html' in self.args.infname:
                for pre in paragraphs:  # NOTE this loops over everything an awful lot of times. Shouldn't really matter for now, though
                    if unique_id in pre.text:
                        imgtinfo.append(pre.text)
            else:
                n_total += 1
                assert self.args.infname == None
                infnames = glob.glob(self.args.indir + '/' + unique_id + '*')
                assert len(infnames) <= 1
                if len(infnames) != 1:
                    if self.args.debug:
                        print ' couldn\'t find it'
                    n_not_found += 1
                    continue
                n_found += 1
                with opener('r')(infnames[0]) as infile:
                    full_text = infile.read()
                    if len(re.findall('[123]. Alignment for [VDJ]-GENE', full_text)) < 3:
                        failregions = re.findall('No [VDJ]-GENE has been identified', full_text)
                        if self.args.debug and len(failregions) > 0:
                            print '    ', failregions
                        n_failed += 1
                        continue

                    # loop over the paragraphs I want
                    position = full_text.find(unique_id)  # don't need this one
                    for ir in range(4):
                        position = full_text.find(unique_id, position+1)
                        pgraph = full_text[position : full_text.find('\n\n', position+1)]
                        if 'insertion(s) and/or deletion(s) which are not dealt in this release' in pgraph:
                            ir -= 1
                            continue
                        imgtinfo.append(pgraph)  # query seq paragraph

            if len(imgtinfo) == 0:
                print '%s no info' % unique_id
                continue
            else:
                if self.args.debug:
                    print ''
            line = self.parse_query_text(unique_id, imgtinfo)
            if 'skip_gene' in line:
                # assert self.args.skip_missing_genes
                n_skipped += 1
                continue
            try:
                assert 'failed' not in line
                joinparser.add_insertions(line, debug=self.args.debug)
                joinparser.resolve_overlapping_matches(line, debug=False, germlines=self.germline_seqs)
            except (AssertionError, KeyError):
                print '    giving up'
                n_failed += 1
                perfplotter.add_partial_fail(self.seqinfo[unique_id], line)
                # print '    perfplotter: not sure what to do with a fail'
                continue
            perfplotter.evaluate(self.seqinfo[unique_id], line)
            if self.args.debug:
                utils.print_reco_event(self.germline_seqs, self.seqinfo[unique_id], label='true:')
                utils.print_reco_event(self.germline_seqs, line, label='inferred:')

        perfplotter.plot()
        print 'failed: %d / %d = %f' % (n_failed, n_total, float(n_failed) / n_total)
        print 'skipped: %d / %d = %f' % (n_skipped, n_total, float(n_skipped) / n_total)
        print '    ',
        for g, n in genes_actually_skipped.items():
            print '  %d %s' % (n, utils.color_gene(g))
        print ''
        if n_not_found > 0:
            print '  not found: %d / %d = %f' % (n_not_found, n_not_found + n_found, n_not_found / float(n_not_found + n_found))
Exemplo n.º 3
0
class Waterer(object):
    """ Run smith-waterman on the query sequences in <infname> """
    def __init__(self, args, input_info, reco_info, germline_seqs, parameter_dir, write_parameters=False):
        self.parameter_dir = parameter_dir
        self.args = args
        self.debug = self.args.debug if self.args.sw_debug is None else self.args.sw_debug

        self.input_info = input_info
        self.remaining_queries = [query for query in self.input_info.keys()]  # we remove queries from this list when we're satisfied with the current output (in general we may have to rerun some queries with different match/mismatch scores)
        self.new_indels = 0  # number of new indels that were kicked up this time through

        self.reco_info = reco_info
        self.germline_seqs = germline_seqs
        self.pcounter, self.true_pcounter, self.perfplotter = None, None, None
        if write_parameters:
            self.pcounter = ParameterCounter(self.germline_seqs)
            if not self.args.is_data:
                self.true_pcounter = ParameterCounter(self.germline_seqs)
        if self.args.plot_performance:
            self.perfplotter = PerformancePlotter(self.germline_seqs, 'sw')
        self.info = {}
        self.info['queries'] = []
        self.info['all_best_matches'] = set()  # set of all the matches we found (for *all* queries)
        self.info['skipped_unproductive_queries'] = []  # list of unproductive queries
        # self.info['skipped_indel_queries'] = []  # list of queries that had indels
        self.info['skipped_unknown_queries'] = []
        self.info['indels'] = {}
        if self.args.apply_choice_probs_in_sw:
            if self.debug:
                print '  reading gene choice probs from', parameter_dir
            self.gene_choice_probs = utils.read_overall_gene_probs(parameter_dir)

        with opener('r')(self.args.datadir + '/v-meta.json') as json_file:  # get location of <begin> cysteine in each v region
            self.cyst_positions = json.load(json_file)
        with opener('r')(self.args.datadir + '/j_tryp.csv') as csv_file:  # get location of <end> tryptophan in each j region (TGG)
            tryp_reader = csv.reader(csv_file)
            self.tryp_positions = {row[0]:row[1] for row in tryp_reader}  # WARNING: this doesn't filter out the header line

        self.outfile = None
        if self.args.outfname is not None:
            self.outfile = open(self.args.outfname, 'a')

        self.n_unproductive = 0
        self.n_total = 0

        print 'smith-waterman'

    # ----------------------------------------------------------------------------------------
    def __del__(self):
        if self.args.outfname is not None:
            self.outfile.close()

    # ----------------------------------------------------------------------------------------
    def clean(self):
        if self.pcounter is not None:
            self.pcounter.clean()
        if self.true_pcounter is not None:
            self.true_pcounter.clean()

    # ----------------------------------------------------------------------------------------
    def run(self):
        # start = time.time()
        base_infname = 'query-seqs.fa'
        base_outfname = 'query-seqs.bam'
        sys.stdout.flush()

        n_tries = 0
        while len(self.remaining_queries) > 0:  # we remove queries from <self.remaining_queries> as we're satisfied with their output
            self.write_vdjalign_input(base_infname, n_procs=self.args.n_fewer_procs)
            self.execute_command(base_infname, base_outfname, self.args.n_fewer_procs)
            self.read_output(base_outfname, n_procs=self.args.n_fewer_procs)
            n_tries += 1
            if n_tries > 2:
                self.info['skipped_unknown_queries'] += self.remaining_queries
                break

        self.finalize()

    # ----------------------------------------------------------------------------------------
    def finalize(self):
        if self.perfplotter is not None:
            self.perfplotter.plot(self.args.plotdir + '/sw/performance')
        # print '    sw time: %.3f' % (time.time()-start)
        if self.n_unproductive > 0:
            print '      unproductive skipped %d / %d = %.2f' % (self.n_unproductive, self.n_total, float(self.n_unproductive) / self.n_total)
        # if len(self.info['skipped_indel_queries']) > 0:
        #     print '      indels skipped %d / %d = %.2f' % (len(self.info['skipped_indel_queries']), self.n_total, float(len(self.info['skipped_indel_queries'])) / self.n_total)
        if len(self.info['indels']) > 0:
            print '      indels: %s' % ':'.join(self.info['indels'].keys())
        if self.pcounter is not None:
            self.pcounter.write(self.parameter_dir)
            if self.args.plotdir is not None:
                self.pcounter.plot(self.args.plotdir + '/sw', subset_by_gene=True, cyst_positions=self.cyst_positions, tryp_positions=self.tryp_positions)
                if self.true_pcounter is not None:
                    self.true_pcounter.plot(self.args.plotdir + 'sw/true', subset_by_gene=True, cyst_positions=self.cyst_positions, tryp_positions=self.tryp_positions)

    # ----------------------------------------------------------------------------------------
    def execute_command(self, base_infname, base_outfname, n_procs):
        if n_procs == 1:
            cmd_str = self.get_vdjalign_cmd_str(self.args.workdir, base_infname, base_outfname)
            proc = Popen(cmd_str.split(), stdout=PIPE, stderr=PIPE)
            out, err = proc.communicate()
            utils.process_out_err(out, err)
            if not self.args.no_clean:
                os.remove(self.args.workdir + '/' + base_infname)
        else:
            procs = []
            for iproc in range(n_procs):
                cmd_str = self.get_vdjalign_cmd_str(self.args.workdir + '/sw-' + str(iproc), base_infname, base_outfname)
                procs.append(Popen(cmd_str.split(), stdout=PIPE, stderr=PIPE))
                time.sleep(0.1)
            for iproc in range(len(procs)):
                out, err = procs[iproc].communicate()
                utils.process_out_err(out, err, extra_str=str(iproc))
            if not self.args.no_clean:
                for iproc in range(n_procs):
                    os.remove(self.args.workdir + '/sw-' + str(iproc) + '/' + base_infname)

        sys.stdout.flush()

    # ----------------------------------------------------------------------------------------
    def write_vdjalign_input(self, base_infname, n_procs):
        queries_per_proc = float(len(self.remaining_queries)) / n_procs
        n_queries_per_proc = int(math.ceil(queries_per_proc))
        if n_procs == 1:  # double check for rounding problems or whatnot
            assert n_queries_per_proc == len(self.remaining_queries)
        for iproc in range(n_procs):
            workdir = self.args.workdir
            if n_procs > 1:
                workdir += '/sw-' + str(iproc)
                utils.prep_dir(workdir)
            with opener('w')(workdir + '/' + base_infname) as sub_infile:
                for iquery in range(iproc*n_queries_per_proc, (iproc + 1)*n_queries_per_proc):
                    if iquery >= len(self.remaining_queries):
                        break
                    query_name = self.remaining_queries[iquery]
                    sub_infile.write('>' + query_name + ' NUKES\n')

                    seq = self.input_info[query_name]['seq']
                    if query_name in self.info['indels']:
                        seq = self.info['indels'][query_name]['reversed_seq']  # use the query sequence with shm insertions and deletions reversed
                    sub_infile.write(seq + '\n')

    # ----------------------------------------------------------------------------------------
    def get_vdjalign_cmd_str(self, workdir, base_infname, base_outfname):
        """
        Run smith-waterman alignment (from Connor's ighutils package) on the seqs in <base_infname>, and toss all the top matches into <base_outfname>.
        """
        # large gap-opening penalty: we want *no* gaps in the middle of the alignments
        # match score larger than (negative) mismatch score: we want to *encourage* some level of shm. If they're equal, we tend to end up with short unmutated alignments, which screws everything up
        os.environ['PATH'] = os.getenv('PWD') + '/packages/samtools:' + os.getenv('PATH')
        check_output(['which', 'samtools'])
        if not os.path.exists(self.args.ighutil_dir + '/bin/vdjalign'):
            raise Exception('ERROR ighutil path d.n.e: ' + self.args.ighutil_dir + '/bin/vdjalign')
        cmd_str = self.args.ighutil_dir + '/bin/vdjalign align-fastq -q'
        if self.args.slurm:
            cmd_str = 'srun ' + cmd_str
        cmd_str += ' --max-drop 50'
        match, mismatch = self.args.match_mismatch
        cmd_str += ' --match ' + str(match) + ' --mismatch ' + str(mismatch)
        cmd_str += ' --gap-open ' + str(self.args.gap_open_penalty)  #1000'  #50'
        cmd_str += ' --vdj-dir ' + self.args.datadir
        cmd_str += ' ' + workdir + '/' + base_infname + ' ' + workdir + '/' + base_outfname

        return cmd_str

    # ----------------------------------------------------------------------------------------
    def read_output(self, base_outfname, n_procs=1):
        n_processed = 0
        for iproc in range(n_procs):
            workdir = self.args.workdir
            if n_procs > 1:
                workdir += '/sw-' + str(iproc)
            outfname = workdir + '/' + base_outfname
            with contextlib.closing(pysam.Samfile(outfname)) as bam:
                grouped = itertools.groupby(iter(bam), operator.attrgetter('qname'))
                for _, reads in grouped:  # loop over query sequences
                    self.n_total += 1
                    self.process_query(bam, list(reads))
                    n_processed += 1

            if not self.args.no_clean:
                os.remove(outfname)
                if n_procs > 1:  # still need the top-level workdir
                    os.rmdir(workdir)

        print '    processed %d queries' % n_processed

        if len(self.remaining_queries) > 0:
            if self.new_indels > 0:  # if we skipped some events, and if none of those were because they were indels, then increase mismatch score
                print '      skipped %d queries (%d indels), rerunning them' % (len(self.remaining_queries), self.new_indels)
                self.new_indels = 0
            else:
                print '      skipped %d queries (%d indels), increasing mismatch score (%d --> %d) and rerunning them' % (len(self.remaining_queries), self.new_indels, self.args.match_mismatch[1], self.args.match_mismatch[1] + 1)
                self.args.match_mismatch[1] += 1
                self.new_indels = 0

    # ----------------------------------------------------------------------------------------
    def get_choice_prob(self, region, gene):
        choice_prob = 1.0
        if gene in self.gene_choice_probs[region]:
            choice_prob = self.gene_choice_probs[region][gene]
        else:
            choice_prob = 0.0  # NOTE would it make sense to use something else here?
        return choice_prob

    # ----------------------------------------------------------------------------------------
    def get_indel_info(self, query_name, cigarstr, qrseq, glseq, gene):
        cigars = re.findall('[0-9][0-9]*[A-Z]', cigarstr)  # split cigar string into its parts
        cigars = [(cstr[-1], int(cstr[:-1])) for cstr in cigars]  # split each part into the code and the length

        codestr = ''
        qpos = 0  # position within query sequence
        indelfo = {'reversed_seq' : '', 'indels' : []}  # replacement_seq: query seq with insertions removed and germline bases inserted at the position of deletions
        tmp_indices = []
        for code, length in cigars:
            codestr += length * code
            if code == 'I':  # advance qr seq but not gl seq
                indelfo['indels'].append({'type' : 'insertion', 'pos' : qpos, 'len' : length, 'seqstr' : ''})  # insertion begins at <pos>
                tmp_indices += [len(indelfo['indels']) - 1  for _ in range(length)]# indel index corresponding to this position in the alignment
            elif code == 'D':  # advance qr seq but not gl seq
                indelfo['indels'].append({'type' : 'deletion', 'pos' : qpos, 'len' : length, 'seqstr' : ''})  # first deleted base is <pos> (well, first base which is in the position of the first deleted base)
                tmp_indices += [len(indelfo['indels']) - 1  for _ in range(length)]# indel index corresponding to this position in the alignment
            else:
                tmp_indices += [None  for _ in range(length)]  # indel index corresponding to this position in the alignment
            qpos += length

        qrprintstr, glprintstr = '', ''
        iqr, igl = 0, 0
        for icode in range(len(codestr)):
            code = codestr[icode]
            if code == 'M':
                qrbase = qrseq[iqr]
                if qrbase != glseq[igl]:
                    qrbase = utils.color('red', qrbase)
                qrprintstr += qrbase
                glprintstr += glseq[igl]
                indelfo['reversed_seq'] += qrseq[iqr]  # add the base to the overall sequence with all indels reversed
            elif code == 'S':
                continue
            elif code == 'I':
                qrprintstr += utils.color('light_blue', qrseq[iqr])
                glprintstr += utils.color('light_blue', '*')
                indelfo['indels'][tmp_indices[icode]]['seqstr'] += qrseq[iqr]  # and to the sequence of just this indel
                igl -= 1
            elif code == 'D':
                qrprintstr += utils.color('light_blue', '*')
                glprintstr += utils.color('light_blue', glseq[igl])
                indelfo['reversed_seq'] += glseq[igl]  # add the base to the overall sequence with all indels reversed
                indelfo['indels'][tmp_indices[icode]]['seqstr'] += glseq[igl]  # and to the sequence of just this indel
                iqr -= 1
            else:
                raise Exception('unhandled code %s' % code)

            iqr += 1
            igl += 1

        print '\n      indels in %s' % query_name
        print '          %20s %s' % (gene, glprintstr)
        print '          %20s %s' % ('query', qrprintstr)
        for idl in indelfo['indels']:
            print '          %10s: %d bases at %d (%s)' % (idl['type'], idl['len'], idl['pos'], idl['seqstr'])
        # utils.undo_indels(indelfo)
        # print '                       %s' % self.input_info[query_name]['seq']

        return indelfo

    # ----------------------------------------------------------------------------------------
    def process_query(self, bam, reads):
        primary = next((r for r in reads if not r.is_secondary), None)
        query_seq = primary.seq
        query_name = primary.qname
        first_match_query_bounds = None  # since sw excises its favorite v match, we have to know this match's boundaries in order to calculate k_d for all the other matches
        all_match_names = {}
        warnings = {}  # ick, this is a messy way to pass stuff around
        for region in utils.regions:
            all_match_names[region] = []
        all_query_bounds, all_germline_bounds = {}, {}
        n_skipped_invalid_cpos = 0
        for read in reads:  # loop over the matches found for each query sequence
            # set this match's values
            read.seq = query_seq  # only the first one has read.seq set by default, so we need to set the rest by hand
            gene = bam.references[read.tid]
            region = utils.get_region(gene)
            raw_score = read.tags[0][1]  # raw because they don't include the gene choice probs
            score = raw_score
            if self.args.apply_choice_probs_in_sw:  # NOTE I stopped applying the gene choice probs here because the smith-waterman scores don't correspond to log-probs, so throwing on the gene choice probs was dubious (and didn't seem to work that well)
                score = self.get_choice_prob(region, gene) * raw_score  # multiply by the probability to choose this gene
            qrbounds = (read.qstart, read.qend)
            glbounds = (read.pos, read.aend)
            if region == 'v' and first_match_query_bounds is None:
                first_match_query_bounds = qrbounds

            # perform a few checks and see if we want to skip this match
            if region == 'v':  # skip matches with cpos past the end of the query seq (i.e. eroded a ton on the right side of the v)
                cpos = utils.get_conserved_codon_position(self.cyst_positions, self.tryp_positions, 'v', gene, glbounds, qrbounds, assert_on_fail=False)
                if not utils.check_conserved_cysteine(self.germline_seqs['v'][gene], self.cyst_positions[gene]['cysteine-position'], assert_on_fail=False):  # some of the damn cysteine positions in the json file were wrong, so now we check
                    raise Exception('bad cysteine in %s: %d %s' % (gene, self.cyst_positions[gene]['cysteine-position'], self.germline_seqs['v'][gene]))
                if cpos < 0 or cpos >= len(query_seq):
                    n_skipped_invalid_cpos += 1
                    continue

            if 'I' in read.cigarstring or 'D' in read.cigarstring:  # skip indels, and tell the HMM to skip indels (you won't see any unless you decrease the <self.args.gap_open_penalty>)
                if len(all_match_names[region]) == 0:  # if this is the first (best) match for this region, allow indels (otherwise skip the match)
                    if query_name not in self.info['indels']:
                        self.info['indels'][query_name] = self.get_indel_info(query_name, read.cigarstring, query_seq[qrbounds[0] : qrbounds[1]], self.germline_seqs[region][gene][glbounds[0] : glbounds[1]], gene)
                        self.info['indels'][query_name]['reversed_seq'] = query_seq[ : qrbounds[0]] + self.info['indels'][query_name]['reversed_seq'] + query_seq[qrbounds[1] : ]
                        self.new_indels += 1
                        # print ' query seq  %s' % query_seq
                        # print 'indelfo seq %s' % self.info['indels'][query_name]['reversed_seq']
                        # self.info['skipped_indel_queries'].append(query_name)
                        # self.info[query_name] = {'indels'}
                    else:
                        print '     multiple indels for %s' % query_name
                    return
                else:
                    continue

            if qrbounds[1]-qrbounds[0] != glbounds[1]-glbounds[0]:
                raise Exception('germline match (%d %d) not same length as query match (%d %d)' % (qrbounds[0], qrbounds[1], glbounds[0], glbounds[1]))

            assert qrbounds[1] <= len(query_seq)
            if glbounds[1] > len(self.germline_seqs[region][gene]):
                print '  ', gene
                print '  ', glbounds[1], len(self.germline_seqs[region][gene])
                print '  ', self.germline_seqs[region][gene]
            assert glbounds[1] <= len(self.germline_seqs[region][gene])
            assert qrbounds[1]-qrbounds[0] == glbounds[1]-glbounds[0]

            # and finally add this match's information
            warnings[gene] = ''
            all_match_names[region].append((score, gene))  # NOTE it is important that this is ordered such that the best match is first
            all_query_bounds[gene] = qrbounds
            all_germline_bounds[gene] = glbounds

        # if n_skipped_invalid_cpos > 0:
        #     print '      skipped %d invalid cpos values for %s' % (n_skipped_invalid_cpos, query_name)
        self.summarize_query(query_name, query_seq, all_match_names, all_query_bounds, all_germline_bounds, warnings, first_match_query_bounds)

    # ----------------------------------------------------------------------------------------
    def print_match(self, region, gene, query_seq, score, glbounds, qrbounds, codon_pos, warnings, skipping=False):
        if self.debug < 2:
            return
        out_str_list = []
        buff_str = (20 - len(gene)) * ' '
        tmp_val = score
        if self.args.apply_choice_probs_in_sw and self.get_choice_prob(region, gene) != 0.0:
            tmp_val = score / self.get_choice_prob(region, gene)
        if self.args.apply_choice_probs_in_sw:
            out_str_list.append('%8s%s%s%9.1e * %3.0f = %-6.1f' % (' ', utils.color_gene(gene), buff_str, self.get_choice_prob(region, gene), tmp_val, score))
        else:
            out_str_list.append('%8s%s%s%9s%3s %6.0f        ' % (' ', utils.color_gene(gene), '', '', buff_str, score))
        out_str_list.append('%4d%4d   %s\n' % (glbounds[0], glbounds[1], self.germline_seqs[region][gene][glbounds[0]:glbounds[1]]))
        out_str_list.append('%46s  %4d%4d' % ('', qrbounds[0], qrbounds[1]))
        out_str_list.append('   %s ' % (utils.color_mutants(self.germline_seqs[region][gene][glbounds[0]:glbounds[1]], query_seq[qrbounds[0]:qrbounds[1]])))
        if region != 'd':
            out_str_list.append('(%s %d)' % (utils.conserved_codon_names[region], codon_pos))
        if warnings[gene] != '':
            out_str_list.append('WARNING ' + warnings[gene])
        if skipping:
            out_str_list.append('skipping!')
        if self.args.outfname is None:
            print ''.join(out_str_list)
        else:
            out_str_list.append('\n')
            self.outfile.write(''.join(out_str_list))

    # ----------------------------------------------------------------------------------------
    def shift_overlapping_boundaries(self, qrbounds, glbounds, query_name, query_seq, best):
        # NOTE this does pretty much the same thing as resolve_overlapping_matches in joinparser.py
        """ s-w allows d and j matches (and v and d matches) to overlap... which makes no sense, so apportion the disputed territory between the two regions """
        for region_pairs in ({'left':'v', 'right':'d'}, {'left':'d', 'right':'j'}):
            l_reg = region_pairs['left']
            r_reg = region_pairs['right']
            l_gene = best[l_reg]
            r_gene = best[r_reg]
            overlap = qrbounds[l_gene][1] - qrbounds[r_gene][0]
            if overlap > 0:
                l_length = qrbounds[l_gene][1] - qrbounds[l_gene][0]
                r_length = qrbounds[r_gene][1] - qrbounds[r_gene][0]
                l_portion, r_portion = 0, 0
                while l_portion + r_portion < overlap:
                    if l_length <= 1 and r_length <= 1:  # don't want to erode match (in practice it'll be the d match) all the way to zero
                        print '      ERROR both lengths went to zero'
                        assert False
                    elif l_length > 1 and r_length > 1:  # if both have length left, alternate back and forth
                        if (l_portion + r_portion) % 2 == 0:
                            l_portion += 1  # give one base to the left
                            l_length -= 1
                        else:
                            r_portion += 1  # and one to the right
                            r_length -= 1
                    elif l_length > 1:
                        l_portion += 1
                        l_length -= 1
                    elif r_length > 1:
                        r_portion += 1
                        r_length -= 1

                if self.debug:
                    print '      WARNING %s apportioning %d bases between %s (%d) match and %s (%d) match' % (query_name, overlap, l_reg, l_portion, r_reg, r_portion)
                assert l_portion + r_portion == overlap
                qrbounds[l_gene] = (qrbounds[l_gene][0], qrbounds[l_gene][1] - l_portion)
                glbounds[l_gene] = (glbounds[l_gene][0], glbounds[l_gene][1] - l_portion)
                qrbounds[r_gene] = (qrbounds[r_gene][0] + r_portion, qrbounds[r_gene][1])
                glbounds[r_gene] = (glbounds[r_gene][0] + r_portion, glbounds[r_gene][1])

                best[l_reg + '_gl_seq'] = self.germline_seqs[l_reg][l_gene][glbounds[l_gene][0] : glbounds[l_gene][1]]
                best[l_reg + '_qr_seq'] = query_seq[qrbounds[l_gene][0]:qrbounds[l_gene][1]]
                best[r_reg + '_gl_seq'] = self.germline_seqs[r_reg][r_gene][glbounds[r_gene][0] : glbounds[r_gene][1]]
                best[r_reg + '_qr_seq'] = query_seq[qrbounds[r_gene][0]:qrbounds[r_gene][1]]

    # ----------------------------------------------------------------------------------------
    def add_to_info(self, query_name, query_seq, kvals, match_names, best, all_germline_bounds, all_query_bounds, codon_positions):
        assert query_name not in self.info
        self.info['queries'].append(query_name)
        self.info[query_name] = {}
        self.info[query_name]['unique_id'] = query_name  # redundant, but used somewhere down the line
        self.info[query_name]['k_v'] = kvals['v']
        self.info[query_name]['k_d'] = kvals['d']
        self.info[query_name]['all'] = ':'.join(match_names['v'] + match_names['d'] + match_names['j'])

        # assert codon_positions['v'] != -1
        # assert codon_positions['j'] != -1
        self.info[query_name]['cdr3_length'] = codon_positions['j'] - codon_positions['v'] + 3  #tryp_position_in_joined_seq - self.cyst_position + 3
        self.info[query_name]['cyst_position'] = codon_positions['v']
        self.info[query_name]['tryp_position'] = codon_positions['j']
        if self.info[query_name]['cyst_position'] < 0 or self.info[query_name]['cyst_position'] >= len(query_seq):
            raise Exception('cpos %d invalid for %s (%s)' % (self.info[query_name]['cyst_position'], query_name, query_seq))
        if self.info[query_name]['tryp_position'] < 0 or self.info[query_name]['tryp_position'] >= len(query_seq):
            raise Exception('tpos %d invalid for %s (%s)' % (self.info[query_name]['tryp_position'], query_name, query_seq))

        # erosion, insertion, mutation info for best match
        self.info[query_name]['v_5p_del'] = all_germline_bounds[best['v']][0]
        self.info[query_name]['v_3p_del'] = len(self.germline_seqs['v'][best['v']]) - all_germline_bounds[best['v']][1]  # len(germline v) - gl_match_end
        self.info[query_name]['d_5p_del'] = all_germline_bounds[best['d']][0]
        self.info[query_name]['d_3p_del'] = len(self.germline_seqs['d'][best['d']]) - all_germline_bounds[best['d']][1]
        self.info[query_name]['j_5p_del'] = all_germline_bounds[best['j']][0]
        self.info[query_name]['j_3p_del'] = len(self.germline_seqs['j'][best['j']]) - all_germline_bounds[best['j']][1]

        self.info[query_name]['fv_insertion'] = query_seq[ : all_query_bounds[best['v']][0]]
        self.info[query_name]['vd_insertion'] = query_seq[all_query_bounds[best['v']][1] : all_query_bounds[best['d']][0]]
        self.info[query_name]['dj_insertion'] = query_seq[all_query_bounds[best['d']][1] : all_query_bounds[best['j']][0]]
        self.info[query_name]['jf_insertion'] = query_seq[all_query_bounds[best['j']][1] : ]

        for region in utils.regions:
            self.info[query_name][region + '_gene'] = best[region]
            self.info[query_name][region + '_gl_seq'] = best[region + '_gl_seq']
            self.info[query_name][region + '_qr_seq'] = best[region + '_qr_seq']
            self.info['all_best_matches'].add(best[region])

        self.info[query_name]['seq'] = query_seq  # NOTE this is the seq output by vdjalign, i.e. if we reversed any indels it is the reversed sequence
        if self.debug:
            if not self.args.is_data:
                utils.print_reco_event(self.germline_seqs, self.reco_info[query_name], extra_str='      ', label='true:', indelfo=self.reco_info[query_name]['indels'])
            utils.print_reco_event(self.germline_seqs, self.info[query_name], extra_str='      ', label='inferred:', indelfo=self.info['indels'].get(query_name, None))

        if self.pcounter is not None:
            self.pcounter.increment_reco_params(self.info[query_name])
            self.pcounter.increment_mutation_params(self.info[query_name])
        if self.true_pcounter is not None:
            self.true_pcounter.increment_reco_params(self.reco_info[query_name])
            self.true_pcounter.increment_mutation_params(self.reco_info[query_name])
        if self.perfplotter is not None:
            self.perfplotter.evaluate(self.reco_info[query_name], self.info[query_name])  #, subtract_unphysical_erosions=True)

        self.remaining_queries.remove(query_name)

    # ----------------------------------------------------------------------------------------
    def summarize_query(self, query_name, query_seq, all_match_names, all_query_bounds, all_germline_bounds, warnings, first_match_query_bounds):
        if self.debug:
            print '%s' % query_name

        best, match_names, n_matches = {}, {}, {}
        n_used = {'v':0, 'd':0, 'j':0}
        k_v_min, k_d_min = 999, 999
        k_v_max, k_d_max = 0, 0
        for region in utils.regions:
            all_match_names[region] = sorted(all_match_names[region], reverse=True)
            match_names[region] = []
        codon_positions = {'v':-1, 'd':-1, 'j':-1}  # conserved codon positions (v:cysteine, d:dummy, j:tryptophan)
        for region in utils.regions:
            n_matches[region] = len(all_match_names[region])
            n_skipped = 0
            for score, gene in all_match_names[region]:
                glbounds = all_germline_bounds[gene]
                qrbounds = all_query_bounds[gene]
                assert qrbounds[1] <= len(query_seq)  # NOTE I'm putting these up avove as well (in process_query), so in time I should remove them from here
                assert glbounds[1] <= len(self.germline_seqs[region][gene])
                assert qrbounds[0] >= 0
                assert glbounds[0] >= 0
                glmatchseq = self.germline_seqs[region][gene][glbounds[0]:glbounds[1]]

                # TODO since I'm no longer skipping the genes after the first <args.n_max_per_region>, the OR of k-space below is overly conservative

                # only use a specified set of genes
                if self.args.only_genes is not None and gene not in self.args.only_genes:
                    n_skipped += 1
                    continue

                # add match to the list
                n_used[region] += 1
                match_names[region].append(gene)

                self.print_match(region, gene, query_seq, score, glbounds, qrbounds, -1, warnings, skipping=False)

                # if the germline match and the query match aren't the same length, s-w likely added an insert, which we shouldn't get since the gap-open penalty is jacked up so high
                if len(glmatchseq) != len(query_seq[qrbounds[0]:qrbounds[1]]):  # neurotic double check (um, I think) EDIT hey this totally saved my ass
                    print 'ERROR %d not same length' % query_name
                    print glmatchseq, glbounds[0], glbounds[1]
                    print query_seq[qrbounds[0]:qrbounds[1]]
                    assert False

                if region == 'v':
                    this_k_v = all_query_bounds[gene][1]  # NOTE even if the v match doesn't start at the left hand edge of the query sequence, we still measure k_v from there.
                                                          # In other words, sw doesn't tell the hmm about it
                    k_v_min = min(this_k_v, k_v_min)
                    k_v_max = max(this_k_v, k_v_max)
                if region == 'd':
                    this_k_d = all_query_bounds[gene][1] - first_match_query_bounds[1]  # end of d minus end of v
                    k_d_min = min(this_k_d, k_d_min)
                    k_d_max = max(this_k_d, k_d_max)

                # check consistency with best match (since the best match is excised in s-w code, and because ham is run with *one* k_v k_d set)
                if region not in best:
                    best[region] = gene
                    best[region + '_gl_seq'] = self.germline_seqs[region][gene][glbounds[0]:glbounds[1]]
                    best[region + '_qr_seq'] = query_seq[qrbounds[0]:qrbounds[1]]
                    best[region + '_score'] = score

            if self.debug and n_skipped > 0:
                print '%8s skipped %d %s genes' % ('', n_skipped, region)

        for region in utils.regions:
            if region not in best:
                print '      no', region, 'match found for', query_name  # NOTE if no d match found, we should really just assume entire d was eroded
                return

        # s-w allows d and j matches to overlap, so we need to apportion the disputed bases
        try:
            self.shift_overlapping_boundaries(all_query_bounds, all_germline_bounds, query_name, query_seq, best)
        except AssertionError:
            print '%s: apportionment failed' % query_name
            return

        # check for unproductive rearrangements
        for region in utils.regions:
            codon_positions[region] = utils.get_conserved_codon_position(self.cyst_positions, self.tryp_positions, region, best[region], all_germline_bounds[best[region]], all_query_bounds[best[region]], assert_on_fail=False)  # position in the query sequence, that is
        codons_ok = utils.check_both_conserved_codons(query_seq, codon_positions['v'], codon_positions['j'], debug=self.debug, extra_str='      ', assert_on_fail=False)
        cdr3_length = codon_positions['j'] - codon_positions['v'] + 3
        in_frame_cdr3 = (cdr3_length % 3 == 0)
        if self.debug and not in_frame_cdr3:
                print '      out of frame cdr3: %d %% 3 = %d' % (cdr3_length, cdr3_length % 3)
        no_stop_codon = utils.stop_codon_check(query_seq, codon_positions['v'], debug=self.debug)
        if not codons_ok or not in_frame_cdr3 or not no_stop_codon:
            if self.debug:
                print '       unproductive rearrangement in waterer codons_ok: %s   in_frame_cdr3: %s   no_stop_codon: %s' % (codons_ok, in_frame_cdr3, no_stop_codon)
            if self.args.skip_unproductive:
                if self.debug:
                    print '            ...skipping'
                self.n_unproductive += 1
                self.info['skipped_unproductive_queries'].append(query_name)
                return

        # best k_v, k_d:
        k_v = all_query_bounds[best['v']][1]  # end of v match
        k_d = all_query_bounds[best['d']][1] - all_query_bounds[best['v']][1]  # end of d minus end of v

        if k_d_max < 5:  # since the s-w step matches to the longest possible j and then excises it, this sometimes gobbles up the d, resulting in a very short d alignment.
            if self.debug:
                print '  expanding k_d'
            k_d_max = max(8, k_d_max)

        if 'IGHJ4*' in best['j'] and self.germline_seqs['d'][best['d']][-5:] == 'ACTAC':  # the end of some d versions is the same as the start of some j versions, so the s-w frequently kicks out the 'wrong' alignment
            if self.debug:
                print '  doubly expanding k_d'
            if k_d_max-k_d_min < 8:
                k_d_min -= 5
                k_d_max += 2

        k_v_min = max(0, k_v_min - self.args.default_v_fuzz)  # ok, so I don't *actually* want it to be zero... oh, well
        k_v_max += self.args.default_v_fuzz
        k_d_min = max(1, k_d_min - self.args.default_d_fuzz)
        k_d_max += self.args.default_d_fuzz
        assert k_v_min > 0 and k_d_min > 0 and k_v_max > 0 and k_d_max > 0

        if self.debug:
            print '         k_v: %d [%d-%d)' % (k_v, k_v_min, k_v_max)
            print '         k_d: %d [%d-%d)' % (k_d, k_d_min, k_d_max)
            print '         used',
            for region in utils.regions:
                print ' %s: %d/%d' % (region, n_used[region], n_matches[region]),
            print ''


        kvals = {}
        kvals['v'] = {'best':k_v, 'min':k_v_min, 'max':k_v_max}
        kvals['d'] = {'best':k_d, 'min':k_d_min, 'max':k_d_max}
        self.add_to_info(query_name, query_seq, kvals, match_names, best, all_germline_bounds, all_query_bounds, codon_positions=codon_positions)
Exemplo n.º 4
0
            trueDictionary[unique_id] = {}
            trueDictionary[unique_id]['v_gene'] = row1['v_gene']
            trueDictionary[unique_id]['d_gene'] = row1['d_gene']
            trueDictionary[unique_id]['j_gene'] = row1['v_gene']
            #print trueDictionary[unique_id]
            iDictionary[unique_id] = {}
            iDictionary[unique_id]['v_gene'] = row2['Best V hit']
            iDictionary[unique_id]['d_gene'] = row2['Best D hit']
            iDictionary[unique_id]['j_gene'] = row2['Best J hit']
            #print iDictionary[unique_id]

#run evaluate function from performanceplotter.py
for key in trueDictionary:
    #if key == '123818946361786991':
    #print 'RUNNING EVALUATE ON: ', key
    perfplotter.evaluate(trueDictionary[key], iDictionary[key])
    #perfplotter.evaluate(trueDictionary[key], iDictionary[key])
print 'COMPLETED EVALUATE'
#plot the information gained from the 'evaluate' function
perfplotter.plot(mixcrPlotDir)
print mixcrPlotDir
print 'COMPLETED PLOTTING'
#----------------------------
#Code from previous development
'''	
with open("simu-10-leaves-1-mutate.csv") as inFile1:
	with open('edited_output_file.txt') as inFile2:
		reader1 = csv.DictReader(inFile1)
		reader2 = csv.DictReader(inFile2, delimiter='\t')
		for i1, i2 in zip(reader1, reader2):
			#gets the unique id number from the dictionary in the first id
Exemplo n.º 5
0
class Waterer(object):
    """ Run smith-waterman on the query sequences in <infname> """
    def __init__(self,
                 args,
                 input_info,
                 reco_info,
                 germline_seqs,
                 parameter_dir,
                 write_parameters=False):
        self.parameter_dir = parameter_dir
        self.args = args
        self.debug = self.args.debug if self.args.sw_debug is None else self.args.sw_debug

        self.input_info = input_info
        self.remaining_queries = [
            query for query in self.input_info.keys()
        ]  # we remove queries from this list when we're satisfied with the current output (in general we may have to rerun some queries with different match/mismatch scores)
        self.new_indels = 0  # number of new indels that were kicked up this time through

        self.reco_info = reco_info
        self.germline_seqs = germline_seqs
        self.pcounter, self.true_pcounter, self.perfplotter = None, None, None
        if write_parameters:
            self.pcounter = ParameterCounter(self.germline_seqs)
            if not self.args.is_data:
                self.true_pcounter = ParameterCounter(self.germline_seqs)
        if self.args.plot_performance:
            self.perfplotter = PerformancePlotter(self.germline_seqs, 'sw')
        self.info = {}
        self.info['queries'] = []
        self.info['all_best_matches'] = set(
        )  # set of all the matches we found (for *all* queries)
        self.info['skipped_unproductive_queries'] = [
        ]  # list of unproductive queries
        # self.info['skipped_indel_queries'] = []  # list of queries that had indels
        self.info['skipped_unknown_queries'] = []
        self.info['indels'] = {}
        if self.args.apply_choice_probs_in_sw:
            if self.debug:
                print '  reading gene choice probs from', parameter_dir
            self.gene_choice_probs = utils.read_overall_gene_probs(
                parameter_dir)

        with opener('r')(
                self.args.datadir + '/v-meta.json'
        ) as json_file:  # get location of <begin> cysteine in each v region
            self.cyst_positions = json.load(json_file)
        with opener('r')(
                self.args.datadir + '/j_tryp.csv'
        ) as csv_file:  # get location of <end> tryptophan in each j region (TGG)
            tryp_reader = csv.reader(csv_file)
            self.tryp_positions = {
                row[0]: row[1]
                for row in tryp_reader
            }  # WARNING: this doesn't filter out the header line

        self.outfile = None
        if self.args.outfname is not None:
            self.outfile = open(self.args.outfname, 'a')

        self.n_unproductive = 0
        self.n_total = 0

        print 'smith-waterman'

    # ----------------------------------------------------------------------------------------
    def __del__(self):
        if self.args.outfname is not None:
            self.outfile.close()

    # ----------------------------------------------------------------------------------------
    def clean(self):
        if self.pcounter is not None:
            self.pcounter.clean()
        if self.true_pcounter is not None:
            self.true_pcounter.clean()

    # ----------------------------------------------------------------------------------------
    def run(self):
        # start = time.time()
        base_infname = 'query-seqs.fa'
        base_outfname = 'query-seqs.bam'
        sys.stdout.flush()

        n_tries = 0
        while len(
                self.remaining_queries
        ) > 0:  # we remove queries from <self.remaining_queries> as we're satisfied with their output
            self.write_vdjalign_input(base_infname,
                                      n_procs=self.args.n_fewer_procs)
            self.execute_command(base_infname, base_outfname,
                                 self.args.n_fewer_procs)
            self.read_output(base_outfname, n_procs=self.args.n_fewer_procs)
            n_tries += 1
            if n_tries > 2:
                self.info['skipped_unknown_queries'] += self.remaining_queries
                break

        self.finalize()

    # ----------------------------------------------------------------------------------------
    def finalize(self):
        if self.perfplotter is not None:
            self.perfplotter.plot(self.args.plotdir + '/sw/performance')
        # print '    sw time: %.3f' % (time.time()-start)
        if self.n_unproductive > 0:
            print '      unproductive skipped %d / %d = %.2f' % (
                self.n_unproductive, self.n_total,
                float(self.n_unproductive) / self.n_total)
        # if len(self.info['skipped_indel_queries']) > 0:
        #     print '      indels skipped %d / %d = %.2f' % (len(self.info['skipped_indel_queries']), self.n_total, float(len(self.info['skipped_indel_queries'])) / self.n_total)
        if len(self.info['indels']) > 0:
            print '      indels: %s' % ':'.join(self.info['indels'].keys())
        if self.pcounter is not None:
            self.pcounter.write(self.parameter_dir)
            if self.args.plotdir is not None:
                self.pcounter.plot(self.args.plotdir + '/sw',
                                   subset_by_gene=True,
                                   cyst_positions=self.cyst_positions,
                                   tryp_positions=self.tryp_positions)
                if self.true_pcounter is not None:
                    self.true_pcounter.plot(self.args.plotdir + 'sw/true',
                                            subset_by_gene=True,
                                            cyst_positions=self.cyst_positions,
                                            tryp_positions=self.tryp_positions)

    # ----------------------------------------------------------------------------------------
    def execute_command(self, base_infname, base_outfname, n_procs):
        if n_procs == 1:
            cmd_str = self.get_vdjalign_cmd_str(self.args.workdir,
                                                base_infname, base_outfname)
            proc = Popen(cmd_str.split(), stdout=PIPE, stderr=PIPE)
            out, err = proc.communicate()
            utils.process_out_err(out, err)
            if not self.args.no_clean:
                os.remove(self.args.workdir + '/' + base_infname)
        else:
            procs = []
            for iproc in range(n_procs):
                cmd_str = self.get_vdjalign_cmd_str(
                    self.args.workdir + '/sw-' + str(iproc), base_infname,
                    base_outfname)
                procs.append(Popen(cmd_str.split(), stdout=PIPE, stderr=PIPE))
                time.sleep(0.1)
            for iproc in range(len(procs)):
                out, err = procs[iproc].communicate()
                utils.process_out_err(out, err, extra_str=str(iproc))
            if not self.args.no_clean:
                for iproc in range(n_procs):
                    os.remove(self.args.workdir + '/sw-' + str(iproc) + '/' +
                              base_infname)

        sys.stdout.flush()

    # ----------------------------------------------------------------------------------------
    def write_vdjalign_input(self, base_infname, n_procs):
        queries_per_proc = float(len(self.remaining_queries)) / n_procs
        n_queries_per_proc = int(math.ceil(queries_per_proc))
        if n_procs == 1:  # double check for rounding problems or whatnot
            assert n_queries_per_proc == len(self.remaining_queries)
        for iproc in range(n_procs):
            workdir = self.args.workdir
            if n_procs > 1:
                workdir += '/sw-' + str(iproc)
                utils.prep_dir(workdir)
            with opener('w')(workdir + '/' + base_infname) as sub_infile:
                for iquery in range(iproc * n_queries_per_proc,
                                    (iproc + 1) * n_queries_per_proc):
                    if iquery >= len(self.remaining_queries):
                        break
                    query_name = self.remaining_queries[iquery]
                    sub_infile.write('>' + query_name + ' NUKES\n')

                    seq = self.input_info[query_name]['seq']
                    if query_name in self.info['indels']:
                        seq = self.info['indels'][query_name][
                            'reversed_seq']  # use the query sequence with shm insertions and deletions reversed
                    sub_infile.write(seq + '\n')

    # ----------------------------------------------------------------------------------------
    def get_vdjalign_cmd_str(self, workdir, base_infname, base_outfname):
        """
        Run smith-waterman alignment (from Connor's ighutils package) on the seqs in <base_infname>, and toss all the top matches into <base_outfname>.
        """
        # large gap-opening penalty: we want *no* gaps in the middle of the alignments
        # match score larger than (negative) mismatch score: we want to *encourage* some level of shm. If they're equal, we tend to end up with short unmutated alignments, which screws everything up
        os.environ['PATH'] = os.getenv(
            'PWD') + '/packages/samtools:' + os.getenv('PATH')
        check_output(['which', 'samtools'])
        if not os.path.exists(self.args.ighutil_dir + '/bin/vdjalign'):
            raise Exception('ERROR ighutil path d.n.e: ' +
                            self.args.ighutil_dir + '/bin/vdjalign')
        cmd_str = self.args.ighutil_dir + '/bin/vdjalign align-fastq -q'
        if self.args.slurm:
            cmd_str = 'srun ' + cmd_str
        cmd_str += ' --max-drop 50'
        match, mismatch = self.args.match_mismatch
        cmd_str += ' --match ' + str(match) + ' --mismatch ' + str(mismatch)
        cmd_str += ' --gap-open ' + str(
            self.args.gap_open_penalty)  #1000'  #50'
        cmd_str += ' --vdj-dir ' + self.args.datadir
        cmd_str += ' ' + workdir + '/' + base_infname + ' ' + workdir + '/' + base_outfname

        return cmd_str

    # ----------------------------------------------------------------------------------------
    def read_output(self, base_outfname, n_procs=1):
        n_processed = 0
        for iproc in range(n_procs):
            workdir = self.args.workdir
            if n_procs > 1:
                workdir += '/sw-' + str(iproc)
            outfname = workdir + '/' + base_outfname
            with contextlib.closing(pysam.Samfile(outfname)) as bam:
                grouped = itertools.groupby(iter(bam),
                                            operator.attrgetter('qname'))
                for _, reads in grouped:  # loop over query sequences
                    self.n_total += 1
                    self.process_query(bam, list(reads))
                    n_processed += 1

            if not self.args.no_clean:
                os.remove(outfname)
                if n_procs > 1:  # still need the top-level workdir
                    os.rmdir(workdir)

        print '    processed %d queries' % n_processed

        if len(self.remaining_queries) > 0:
            if self.new_indels > 0:  # if we skipped some events, and if none of those were because they were indels, then increase mismatch score
                print '      skipped %d queries (%d indels), rerunning them' % (
                    len(self.remaining_queries), self.new_indels)
                self.new_indels = 0
            else:
                print '      skipped %d queries (%d indels), increasing mismatch score (%d --> %d) and rerunning them' % (
                    len(self.remaining_queries), self.new_indels,
                    self.args.match_mismatch[1],
                    self.args.match_mismatch[1] + 1)
                self.args.match_mismatch[1] += 1
                self.new_indels = 0

    # ----------------------------------------------------------------------------------------
    def get_choice_prob(self, region, gene):
        choice_prob = 1.0
        if gene in self.gene_choice_probs[region]:
            choice_prob = self.gene_choice_probs[region][gene]
        else:
            choice_prob = 0.0  # NOTE would it make sense to use something else here?
        return choice_prob

    # ----------------------------------------------------------------------------------------
    def get_indel_info(self, query_name, cigarstr, qrseq, glseq, gene):
        cigars = re.findall('[0-9][0-9]*[A-Z]',
                            cigarstr)  # split cigar string into its parts
        cigars = [(cstr[-1], int(cstr[:-1])) for cstr in cigars
                  ]  # split each part into the code and the length

        codestr = ''
        qpos = 0  # position within query sequence
        indelfo = {
            'reversed_seq': '',
            'indels': []
        }  # replacement_seq: query seq with insertions removed and germline bases inserted at the position of deletions
        tmp_indices = []
        for code, length in cigars:
            codestr += length * code
            if code == 'I':  # advance qr seq but not gl seq
                indelfo['indels'].append({
                    'type': 'insertion',
                    'pos': qpos,
                    'len': length,
                    'seqstr': ''
                })  # insertion begins at <pos>
                tmp_indices += [
                    len(indelfo['indels']) - 1 for _ in range(length)
                ]  # indel index corresponding to this position in the alignment
            elif code == 'D':  # advance qr seq but not gl seq
                indelfo['indels'].append(
                    {
                        'type': 'deletion',
                        'pos': qpos,
                        'len': length,
                        'seqstr': ''
                    }
                )  # first deleted base is <pos> (well, first base which is in the position of the first deleted base)
                tmp_indices += [
                    len(indelfo['indels']) - 1 for _ in range(length)
                ]  # indel index corresponding to this position in the alignment
            else:
                tmp_indices += [
                    None for _ in range(length)
                ]  # indel index corresponding to this position in the alignment
            qpos += length

        qrprintstr, glprintstr = '', ''
        iqr, igl = 0, 0
        for icode in range(len(codestr)):
            code = codestr[icode]
            if code == 'M':
                qrbase = qrseq[iqr]
                if qrbase != glseq[igl]:
                    qrbase = utils.color('red', qrbase)
                qrprintstr += qrbase
                glprintstr += glseq[igl]
                indelfo['reversed_seq'] += qrseq[
                    iqr]  # add the base to the overall sequence with all indels reversed
            elif code == 'S':
                continue
            elif code == 'I':
                qrprintstr += utils.color('light_blue', qrseq[iqr])
                glprintstr += utils.color('light_blue', '*')
                indelfo['indels'][tmp_indices[icode]]['seqstr'] += qrseq[
                    iqr]  # and to the sequence of just this indel
                igl -= 1
            elif code == 'D':
                qrprintstr += utils.color('light_blue', '*')
                glprintstr += utils.color('light_blue', glseq[igl])
                indelfo['reversed_seq'] += glseq[
                    igl]  # add the base to the overall sequence with all indels reversed
                indelfo['indels'][tmp_indices[icode]]['seqstr'] += glseq[
                    igl]  # and to the sequence of just this indel
                iqr -= 1
            else:
                raise Exception('unhandled code %s' % code)

            iqr += 1
            igl += 1

        print '\n      indels in %s' % query_name
        print '          %20s %s' % (gene, glprintstr)
        print '          %20s %s' % ('query', qrprintstr)
        for idl in indelfo['indels']:
            print '          %10s: %d bases at %d (%s)' % (
                idl['type'], idl['len'], idl['pos'], idl['seqstr'])
        # utils.undo_indels(indelfo)
        # print '                       %s' % self.input_info[query_name]['seq']

        return indelfo

    # ----------------------------------------------------------------------------------------
    def process_query(self, bam, reads):
        primary = next((r for r in reads if not r.is_secondary), None)
        query_seq = primary.seq
        query_name = primary.qname
        first_match_query_bounds = None  # since sw excises its favorite v match, we have to know this match's boundaries in order to calculate k_d for all the other matches
        all_match_names = {}
        warnings = {}  # ick, this is a messy way to pass stuff around
        for region in utils.regions:
            all_match_names[region] = []
        all_query_bounds, all_germline_bounds = {}, {}
        n_skipped_invalid_cpos = 0
        for read in reads:  # loop over the matches found for each query sequence
            # set this match's values
            read.seq = query_seq  # only the first one has read.seq set by default, so we need to set the rest by hand
            gene = bam.references[read.tid]
            region = utils.get_region(gene)
            raw_score = read.tags[0][
                1]  # raw because they don't include the gene choice probs
            score = raw_score
            if self.args.apply_choice_probs_in_sw:  # NOTE I stopped applying the gene choice probs here because the smith-waterman scores don't correspond to log-probs, so throwing on the gene choice probs was dubious (and didn't seem to work that well)
                score = self.get_choice_prob(
                    region, gene
                ) * raw_score  # multiply by the probability to choose this gene
            qrbounds = (read.qstart, read.qend)
            glbounds = (read.pos, read.aend)
            if region == 'v' and first_match_query_bounds is None:
                first_match_query_bounds = qrbounds

            # perform a few checks and see if we want to skip this match
            if region == 'v':  # skip matches with cpos past the end of the query seq (i.e. eroded a ton on the right side of the v)
                cpos = utils.get_conserved_codon_position(self.cyst_positions,
                                                          self.tryp_positions,
                                                          'v',
                                                          gene,
                                                          glbounds,
                                                          qrbounds,
                                                          assert_on_fail=False)
                if not utils.check_conserved_cysteine(
                        self.germline_seqs['v'][gene],
                        self.cyst_positions[gene]['cysteine-position'],
                        assert_on_fail=False
                ):  # some of the damn cysteine positions in the json file were wrong, so now we check
                    raise Exception(
                        'bad cysteine in %s: %d %s' %
                        (gene, self.cyst_positions[gene]['cysteine-position'],
                         self.germline_seqs['v'][gene]))
                if cpos < 0 or cpos >= len(query_seq):
                    n_skipped_invalid_cpos += 1
                    continue

            if 'I' in read.cigarstring or 'D' in read.cigarstring:  # skip indels, and tell the HMM to skip indels (you won't see any unless you decrease the <self.args.gap_open_penalty>)
                if len(
                        all_match_names[region]
                ) == 0:  # if this is the first (best) match for this region, allow indels (otherwise skip the match)
                    if query_name not in self.info['indels']:
                        self.info['indels'][query_name] = self.get_indel_info(
                            query_name, read.cigarstring,
                            query_seq[qrbounds[0]:qrbounds[1]],
                            self.germline_seqs[region][gene]
                            [glbounds[0]:glbounds[1]], gene)
                        self.info['indels'][query_name][
                            'reversed_seq'] = query_seq[:qrbounds[
                                0]] + self.info['indels'][query_name][
                                    'reversed_seq'] + query_seq[qrbounds[1]:]
                        self.new_indels += 1
                        # print ' query seq  %s' % query_seq
                        # print 'indelfo seq %s' % self.info['indels'][query_name]['reversed_seq']
                        # self.info['skipped_indel_queries'].append(query_name)
                        # self.info[query_name] = {'indels'}
                    else:
                        print '     multiple indels for %s' % query_name
                    return
                else:
                    continue

            if qrbounds[1] - qrbounds[0] != glbounds[1] - glbounds[0]:
                raise Exception(
                    'germline match (%d %d) not same length as query match (%d %d)'
                    % (qrbounds[0], qrbounds[1], glbounds[0], glbounds[1]))

            assert qrbounds[1] <= len(query_seq)
            if glbounds[1] > len(self.germline_seqs[region][gene]):
                print '  ', gene
                print '  ', glbounds[1], len(self.germline_seqs[region][gene])
                print '  ', self.germline_seqs[region][gene]
            assert glbounds[1] <= len(self.germline_seqs[region][gene])
            assert qrbounds[1] - qrbounds[0] == glbounds[1] - glbounds[0]

            # and finally add this match's information
            warnings[gene] = ''
            all_match_names[region].append(
                (score, gene)
            )  # NOTE it is important that this is ordered such that the best match is first
            all_query_bounds[gene] = qrbounds
            all_germline_bounds[gene] = glbounds

        # if n_skipped_invalid_cpos > 0:
        #     print '      skipped %d invalid cpos values for %s' % (n_skipped_invalid_cpos, query_name)
        self.summarize_query(query_name, query_seq, all_match_names,
                             all_query_bounds, all_germline_bounds, warnings,
                             first_match_query_bounds)

    # ----------------------------------------------------------------------------------------
    def print_match(self,
                    region,
                    gene,
                    query_seq,
                    score,
                    glbounds,
                    qrbounds,
                    codon_pos,
                    warnings,
                    skipping=False):
        if self.debug < 2:
            return
        out_str_list = []
        buff_str = (20 - len(gene)) * ' '
        tmp_val = score
        if self.args.apply_choice_probs_in_sw and self.get_choice_prob(
                region, gene) != 0.0:
            tmp_val = score / self.get_choice_prob(region, gene)
        if self.args.apply_choice_probs_in_sw:
            out_str_list.append(
                '%8s%s%s%9.1e * %3.0f = %-6.1f' %
                (' ', utils.color_gene(gene), buff_str,
                 self.get_choice_prob(region, gene), tmp_val, score))
        else:
            out_str_list.append(
                '%8s%s%s%9s%3s %6.0f        ' %
                (' ', utils.color_gene(gene), '', '', buff_str, score))
        out_str_list.append(
            '%4d%4d   %s\n' %
            (glbounds[0], glbounds[1],
             self.germline_seqs[region][gene][glbounds[0]:glbounds[1]]))
        out_str_list.append('%46s  %4d%4d' % ('', qrbounds[0], qrbounds[1]))
        out_str_list.append('   %s ' % (utils.color_mutants(
            self.germline_seqs[region][gene][glbounds[0]:glbounds[1]],
            query_seq[qrbounds[0]:qrbounds[1]])))
        if region != 'd':
            out_str_list.append(
                '(%s %d)' % (utils.conserved_codon_names[region], codon_pos))
        if warnings[gene] != '':
            out_str_list.append('WARNING ' + warnings[gene])
        if skipping:
            out_str_list.append('skipping!')
        if self.args.outfname is None:
            print ''.join(out_str_list)
        else:
            out_str_list.append('\n')
            self.outfile.write(''.join(out_str_list))

    # ----------------------------------------------------------------------------------------
    def shift_overlapping_boundaries(self, qrbounds, glbounds, query_name,
                                     query_seq, best):
        # NOTE this does pretty much the same thing as resolve_overlapping_matches in joinparser.py
        """ s-w allows d and j matches (and v and d matches) to overlap... which makes no sense, so apportion the disputed territory between the two regions """
        for region_pairs in ({
                'left': 'v',
                'right': 'd'
        }, {
                'left': 'd',
                'right': 'j'
        }):
            l_reg = region_pairs['left']
            r_reg = region_pairs['right']
            l_gene = best[l_reg]
            r_gene = best[r_reg]
            overlap = qrbounds[l_gene][1] - qrbounds[r_gene][0]
            if overlap > 0:
                l_length = qrbounds[l_gene][1] - qrbounds[l_gene][0]
                r_length = qrbounds[r_gene][1] - qrbounds[r_gene][0]
                l_portion, r_portion = 0, 0
                while l_portion + r_portion < overlap:
                    if l_length <= 1 and r_length <= 1:  # don't want to erode match (in practice it'll be the d match) all the way to zero
                        print '      ERROR both lengths went to zero'
                        assert False
                    elif l_length > 1 and r_length > 1:  # if both have length left, alternate back and forth
                        if (l_portion + r_portion) % 2 == 0:
                            l_portion += 1  # give one base to the left
                            l_length -= 1
                        else:
                            r_portion += 1  # and one to the right
                            r_length -= 1
                    elif l_length > 1:
                        l_portion += 1
                        l_length -= 1
                    elif r_length > 1:
                        r_portion += 1
                        r_length -= 1

                if self.debug:
                    print '      WARNING %s apportioning %d bases between %s (%d) match and %s (%d) match' % (
                        query_name, overlap, l_reg, l_portion, r_reg,
                        r_portion)
                assert l_portion + r_portion == overlap
                qrbounds[l_gene] = (qrbounds[l_gene][0],
                                    qrbounds[l_gene][1] - l_portion)
                glbounds[l_gene] = (glbounds[l_gene][0],
                                    glbounds[l_gene][1] - l_portion)
                qrbounds[r_gene] = (qrbounds[r_gene][0] + r_portion,
                                    qrbounds[r_gene][1])
                glbounds[r_gene] = (glbounds[r_gene][0] + r_portion,
                                    glbounds[r_gene][1])

                best[l_reg + '_gl_seq'] = self.germline_seqs[l_reg][l_gene][
                    glbounds[l_gene][0]:glbounds[l_gene][1]]
                best[l_reg + '_qr_seq'] = query_seq[
                    qrbounds[l_gene][0]:qrbounds[l_gene][1]]
                best[r_reg + '_gl_seq'] = self.germline_seqs[r_reg][r_gene][
                    glbounds[r_gene][0]:glbounds[r_gene][1]]
                best[r_reg + '_qr_seq'] = query_seq[
                    qrbounds[r_gene][0]:qrbounds[r_gene][1]]

    # ----------------------------------------------------------------------------------------
    def add_to_info(self, query_name, query_seq, kvals, match_names, best,
                    all_germline_bounds, all_query_bounds, codon_positions):
        assert query_name not in self.info
        self.info['queries'].append(query_name)
        self.info[query_name] = {}
        self.info[query_name][
            'unique_id'] = query_name  # redundant, but used somewhere down the line
        self.info[query_name]['k_v'] = kvals['v']
        self.info[query_name]['k_d'] = kvals['d']
        self.info[query_name]['all'] = ':'.join(match_names['v'] +
                                                match_names['d'] +
                                                match_names['j'])

        # assert codon_positions['v'] != -1
        # assert codon_positions['j'] != -1
        self.info[query_name][
            'cdr3_length'] = codon_positions['j'] - codon_positions[
                'v'] + 3  #tryp_position_in_joined_seq - self.cyst_position + 3
        self.info[query_name]['cyst_position'] = codon_positions['v']
        self.info[query_name]['tryp_position'] = codon_positions['j']
        if self.info[query_name]['cyst_position'] < 0 or self.info[query_name][
                'cyst_position'] >= len(query_seq):
            raise Exception('cpos %d invalid for %s (%s)' %
                            (self.info[query_name]['cyst_position'],
                             query_name, query_seq))
        if self.info[query_name]['tryp_position'] < 0 or self.info[query_name][
                'tryp_position'] >= len(query_seq):
            raise Exception('tpos %d invalid for %s (%s)' %
                            (self.info[query_name]['tryp_position'],
                             query_name, query_seq))

        # erosion, insertion, mutation info for best match
        self.info[query_name]['v_5p_del'] = all_germline_bounds[best['v']][0]
        self.info[query_name]['v_3p_del'] = len(
            self.germline_seqs['v'][best['v']]) - all_germline_bounds[
                best['v']][1]  # len(germline v) - gl_match_end
        self.info[query_name]['d_5p_del'] = all_germline_bounds[best['d']][0]
        self.info[query_name]['d_3p_del'] = len(self.germline_seqs['d'][
            best['d']]) - all_germline_bounds[best['d']][1]
        self.info[query_name]['j_5p_del'] = all_germline_bounds[best['j']][0]
        self.info[query_name]['j_3p_del'] = len(self.germline_seqs['j'][
            best['j']]) - all_germline_bounds[best['j']][1]

        self.info[query_name][
            'fv_insertion'] = query_seq[:all_query_bounds[best['v']][0]]
        self.info[query_name]['vd_insertion'] = query_seq[
            all_query_bounds[best['v']][1]:all_query_bounds[best['d']][0]]
        self.info[query_name]['dj_insertion'] = query_seq[
            all_query_bounds[best['d']][1]:all_query_bounds[best['j']][0]]
        self.info[query_name]['jf_insertion'] = query_seq[
            all_query_bounds[best['j']][1]:]

        for region in utils.regions:
            self.info[query_name][region + '_gene'] = best[region]
            self.info[query_name][region + '_gl_seq'] = best[region +
                                                             '_gl_seq']
            self.info[query_name][region + '_qr_seq'] = best[region +
                                                             '_qr_seq']
            self.info['all_best_matches'].add(best[region])

        self.info[query_name][
            'seq'] = query_seq  # NOTE this is the seq output by vdjalign, i.e. if we reversed any indels it is the reversed sequence
        if self.debug:
            if not self.args.is_data:
                utils.print_reco_event(
                    self.germline_seqs,
                    self.reco_info[query_name],
                    extra_str='      ',
                    label='true:',
                    indelfo=self.reco_info[query_name]['indels'])
            utils.print_reco_event(self.germline_seqs,
                                   self.info[query_name],
                                   extra_str='      ',
                                   label='inferred:',
                                   indelfo=self.info['indels'].get(
                                       query_name, None))

        if self.pcounter is not None:
            self.pcounter.increment_reco_params(self.info[query_name])
            self.pcounter.increment_mutation_params(self.info[query_name])
        if self.true_pcounter is not None:
            self.true_pcounter.increment_reco_params(
                self.reco_info[query_name])
            self.true_pcounter.increment_mutation_params(
                self.reco_info[query_name])
        if self.perfplotter is not None:
            self.perfplotter.evaluate(
                self.reco_info[query_name],
                self.info[query_name])  #, subtract_unphysical_erosions=True)

        self.remaining_queries.remove(query_name)

    # ----------------------------------------------------------------------------------------
    def summarize_query(self, query_name, query_seq, all_match_names,
                        all_query_bounds, all_germline_bounds, warnings,
                        first_match_query_bounds):
        if self.debug:
            print '%s' % query_name

        best, match_names, n_matches = {}, {}, {}
        n_used = {'v': 0, 'd': 0, 'j': 0}
        k_v_min, k_d_min = 999, 999
        k_v_max, k_d_max = 0, 0
        for region in utils.regions:
            all_match_names[region] = sorted(all_match_names[region],
                                             reverse=True)
            match_names[region] = []
        codon_positions = {
            'v': -1,
            'd': -1,
            'j': -1
        }  # conserved codon positions (v:cysteine, d:dummy, j:tryptophan)
        for region in utils.regions:
            n_matches[region] = len(all_match_names[region])
            n_skipped = 0
            for score, gene in all_match_names[region]:
                glbounds = all_germline_bounds[gene]
                qrbounds = all_query_bounds[gene]
                assert qrbounds[1] <= len(
                    query_seq
                )  # NOTE I'm putting these up avove as well (in process_query), so in time I should remove them from here
                assert glbounds[1] <= len(self.germline_seqs[region][gene])
                assert qrbounds[0] >= 0
                assert glbounds[0] >= 0
                glmatchseq = self.germline_seqs[region][gene][
                    glbounds[0]:glbounds[1]]

                # TODO since I'm no longer skipping the genes after the first <args.n_max_per_region>, the OR of k-space below is overly conservative

                # only use a specified set of genes
                if self.args.only_genes is not None and gene not in self.args.only_genes:
                    n_skipped += 1
                    continue

                # add match to the list
                n_used[region] += 1
                match_names[region].append(gene)

                self.print_match(region,
                                 gene,
                                 query_seq,
                                 score,
                                 glbounds,
                                 qrbounds,
                                 -1,
                                 warnings,
                                 skipping=False)

                # if the germline match and the query match aren't the same length, s-w likely added an insert, which we shouldn't get since the gap-open penalty is jacked up so high
                if len(glmatchseq) != len(
                        query_seq[qrbounds[0]:qrbounds[1]]
                ):  # neurotic double check (um, I think) EDIT hey this totally saved my ass
                    print 'ERROR %d not same length' % query_name
                    print glmatchseq, glbounds[0], glbounds[1]
                    print query_seq[qrbounds[0]:qrbounds[1]]
                    assert False

                if region == 'v':
                    this_k_v = all_query_bounds[gene][
                        1]  # NOTE even if the v match doesn't start at the left hand edge of the query sequence, we still measure k_v from there.
                    # In other words, sw doesn't tell the hmm about it
                    k_v_min = min(this_k_v, k_v_min)
                    k_v_max = max(this_k_v, k_v_max)
                if region == 'd':
                    this_k_d = all_query_bounds[gene][
                        1] - first_match_query_bounds[
                            1]  # end of d minus end of v
                    k_d_min = min(this_k_d, k_d_min)
                    k_d_max = max(this_k_d, k_d_max)

                # check consistency with best match (since the best match is excised in s-w code, and because ham is run with *one* k_v k_d set)
                if region not in best:
                    best[region] = gene
                    best[region + '_gl_seq'] = self.germline_seqs[region][
                        gene][glbounds[0]:glbounds[1]]
                    best[region +
                         '_qr_seq'] = query_seq[qrbounds[0]:qrbounds[1]]
                    best[region + '_score'] = score

            if self.debug and n_skipped > 0:
                print '%8s skipped %d %s genes' % ('', n_skipped, region)

        for region in utils.regions:
            if region not in best:
                print '      no', region, 'match found for', query_name  # NOTE if no d match found, we should really just assume entire d was eroded
                return

        # s-w allows d and j matches to overlap, so we need to apportion the disputed bases
        try:
            self.shift_overlapping_boundaries(all_query_bounds,
                                              all_germline_bounds, query_name,
                                              query_seq, best)
        except AssertionError:
            print '%s: apportionment failed' % query_name
            return

        # check for unproductive rearrangements
        for region in utils.regions:
            codon_positions[region] = utils.get_conserved_codon_position(
                self.cyst_positions,
                self.tryp_positions,
                region,
                best[region],
                all_germline_bounds[best[region]],
                all_query_bounds[best[region]],
                assert_on_fail=False
            )  # position in the query sequence, that is
        codons_ok = utils.check_both_conserved_codons(query_seq,
                                                      codon_positions['v'],
                                                      codon_positions['j'],
                                                      debug=self.debug,
                                                      extra_str='      ',
                                                      assert_on_fail=False)
        cdr3_length = codon_positions['j'] - codon_positions['v'] + 3
        in_frame_cdr3 = (cdr3_length % 3 == 0)
        if self.debug and not in_frame_cdr3:
            print '      out of frame cdr3: %d %% 3 = %d' % (cdr3_length,
                                                             cdr3_length % 3)
        no_stop_codon = utils.stop_codon_check(query_seq,
                                               codon_positions['v'],
                                               debug=self.debug)
        if not codons_ok or not in_frame_cdr3 or not no_stop_codon:
            if self.debug:
                print '       unproductive rearrangement in waterer codons_ok: %s   in_frame_cdr3: %s   no_stop_codon: %s' % (
                    codons_ok, in_frame_cdr3, no_stop_codon)
            if self.args.skip_unproductive:
                if self.debug:
                    print '            ...skipping'
                self.n_unproductive += 1
                self.info['skipped_unproductive_queries'].append(query_name)
                return

        # best k_v, k_d:
        k_v = all_query_bounds[best['v']][1]  # end of v match
        k_d = all_query_bounds[best['d']][1] - all_query_bounds[best['v']][
            1]  # end of d minus end of v

        if k_d_max < 5:  # since the s-w step matches to the longest possible j and then excises it, this sometimes gobbles up the d, resulting in a very short d alignment.
            if self.debug:
                print '  expanding k_d'
            k_d_max = max(8, k_d_max)

        if 'IGHJ4*' in best['j'] and self.germline_seqs['d'][best['d']][
                -5:] == 'ACTAC':  # the end of some d versions is the same as the start of some j versions, so the s-w frequently kicks out the 'wrong' alignment
            if self.debug:
                print '  doubly expanding k_d'
            if k_d_max - k_d_min < 8:
                k_d_min -= 5
                k_d_max += 2

        k_v_min = max(
            0, k_v_min - self.args.default_v_fuzz
        )  # ok, so I don't *actually* want it to be zero... oh, well
        k_v_max += self.args.default_v_fuzz
        k_d_min = max(1, k_d_min - self.args.default_d_fuzz)
        k_d_max += self.args.default_d_fuzz
        assert k_v_min > 0 and k_d_min > 0 and k_v_max > 0 and k_d_max > 0

        if self.debug:
            print '         k_v: %d [%d-%d)' % (k_v, k_v_min, k_v_max)
            print '         k_d: %d [%d-%d)' % (k_d, k_d_min, k_d_max)
            print '         used',
            for region in utils.regions:
                print ' %s: %d/%d' % (region, n_used[region],
                                      n_matches[region]),
            print ''

        kvals = {}
        kvals['v'] = {'best': k_v, 'min': k_v_min, 'max': k_v_max}
        kvals['d'] = {'best': k_d, 'min': k_d_min, 'max': k_d_max}
        self.add_to_info(query_name,
                         query_seq,
                         kvals,
                         match_names,
                         best,
                         all_germline_bounds,
                         all_query_bounds,
                         codon_positions=codon_positions)
Exemplo n.º 6
0
class IhhhmmmParser(object):
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir, remove_N_nukes=True)
        self.perfplotter = PerformancePlotter(self.germline_seqs, self.args.plotdir, "ihhhmmm")

        self.details = OrderedDict()
        self.failtails = {}
        self.n_partially_failed = 0

        # get sequence info that was passed to ihhhmmm
        self.siminfo = OrderedDict()
        self.sim_need = []  # list of queries that we still need to find
        with opener("r")(self.args.simfname) as seqfile:
            reader = csv.DictReader(seqfile)
            iline = 0
            for line in reader:
                if self.args.queries != None and line["unique_id"] not in self.args.queries:
                    continue
                self.siminfo[line["unique_id"]] = line
                self.sim_need.append(line["unique_id"])
                iline += 1
                if args.n_queries > 0 and iline >= args.n_queries:
                    break

        fostream_names = glob.glob(self.args.indir + "/*.fostream")
        if len(fostream_names) == 0:
            raise Exception("no fostreams found in %s" % args.indir)
        fostream_names.sort()  # maybe already sorted?
        for infname in fostream_names:
            if len(self.sim_need) == 0:
                break

            # try to get whatever you can for the failures
            unique_ids = self.find_partial_failures(infname)  # returns list of unique ids in this file

            with opener("r")(infname) as infile:
                self.parse_file(infile, unique_ids)

        # now check that we got results for all the queries we wanted
        n_failed = 0
        for unique_id in self.siminfo:
            if unique_id not in self.details and unique_id not in self.failtails:
                print "%-20s  no info" % unique_id
                self.perfplotter.add_fail()
                n_failed += 1

        print ""
        print "partially failed: %d / %d = %.2f" % (
            self.n_partially_failed,
            len(self.siminfo),
            float(self.n_partially_failed) / len(self.siminfo),
        )
        print "failed:           %d / %d = %.2f" % (n_failed, len(self.siminfo), float(n_failed) / len(self.siminfo))
        print ""

        self.perfplotter.plot()

    # ----------------------------------------------------------------------------------------
    def parse_file(self, infile, unique_ids):
        fk = FileKeeper(infile.readlines())
        i_id = 0
        while not fk.eof and len(self.sim_need) > 0:
            self.parse_detail(fk, unique_ids[i_id])
            i_id += 1

    # ----------------------------------------------------------------------------------------
    def parse_detail(self, fk, unique_id):
        assert fk.iline < len(fk.lines)

        while fk.line[1] != "Details":
            fk.increment()
            if fk.eof:
                return

        fk.increment()
        info = {}
        info["unique_id"] = unique_id
        for begin_line, column, index, required, default in line_order:
            if fk.line[0].find(begin_line) != 0:
                if required:
                    print "oop", begin_line, fk.line
                    sys.exit()
                else:
                    info[column] = default
                    continue
            if column != "":
                info[column] = clean_value(column, fk.line[index])
                # if '[' in info[column]:
                #     print 'added', column, clean_value(column, fk.line[index])
                if column.find("_gene") == 1:
                    region = column[0]
                    info[region + "_5p_del"] = (
                        int(fk.line[fk.line.index("start:") + 1]) - 1
                    )  # NOTE their indices are 1-based
                    gl_length = int(fk.line[fk.line.index("gene:") + 1]) - 1
                    match_end = int(fk.line[fk.line.index("end:") + 1]) - 1
                    assert gl_length >= match_end
                    info[region + "_3p_del"] = gl_length - match_end

            fk.increment()

        if unique_id not in self.sim_need:
            while not fk.eof and fk.line[1] != "Details":  # skip stuff until start of next Detail block
                fk.increment()
            return

        info["fv_insertion"] = ""
        info["jf_insertion"] = ""
        info["seq"] = (
            info["v_qr_seq"] + info["vd_insertion"] + info["d_qr_seq"] + info["dj_insertion"] + info["j_qr_seq"]
        )

        if "-" in info["seq"]:
            print "ERROR found a dash in %s, returning failure" % unique_id
            while not fk.eof and fk.line[1] != "Details":  # skip stuff until start of next Detail block
                fk.increment()
            return

        if (
            info["seq"] not in self.siminfo[unique_id]["seq"]
        ):  # arg. I can't do != because it tacks on v left and j right deletions
            print "ERROR didn't find the right sequence for %s" % unique_id
            print "  ", info["seq"]
            print "  ", self.siminfo[unique_id]["seq"]
            sys.exit()

        if self.args.debug:
            print unique_id
            for region in utils.regions:
                infer_gene = info[region + "_gene"]
                true_gene = self.siminfo[unique_id][region + "_gene"]
                if utils.are_alleles(infer_gene, true_gene):
                    regionstr = utils.color("bold", utils.color("blue", region))
                    truestr = ""  #'(originally %s)' % match_name
                else:
                    regionstr = utils.color("bold", utils.color("red", region))
                    truestr = "(true: %s)" % utils.color_gene(true_gene).replace(region, "")
                print "  %s %s %s" % (regionstr, utils.color_gene(infer_gene).replace(region, ""), truestr)

            utils.print_reco_event(self.germline_seqs, self.siminfo[unique_id], label="true:", extra_str="    ")
            utils.print_reco_event(self.germline_seqs, info, label="inferred:", extra_str="    ")

        for region in utils.regions:
            if info[region + "_gene"] not in self.germline_seqs[region]:
                print "ERROR %s not in germlines" % info[region + "_gene"]
                assert False

            gl_seq = info[region + "_gl_seq"]
            if "[" in gl_seq:  # ambiguous
                for nuke in utils.nukes:
                    gl_seq = gl_seq.replace("[", nuke)
                    if gl_seq in self.germline_seqs[region][info[region + "_gene"]]:
                        print "  replaced [ with %s" % nuke
                        break
                info[region + "_gl_seq"] = gl_seq

            if info[region + "_gl_seq"] not in self.germline_seqs[region][info[region + "_gene"]]:
                print "ERROR gl match not found for %s in %s" % (info[region + "_gene"], unique_id)
                print "  ", info[region + "_gl_seq"]
                print "  ", self.germline_seqs[region][info[region + "_gene"]]
                self.perfplotter.add_partial_fail(self.siminfo[unique_id], info)
                while not fk.eof and fk.line[1] != "Details":  # skip stuff until start of next Detail block
                    fk.increment()
                return

        self.perfplotter.evaluate(self.siminfo[unique_id], info)
        self.details[unique_id] = info
        self.sim_need.remove(unique_id)

        while not fk.eof and fk.line[1] != "Details":  # skip stuff until start of next Detail block
            fk.increment()

    # ----------------------------------------------------------------------------------------
    def find_partial_failures(self, fostream_name):
        unique_ids = []
        for line in open(fostream_name.replace(".fostream", "")).readlines():
            if len(self.sim_need) == 0:
                return
            if len(line.strip()) == 0:  # skip blank lines
                continue

            line = line.replace('"', "")
            line = line.split(";")

            unique_id = line[0]

            if "NA" not in line:  # skip lines that were ok
                unique_ids.append(unique_id)
                continue
            if unique_id not in self.sim_need:
                continue
            if unique_id not in self.siminfo:
                continue  # not looking for this <unique_id> a.t.m.

            info = {}
            info["unique_id"] = unique_id
            for stuff in line:
                for region in utils.regions:  # add the first instance of IGH[VDJ] (if it's there at all)
                    if "IGH" + region.upper() in stuff and region + "_gene" not in info:
                        genes = re.findall("IGH" + region.upper() + "[^ ][^ ]*", stuff)
                        if len(genes) == 0:
                            print "ERROR no %s genes in %s" % (region, stuff)
                        gene = genes[0]
                        if gene not in self.germline_seqs[region]:
                            print "ERROR bad gene %s for %s" % (gene, unique_id)
                            sys.exit()
                        info[region + "_gene"] = gene
            self.perfplotter.add_partial_fail(self.siminfo[unique_id], info)
            if self.args.debug:
                print "%-20s  partial fail %s %s %s" % (
                    unique_id,
                    utils.color_gene(info["v_gene"]) if "v_gene" in info else "",
                    utils.color_gene(info["d_gene"]) if "d_gene" in info else "",
                    utils.color_gene(info["j_gene"]) if "j_gene" in info else "",
                ),
                print "  (true %s %s %s)" % tuple(
                    [self.siminfo[unique_id][region + "_gene"] for region in utils.regions]
                )
            self.failtails[unique_id] = info
            self.n_partially_failed += 1
            self.sim_need.remove(unique_id)

        return unique_ids
Exemplo n.º 7
0
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir)

        perfplotter = PerformancePlotter(self.germline_seqs, self.args.plotdir,
                                         'imgt')

        # get sequence info that was passed to imgt
        self.seqinfo = {}
        with opener('r')(self.args.simfname) as simfile:
            reader = csv.DictReader(simfile)
            iline = 0
            for line in reader:
                if self.args.queries != None and line[
                        'unique_id'] not in self.args.queries:
                    continue
                if len(re.findall('_[FP]', line['j_gene'])) > 0:
                    line['j_gene'] = line['j_gene'].replace(
                        re.findall('_[FP]', line['j_gene'])[0], '')
                self.seqinfo[line['unique_id']] = line
                iline += 1
                if self.args.n_queries > 0 and iline >= self.args.n_queries:
                    break

        paragraphs, csv_info = None, None
        if self.args.infname != None and '.html' in self.args.infname:
            print 'reading', self.args.infname
            with opener('r')(self.args.infname) as infile:
                soup = BeautifulSoup(infile)
                paragraphs = soup.find_all('pre')

        summarydir = self.args.indir[:self.args.indir.rfind(
            '/'
        )]  # one directoy up from <indir>, which has the detailed per-sequence files
        summary_fname = glob.glob(summarydir + '/1_Summary_*.txt')
        assert len(summary_fname) == 1
        summary_fname = summary_fname[0]
        get_genes_to_skip(summary_fname, self.germline_seqs)

        n_failed, n_skipped, n_total, n_not_found, n_found = 0, 0, 0, 0, 0
        for unique_id in self.seqinfo:
            if self.args.debug:
                print unique_id,
            imgtinfo = []
            # print 'true'
            # utils.print_reco_event(self.germline_seqs, self.seqinfo[unique_id])
            if self.args.infname != None and '.html' in self.args.infname:
                for pre in paragraphs:  # NOTE this loops over everything an awful lot of times. Shouldn't really matter for now, though
                    if unique_id in pre.text:
                        imgtinfo.append(pre.text)
            else:
                n_total += 1
                assert self.args.infname == None
                infnames = glob.glob(self.args.indir + '/' + unique_id + '*')
                assert len(infnames) <= 1
                if len(infnames) != 1:
                    if self.args.debug:
                        print ' couldn\'t find it'
                    n_not_found += 1
                    continue
                n_found += 1
                with opener('r')(infnames[0]) as infile:
                    full_text = infile.read()
                    if len(
                            re.findall('[123]. Alignment for [VDJ]-GENE',
                                       full_text)) < 3:
                        failregions = re.findall(
                            'No [VDJ]-GENE has been identified', full_text)
                        if self.args.debug and len(failregions) > 0:
                            print '    ', failregions
                        n_failed += 1
                        continue

                    # loop over the paragraphs I want
                    position = full_text.find(unique_id)  # don't need this one
                    for ir in range(4):
                        position = full_text.find(unique_id, position + 1)
                        pgraph = full_text[position:full_text.
                                           find('\n\n', position + 1)]
                        if 'insertion(s) and/or deletion(s) which are not dealt in this release' in pgraph:
                            ir -= 1
                            continue
                        imgtinfo.append(pgraph)  # query seq paragraph

            if len(imgtinfo) == 0:
                print '%s no info' % unique_id
                continue
            else:
                if self.args.debug:
                    print ''
            line = self.parse_query_text(unique_id, imgtinfo)
            if 'skip_gene' in line:
                # assert self.args.skip_missing_genes
                n_skipped += 1
                continue
            try:
                assert 'failed' not in line
                joinparser.add_insertions(line, debug=self.args.debug)
                joinparser.resolve_overlapping_matches(
                    line, debug=False, germlines=self.germline_seqs)
            except (AssertionError, KeyError):
                print '    giving up'
                n_failed += 1
                perfplotter.add_partial_fail(self.seqinfo[unique_id], line)
                # print '    perfplotter: not sure what to do with a fail'
                continue
            perfplotter.evaluate(self.seqinfo[unique_id], line)
            if self.args.debug:
                utils.print_reco_event(self.germline_seqs,
                                       self.seqinfo[unique_id],
                                       label='true:')
                utils.print_reco_event(self.germline_seqs,
                                       line,
                                       label='inferred:')

        perfplotter.plot()
        print 'failed: %d / %d = %f' % (n_failed, n_total,
                                        float(n_failed) / n_total)
        print 'skipped: %d / %d = %f' % (n_skipped, n_total,
                                         float(n_skipped) / n_total)
        print '    ',
        for g, n in genes_actually_skipped.items():
            print '  %d %s' % (n, utils.color_gene(g))
        print ''
        if n_not_found > 0:
            print '  not found: %d / %d = %f' % (n_not_found, n_not_found +
                                                 n_found, n_not_found /
                                                 float(n_not_found + n_found))
    def read_hmm_output(self,
                        algorithm,
                        hmm_csv_outfname,
                        make_clusters=True,
                        count_parameters=False,
                        parameter_out_dir=None,
                        plotdir=None):
        print '    read output'
        if count_parameters:
            assert parameter_out_dir is not None
            assert plotdir is not None
        pcounter = ParameterCounter(
            self.germline_seqs) if count_parameters else None
        true_pcounter = ParameterCounter(self.germline_seqs) if (
            count_parameters and not self.args.is_data) else None
        perfplotter = PerformancePlotter(
            self.germline_seqs, plotdir +
            '/hmm/performance', 'hmm') if self.args.plot_performance else None

        n_processed = 0
        hmminfo = []
        with opener('r')(hmm_csv_outfname) as hmm_csv_outfile:
            reader = csv.DictReader(hmm_csv_outfile)
            last_key = None
            boundary_error_queries = []
            for line in reader:
                utils.intify(line, splitargs=('unique_ids', 'seqs'))
                ids = line['unique_ids']
                this_key = utils.get_key(ids)
                same_event = from_same_event(self.args.is_data, True,
                                             self.reco_info, ids)
                id_str = ''.join(['%20s ' % i for i in ids])

                # check for errors
                if last_key != this_key:  # if this is the first line for this set of ids (i.e. the best viterbi path or only forward score)
                    if line['errors'] != None and 'boundary' in line[
                            'errors'].split(':'):
                        boundary_error_queries.append(':'.join(
                            [str(uid) for uid in ids]))
                    else:
                        assert len(line['errors']) == 0

                if algorithm == 'viterbi':
                    line['seq'] = line['seqs'][
                        0]  # add info for the best match as 'seq'
                    line['unique_id'] = ids[0]
                    utils.add_match_info(self.germline_seqs,
                                         line,
                                         self.cyst_positions,
                                         self.tryp_positions,
                                         debug=(self.args.debug > 0))

                    if last_key != this_key or self.args.plot_all_best_events:  # if this is the first line (i.e. the best viterbi path) for this query (or query pair), print the true event
                        n_processed += 1
                        if self.args.debug:
                            print '%s   %d' % (id_str, same_event)
                        if line['cdr3_length'] != -1 or not self.args.skip_unproductive:  # if it's productive, or if we're not skipping unproductive rearrangements
                            hmminfo.append(
                                dict([
                                    ('unique_id', line['unique_ids'][0]),
                                ] + line.items()))
                            if pcounter is not None:  # increment counters (but only for the best [first] match)
                                pcounter.increment(line)
                            if true_pcounter is not None:  # increment true counters
                                true_pcounter.increment(self.reco_info[ids[0]])
                            if perfplotter is not None:
                                perfplotter.evaluate(self.reco_info[ids[0]],
                                                     line)

                    if self.args.debug:
                        self.print_hmm_output(
                            line,
                            print_true=(last_key != this_key),
                            perfplotter=perfplotter)
                    line['seq'] = None
                    line['unique_id'] = None

                else:  # for forward, write the pair scores to file to be read by the clusterer
                    if not make_clusters:  # self.args.debug or
                        print '%3d %10.3f    %s' % (
                            same_event, float(line['score']), id_str)
                    if line['score'] == '-nan':
                        print '    WARNING encountered -nan, setting to -999999.0'
                        score = -999999.0
                    else:
                        score = float(line['score'])
                    if len(ids) == 2:
                        hmminfo.append({
                            'id_a': line['unique_ids'][0],
                            'id_b': line['unique_ids'][1],
                            'score': score
                        })
                    n_processed += 1

                last_key = utils.get_key(ids)

        if pcounter is not None:
            pcounter.write(parameter_out_dir)
            if not self.args.no_plot:
                pcounter.plot(plotdir,
                              subset_by_gene=True,
                              cyst_positions=self.cyst_positions,
                              tryp_positions=self.tryp_positions)
        if true_pcounter is not None:
            true_pcounter.write(parameter_out_dir + '/true')
            if not self.args.no_plot:
                true_pcounter.plot(plotdir + '/true',
                                   subset_by_gene=True,
                                   cyst_positions=self.cyst_positions,
                                   tryp_positions=self.tryp_positions)
        if perfplotter is not None:
            perfplotter.plot()

        print '  processed %d queries' % n_processed
        if len(boundary_error_queries) > 0:
            print '    %d boundary errors (%s)' % (
                len(boundary_error_queries), ', '.join(boundary_error_queries))

        return hmminfo
Exemplo n.º 9
0
class IgblastParser(object):
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir,
                                                  remove_N_nukes=True)

        self.perfplotter = PerformancePlotter(self.germline_seqs,
                                              self.args.plotdir, 'igblast')
        self.n_total, self.n_partially_failed = 0, 0

        # get sequence info that was passed to igblast
        self.seqinfo = {}
        with opener('r')(self.args.simfname) as simfile:
            reader = csv.DictReader(simfile)
            iline = 0
            for line in reader:
                if self.args.n_max_queries > 0 and iline >= self.args.n_max_queries:
                    break
                iline += 1
                if self.args.queries != None and int(
                        line['unique_id']) not in self.args.queries:
                    continue
                if len(re.findall('_[FP]', line['j_gene'])) > 0:
                    line['j_gene'] = line['j_gene'].replace(
                        re.findall('_[FP]', line['j_gene'])[0], '')
                self.seqinfo[int(line['unique_id'])] = line

        paragraphs = None
        print 'reading', self.args.infname
        info = {}
        with opener('r')(self.args.infname) as infile:
            line = infile.readline()
            # first find the start of the next query's section
            while line.find('<b>Query=') != 0:
                line = infile.readline()
            # then keep going till eof
            iquery = 0
            while line != '':
                if self.args.n_max_queries > 0 and iquery >= self.args.n_max_queries:
                    break
                # first find the query name
                query_name = int(line.split()[1])
                # and collect the lines for this query
                query_lines = []
                line = infile.readline()
                while line.find('<b>Query=') != 0:
                    query_lines.append(line.strip())
                    line = infile.readline()
                    if line == '':
                        break
                iquery += 1
                # then see if we want this query
                if self.args.queries != None and query_name not in self.args.queries:
                    continue
                if query_name not in self.seqinfo:
                    print 'ERROR %d not in reco info' % query_name
                    sys.exit()
                if self.args.debug:
                    print query_name
                # and finally add the query to <info[query_name]>
                info[query_name] = {'unique_id': query_name}
                self.n_total += 1
                self.process_query(info[query_name], query_name, query_lines)

        self.perfplotter.plot()
        print 'partially failed: %d / %d = %f' % (
            self.n_partially_failed, self.n_total,
            float(self.n_partially_failed) / self.n_total)

    # ----------------------------------------------------------------------------------------
    def process_query(self, qr_info, query_name, query_lines):
        # split query_lines up into blocks
        blocks = []
        for line in query_lines:
            if line.find('Query_') == 0:
                blocks.append([])
            if len(line) == 0:
                continue
            if len(re.findall('<a name=#_[0-9][0-9]*_IGH',
                              line)) == 0 and line.find('Query_') != 0:
                continue
            if len(blocks) == 0:
                print 'wtf? %s' % query_name  # it's probably kicking a reverse match
                self.perfplotter.add_partial_fail(
                    self.seqinfo[query_name],
                    qr_info)  # NOTE that's really a total failure
                self.n_partially_failed += 1
                return
            blocks[-1].append(line)

        # then process each block
        for block in blocks:
            self.process_single_block(block, query_name, qr_info)
            if 'fail' in qr_info:
                self.perfplotter.add_partial_fail(self.seqinfo[query_name],
                                                  qr_info)
                self.n_partially_failed += 1
                return

        for region in utils.regions:
            if region + '_gene' not in qr_info:
                print '  ERROR no %s match for %d' % (region, query_name)
                self.perfplotter.add_partial_fail(self.seqinfo[query_name],
                                                  qr_info)
                self.n_partially_failed += 1
                return

        # expand v match to left end and j match to right end
        qr_info['v_5p_del'] = 0
        qr_info['fv_insertion'] = ''
        if qr_info['match_start'] > 0:
            if self.args.debug:
                print '    add to v left:', self.seqinfo[query_name][
                    'seq'][:qr_info['match_start']]
            qr_info['seq'] = self.seqinfo[query_name][
                'seq'][:qr_info['match_start']] + qr_info['seq']

        qr_info['j_3p_del'] = 0
        qr_info['jf_insertion'] = ''
        if len(self.seqinfo[query_name]['seq']) > qr_info['match_end']:
            if self.args.debug:
                print '    add to j right:', self.seqinfo[query_name][
                    'seq'][qr_info['match_end'] -
                           len(self.seqinfo[query_name]['seq']):]
            qr_info['seq'] = qr_info['seq'] + self.seqinfo[query_name]['seq'][
                qr_info['match_end'] - len(self.seqinfo[query_name]['seq']):]

        for boundary in utils.boundaries:
            start = qr_info[boundary[0] + '_qr_bounds'][1]
            end = qr_info[boundary[1] + '_qr_bounds'][0]
            qr_info[boundary + '_insertion'] = qr_info['seq'][start:end]

        for region in utils.regions:
            start = qr_info[region + '_qr_bounds'][0]
            end = qr_info[region + '_qr_bounds'][1]
            qr_info[region + '_qr_seq'] = qr_info['seq'][start:end]

        try:
            resolve_overlapping_matches(qr_info, self.args.debug,
                                        self.germline_seqs)
        except AssertionError:
            print 'ERROR apportionment failed on %s' % query_name
            self.perfplotter.add_partial_fail(self.seqinfo[query_name],
                                              qr_info)
            self.n_partially_failed += 1
            return

        if self.args.debug:
            print '  query seq:', qr_info['seq']
            for region in utils.regions:
                print '    %s %3d %3d %s %s' % (
                    region, qr_info[region + '_qr_bounds'][0],
                    qr_info[region + '_qr_bounds'][1],
                    utils.color_gene(qr_info[region + '_gene']),
                    qr_info[region + '_gl_seq'])
        for boundary in utils.boundaries:
            start = qr_info[boundary[0] + '_qr_bounds'][1]
            end = qr_info[boundary[1] + '_qr_bounds'][0]
            qr_info[boundary + '_insertion'] = qr_info['seq'][start:end]
            if self.args.debug:
                print '   ', boundary, qr_info[boundary + '_insertion']

        self.perfplotter.evaluate(self.seqinfo[query_name], qr_info)
        # for key, val in qr_info.items():
        #     print key, val
        if self.args.debug:
            utils.print_reco_event(self.germline_seqs,
                                   self.seqinfo[query_name],
                                   label='true:',
                                   extra_str='  ')
            utils.print_reco_event(self.germline_seqs, qr_info, extra_str=' ')

    # ----------------------------------------------------------------------------------------
    def process_single_block(self, block, query_name, qr_info):
        assert block[0].find('Query_') == 0
        vals = block[0].split()
        qr_start = int(
            vals[1]) - 1  # converting from one-indexed to zero-indexed
        qr_seq = vals[2]
        qr_end = int(
            vals[3]
        )  # ...and from inclusive of both bounds to normal programming conventions
        if qr_seq not in self.seqinfo[query_name]['seq']:
            if '-' in qr_seq:
                print '  WARNING insertion inside query seq for %s, treating as partial failure' % query_name
                qr_info['fail'] = True
                return
            else:
                print '  ERROR query seq from igblast info not found in original query seq for %d' % query_name
                print '    %s' % qr_seq
                print '    %s' % self.seqinfo[query_name]['seq']
                sys.exit()

        if 'seq' in qr_info:
            qr_info['seq'] += qr_seq
        else:
            qr_info['seq'] = qr_seq

        # keep track of the absolute first and absolute last bases matched so we can later work out the fv and jf insertions
        if 'match_start' not in qr_info or qr_start < qr_info['match_start']:
            qr_info['match_start'] = qr_start
        if 'match_end' not in qr_info or qr_end > qr_info['match_end']:
            qr_info['match_end'] = qr_end

        if self.args.debug:
            print '      query: %3d %3d %s' % (qr_start, qr_end, qr_seq)
        for line in block[1:]:
            gene = line[line.rfind('IGH'):line.rfind('</a>')]
            region = utils.get_region(gene)
            if gene not in self.germline_seqs[region]:
                print '  ERROR %s not found in germlines' % gene
                qr_info['fail'] = True
                return

            vals = line.split()
            gl_start = int(
                vals[-3]) - 1  # converting from one-indexed to zero-indexed
            gl_seq = vals[-2]
            gl_end = int(
                vals[-1]
            )  # ...and from inclusive of both bounds to normal programming conventions

            if region + '_gene' in qr_info:
                if qr_info[region + '_gene'] == gene:
                    if self.args.debug:
                        print '        %s match: %s' % (
                            region, clean_alignment_crap(qr_seq, gl_seq))
                    qr_info[region + '_gl_seq'] = qr_info[
                        region + '_gl_seq'] + clean_alignment_crap(
                            qr_seq, gl_seq)
                    assert gl_end <= len(self.germline_seqs[region][gene])
                    qr_info[region + '_3p_del'] = len(
                        self.germline_seqs[region][gene]) - gl_end
                    qr_info[region +
                            '_qr_bounds'] = (qr_info[region + '_qr_bounds'][0],
                                             find_qr_bounds(
                                                 qr_start, qr_end, gl_seq)[1])
                else:
                    continue
            else:
                qr_info[region + '_gene'] = gene
                qr_info[region + '_gl_seq'] = clean_alignment_crap(
                    qr_seq, gl_seq)
                # deletions
                qr_info[region + '_5p_del'] = gl_start
                assert gl_end <= len(self.germline_seqs[region][gene])
                qr_info[region + '_3p_del'] = len(
                    self.germline_seqs[region][gene]) - gl_end
                # bounds
                qr_info[region + '_qr_bounds'] = find_qr_bounds(
                    qr_start, qr_end, gl_seq)
                if self.args.debug:
                    print '        %s match: %s' % (
                        region, clean_alignment_crap(qr_seq, gl_seq))
Exemplo n.º 10
0
    def read_hmm_output(self, algorithm, hmm_csv_outfname, make_clusters=True, count_parameters=False, parameter_out_dir=None, plotdir=None):
        print '    read output'
        if count_parameters:
            assert parameter_out_dir is not None
            assert plotdir is not None
        pcounter = ParameterCounter(self.germline_seqs) if count_parameters else None
        true_pcounter = ParameterCounter(self.germline_seqs) if (count_parameters and not self.args.is_data) else None
        perfplotter = PerformancePlotter(self.germline_seqs, plotdir + '/hmm/performance', 'hmm') if self.args.plot_performance else None

        n_processed = 0
        hmminfo = []
        with opener('r')(hmm_csv_outfname) as hmm_csv_outfile:
            reader = csv.DictReader(hmm_csv_outfile)
            last_key = None
            boundary_error_queries = []
            for line in reader:
                utils.intify(line, splitargs=('unique_ids', 'seqs'))
                ids = line['unique_ids']
                this_key = utils.get_key(ids)
                same_event = from_same_event(self.args.is_data, True, self.reco_info, ids)
                id_str = ''.join(['%20s ' % i for i in ids])

                # check for errors
                if last_key != this_key:  # if this is the first line for this set of ids (i.e. the best viterbi path or only forward score)
                    if line['errors'] != None and 'boundary' in line['errors'].split(':'):
                        boundary_error_queries.append(':'.join([str(uid) for uid in ids]))
                    else:
                        assert len(line['errors']) == 0

                if algorithm == 'viterbi':
                    line['seq'] = line['seqs'][0]  # add info for the best match as 'seq'
                    line['unique_id'] = ids[0]
                    utils.add_match_info(self.germline_seqs, line, self.cyst_positions, self.tryp_positions, debug=(self.args.debug > 0))

                    if last_key != this_key or self.args.plot_all_best_events:  # if this is the first line (i.e. the best viterbi path) for this query (or query pair), print the true event
                        n_processed += 1
                        if self.args.debug:
                            print '%s   %d' % (id_str, same_event)
                        if line['cdr3_length'] != -1 or not self.args.skip_unproductive:  # if it's productive, or if we're not skipping unproductive rearrangements
                            hmminfo.append(dict([('unique_id', line['unique_ids'][0]), ] + line.items()))
                            if pcounter is not None:  # increment counters (but only for the best [first] match)
                                pcounter.increment(line)
                            if true_pcounter is not None:  # increment true counters
                                true_pcounter.increment(self.reco_info[ids[0]])
                            if perfplotter is not None:
                                perfplotter.evaluate(self.reco_info[ids[0]], line)

                    if self.args.debug:
                        self.print_hmm_output(line, print_true=(last_key != this_key), perfplotter=perfplotter)
                    line['seq'] = None
                    line['unique_id'] = None

                else:  # for forward, write the pair scores to file to be read by the clusterer
                    if not make_clusters:  # self.args.debug or 
                        print '%3d %10.3f    %s' % (same_event, float(line['score']), id_str)
                    if line['score'] == '-nan':
                        print '    WARNING encountered -nan, setting to -999999.0'
                        score = -999999.0
                    else:
                        score = float(line['score'])
                    if len(ids) == 2:
                        hmminfo.append({'id_a':line['unique_ids'][0], 'id_b':line['unique_ids'][1], 'score':score})
                    n_processed += 1

                last_key = utils.get_key(ids)

        if pcounter is not None:
            pcounter.write(parameter_out_dir)
            if not self.args.no_plot:
                pcounter.plot(plotdir, subset_by_gene=True, cyst_positions=self.cyst_positions, tryp_positions=self.tryp_positions)
        if true_pcounter is not None:
            true_pcounter.write(parameter_out_dir + '/true')
            if not self.args.no_plot:
                true_pcounter.plot(plotdir + '/true', subset_by_gene=True, cyst_positions=self.cyst_positions, tryp_positions=self.tryp_positions)
        if perfplotter is not None:
            perfplotter.plot()

        print '  processed %d queries' % n_processed
        if len(boundary_error_queries) > 0:
            print '    %d boundary errors (%s)' % (len(boundary_error_queries), ', '.join(boundary_error_queries))

        return hmminfo
Exemplo n.º 11
0
class IgblastParser(object):
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir, remove_N_nukes=True)

        self.perfplotter = PerformancePlotter(self.germline_seqs, self.args.plotdir, 'igblast')
        self.n_total, self.n_partially_failed, self.n_skipped = 0, 0, 0

        # get sequence info that was passed to igblast
        self.seqinfo = {}
        with opener('r')(self.args.simfname) as simfile:
            reader = csv.DictReader(simfile)
            iline = 0
            for line in reader:
                if self.args.n_queries > 0 and iline >= self.args.n_queries:
                    break
                iline += 1
                if self.args.queries != None and int(line['unique_id']) not in self.args.queries:
                    continue
                if len(re.findall('_[FP]', line['j_gene'])) > 0:
                    line['j_gene'] = line['j_gene'].replace(re.findall('_[FP]', line['j_gene'])[0], '')
                self.seqinfo[int(line['unique_id'])] = line

        print 'reading', self.args.infname

        get_genes_to_skip(self.args.infname, self.germline_seqs, method='igblast', debug=False)

        paragraphs = None
        info = {}
        with opener('r')(self.args.infname) as infile:
            line = infile.readline()
            # first find the start of the next query's section
            while line.find('<b>Query=') != 0:
                line = infile.readline()
            # then keep going till eof
            iquery = 0
            while line != '':
                if self.args.n_queries > 0 and iquery >= self.args.n_queries:
                    break
                # first find the query name
                query_name = int(line.split()[1])
                # and collect the lines for this query
                query_lines = []
                line = infile.readline()
                while line.find('<b>Query=') != 0:
                    query_lines.append(line.strip())
                    line = infile.readline()
                    if line == '':
                        break
                iquery += 1
                # then see if we want this query
                if self.args.queries != None and query_name not in self.args.queries:
                    continue
                if query_name not in self.seqinfo:
                    print 'ERROR %d not in reco info' % query_name
                    sys.exit()
                if self.args.debug:
                    print query_name
                # and finally add the query to <info[query_name]>
                info[query_name] = {'unique_id':query_name}
                self.n_total += 1
                self.process_query(info[query_name], query_name, query_lines)

        self.perfplotter.plot()
        print 'partially failed: %d / %d = %f' % (self.n_partially_failed, self.n_total, float(self.n_partially_failed) / self.n_total)
        print 'skipped: %d / %d = %f' % (self.n_skipped, self.n_total, float(self.n_skipped) / self.n_total)
        for g, n in genes_actually_skipped.items():
            print '  %d %s' % (n, utils.color_gene(g))

    # ----------------------------------------------------------------------------------------
    def process_query(self, qr_info, query_name, query_lines):
        # split query_lines up into blocks
        blocks = []
        for line in query_lines:
            if line.find('Query_') == 0:
                blocks.append([])
            if len(line) == 0:
                continue
            if len(re.findall('<a name=#_[0-9][0-9]*_IGH', line)) == 0 and line.find('Query_') != 0:
                continue
            if len(blocks) == 0:
                print 'wtf? %s' % query_name  # it's probably kicking a reverse match
                self.perfplotter.add_partial_fail(self.seqinfo[query_name], qr_info)  # NOTE that's really a total failure
                self.n_partially_failed += 1
                return
            blocks[-1].append(line)

        # then process each block
        for block in blocks:
            self.process_single_block(block, query_name, qr_info)
            if 'skip_gene' in qr_info:
                self.n_skipped += 1
                return
            if 'fail' in qr_info:
                self.perfplotter.add_partial_fail(self.seqinfo[query_name], qr_info)
                self.n_partially_failed += 1
                return

        for region in utils.regions:
            if region + '_gene' not in qr_info:
                print '    %d: no %s match' % (query_name, region)
                self.perfplotter.add_partial_fail(self.seqinfo[query_name], qr_info)
                self.n_partially_failed += 1
                return

        # expand v match to left end and j match to right end
        qr_info['v_5p_del'] = 0
        qr_info['fv_insertion'] = ''
        if qr_info['match_start'] > 0:
            if self.args.debug:
                print '    add to v left:', self.seqinfo[query_name]['seq'][ : qr_info['match_start']]
            qr_info['seq'] = self.seqinfo[query_name]['seq'][ : qr_info['match_start']] + qr_info['seq']

        qr_info['j_3p_del'] = 0
        qr_info['jf_insertion'] = ''
        if len(self.seqinfo[query_name]['seq']) > qr_info['match_end']:
            if self.args.debug:
                print '    add to j right:', self.seqinfo[query_name]['seq'][ qr_info['match_end'] - len(self.seqinfo[query_name]['seq']) : ]
            qr_info['seq'] = qr_info['seq'] + self.seqinfo[query_name]['seq'][ qr_info['match_end'] - len(self.seqinfo[query_name]['seq']) : ]

        for boundary in utils.boundaries:
            start = qr_info[boundary[0] + '_qr_bounds'][1]
            end = qr_info[boundary[1] + '_qr_bounds'][0]
            qr_info[boundary + '_insertion'] = qr_info['seq'][start : end]

        for region in utils.regions:
            start = qr_info[region + '_qr_bounds'][0]
            end = qr_info[region + '_qr_bounds'][1]
            qr_info[region + '_qr_seq'] = qr_info['seq'][start : end]

        try:
            resolve_overlapping_matches(qr_info, self.args.debug, self.germline_seqs)
        except AssertionError:
            print '    %s: apportionment failed' % query_name
            self.perfplotter.add_partial_fail(self.seqinfo[query_name], qr_info)
            self.n_partially_failed += 1
            return

        if self.args.debug:
            print '  query seq:', qr_info['seq']
            for region in utils.regions:
                true_gene = self.seqinfo[query_name][region + '_gene']
                infer_gene = qr_info[region + '_gene']
                if utils.are_alleles(infer_gene, true_gene):
                    regionstr = utils.color('bold', utils.color('blue', region))
                    truestr = ''  #'(originally %s)' % match_name
                else:
                    regionstr = utils.color('bold', utils.color('red', region))
                    truestr = '(true: %s)' % utils.color_gene(true_gene).replace(region, '')
                # print '  %s %s %s' % (regionstr, utils.color_gene(infer_gene).replace(region, ''), truestr)

                print '    %s %3d %3d %s %s %s' % (regionstr, qr_info[region + '_qr_bounds'][0], qr_info[region + '_qr_bounds'][1], utils.color_gene(infer_gene).replace(region, ''), truestr, qr_info[region + '_gl_seq'])
        for boundary in utils.boundaries:
            start = qr_info[boundary[0] + '_qr_bounds'][1]
            end = qr_info[boundary[1] + '_qr_bounds'][0]
            qr_info[boundary + '_insertion'] = qr_info['seq'][start : end]
            if self.args.debug:
                print '   ', boundary, qr_info[boundary + '_insertion']

        self.perfplotter.evaluate(self.seqinfo[query_name], qr_info)
        # for key, val in qr_info.items():
        #     print key, val
        if self.args.debug:
            utils.print_reco_event(self.germline_seqs, self.seqinfo[query_name], label='true:', extra_str='  ')
            utils.print_reco_event(self.germline_seqs, qr_info, extra_str=' ')
            
    # ----------------------------------------------------------------------------------------
    def process_single_block(self, block, query_name, qr_info):
        assert block[0].find('Query_') == 0
        vals = block[0].split()
        qr_start = int(vals[1]) - 1  # converting from one-indexed to zero-indexed
        qr_seq = vals[2]
        qr_end = int(vals[3])  # ...and from inclusive of both bounds to normal programming conventions
        if qr_seq not in self.seqinfo[query_name]['seq']:
            if '-' in qr_seq:
                print '    %s: insertion inside query seq, treating as partial failure' % query_name
                qr_info['fail'] = True
                return
            else:
                print '  ERROR query seq from igblast info not found in original query seq for %d' % query_name
                print '    %s' % qr_seq
                print '    %s' % self.seqinfo[query_name]['seq']
                sys.exit()

        if 'seq' in qr_info:
            qr_info['seq'] += qr_seq
        else:
            qr_info['seq'] = qr_seq


        # keep track of the absolute first and absolute last bases matched so we can later work out the fv and jf insertions
        if 'match_start' not in qr_info or qr_start < qr_info['match_start']:
            qr_info['match_start'] = qr_start
        if 'match_end' not in qr_info or qr_end > qr_info['match_end']:
            qr_info['match_end'] = qr_end

        # ----------------------------------------------------------------------------------------
        # skipping bullshit
        def skip_gene(gene):
            if self.args.debug:
                print '    %s in list of genes to skip' % utils.color_gene(gene)
            if gene not in genes_actually_skipped:
                genes_actually_skipped[gene] = 0
            genes_actually_skipped[gene] += 1
            qr_info['skip_gene'] = True

        if self.args.debug:
            print '      query: %3d %3d %s' % (qr_start, qr_end, qr_seq)
        for line in block[1:]:
            gene = line[line.rfind('IGH') : line.rfind('</a>')]
            region = utils.get_region(gene)
            true_gene = self.seqinfo[query_name][region + '_gene']

            for gset in equivalent_genes:
                if gene in gset and true_gene in gset and gene != true_gene:  # if the true gene and the inferred gene are in the same equivalence set, treat it as correct, i.e. just pretend it inferred the right name
                    if self.args.debug:
                        print '   %s: replacing name %s with true name %s' % (query_name, gene, true_gene)
                    gene = true_gene

            if gene in just_always_friggin_skip:
                continue  # go on to the next match

            if not self.args.dont_skip_or15_genes and '/OR1' in true_gene:
                skip_gene(true_gene)
                return

            if self.args.skip_missing_genes:
                if gene in genes_to_skip:
                    continue  # go on to the next match
                    # skip_gene(gene)
                    # return
                if true_gene in genes_to_skip:
                    skip_gene(true_gene)
                    return

            if gene not in self.germline_seqs[region]:
                print '    %s: %s not in germlines (skipping)' % (query_name, gene)
                skip_gene(gene)
                return
                
            vals = line.split()
            gl_start = int(vals[-3]) - 1  # converting from one-indexed to zero-indexed
            gl_seq = vals[-2]
            gl_end = int(vals[-1])  # ...and from inclusive of both bounds to normal programming conventions

            if region + '_gene' in qr_info:
                if qr_info[region + '_gene'] == gene:
                    if self.args.debug:
                        print '        %s match: %s' % (region, clean_alignment_crap(qr_seq, gl_seq))
                    qr_info[region + '_gl_seq'] = qr_info[region + '_gl_seq'] + clean_alignment_crap(qr_seq, gl_seq)
                    if gl_end > len(self.germline_seqs[region][gene]):  # not really sure what's wrong... but it seems to be rare
                        qr_info['fail'] = True
                        return
                    qr_info[region + '_3p_del'] = len(self.germline_seqs[region][gene]) - gl_end
                    qr_info[region + '_qr_bounds'] = (qr_info[region + '_qr_bounds'][0], find_qr_bounds(qr_start, qr_end, gl_seq)[1])
                else:
                    continue
            else:
                qr_info[region + '_gene'] = gene
                qr_info[region + '_gl_seq'] = clean_alignment_crap(qr_seq, gl_seq)
                # deletions
                qr_info[region + '_5p_del'] = gl_start
                assert gl_end <= len(self.germline_seqs[region][gene])
                qr_info[region + '_3p_del'] = len(self.germline_seqs[region][gene]) - gl_end
                # bounds
                qr_info[region + '_qr_bounds'] = find_qr_bounds(qr_start, qr_end, gl_seq)
                if self.args.debug:
                    print '        %s match: %s' % (region, clean_alignment_crap(qr_seq, gl_seq))
Exemplo n.º 12
0
class IhhhmmmParser(object):
    def __init__(self, args):
        self.args = args

        self.germline_seqs = utils.read_germlines(self.args.datadir,
                                                  remove_N_nukes=True)
        self.perfplotter = PerformancePlotter(self.germline_seqs,
                                              self.args.plotdir, 'ihhhmmm')

        self.details = OrderedDict()
        self.failtails = {}
        self.n_partially_failed = 0

        # get sequence info that was passed to ihhhmmm
        self.siminfo = OrderedDict()
        self.sim_need = []  # list of queries that we still need to find
        with opener('r')(self.args.simfname) as seqfile:
            reader = csv.DictReader(seqfile)
            iline = 0
            for line in reader:
                if self.args.queries != None and line[
                        'unique_id'] not in self.args.queries:
                    continue
                self.siminfo[line['unique_id']] = line
                self.sim_need.append(line['unique_id'])
                iline += 1
                if args.n_max_queries > 0 and iline >= args.n_max_queries:
                    break

        fostream_names = glob.glob(self.args.indir + '/*.fostream')
        fostream_names.sort()  # maybe already sorted?
        for infname in fostream_names:
            if len(self.sim_need) == 0:
                break

            # try to get whatever you can for the failures
            unique_ids = self.find_partial_failures(
                infname)  # returns list of unique ids in this file

            with opener('r')(infname) as infile:
                self.parse_file(infile, unique_ids)

        # now check that we got results for all the queries we wanted
        n_failed = 0
        for unique_id in self.siminfo:
            if unique_id not in self.details and unique_id not in self.failtails:
                print '%-20s  no info' % unique_id
                self.perfplotter.add_fail()
                n_failed += 1

        print ''
        print 'partially failed: %d / %d = %.2f' % (
            self.n_partially_failed, len(self.siminfo),
            float(self.n_partially_failed) / len(self.siminfo))
        print 'failed:           %d / %d = %.2f' % (n_failed, len(
            self.siminfo), float(n_failed) / len(self.siminfo))
        print ''

        self.perfplotter.plot()

    # ----------------------------------------------------------------------------------------
    def parse_file(self, infile, unique_ids):
        fk = FileKeeper(infile.readlines())
        i_id = 0
        while not fk.eof and len(self.sim_need) > 0:
            self.parse_detail(fk, unique_ids[i_id])
            i_id += 1

    # ----------------------------------------------------------------------------------------
    def parse_detail(self, fk, unique_id):
        assert fk.iline < len(fk.lines)

        while fk.line[1] != 'Details':
            fk.increment()
            if fk.eof:
                return

        fk.increment()
        info = {}
        info['unique_id'] = unique_id
        for begin_line, column, index, required, default in line_order:
            if fk.line[0].find(begin_line) != 0:
                if required:
                    print 'oop', begin_line, fk.line
                    sys.exit()
                else:
                    info[column] = default
                    continue
            if column != '':
                info[column] = clean_value(column, fk.line[index])
                # if '[' in info[column]:
                #     print 'added', column, clean_value(column, fk.line[index])
                if column.find('_gene') == 1:
                    region = column[0]
                    info[region + '_5p_del'] = int(
                        fk.line[fk.line.index('start:') +
                                1]) - 1  # NOTE their indices are 1-based
                    gl_length = int(fk.line[fk.line.index('gene:') + 1]) - 1
                    match_end = int(fk.line[fk.line.index('end:') + 1]) - 1
                    assert gl_length >= match_end
                    info[region + '_3p_del'] = gl_length - match_end

            fk.increment()

        if unique_id not in self.sim_need:
            while not fk.eof and fk.line[
                    1] != 'Details':  # skip stuff until start of next Detail block
                fk.increment()
            return

        info['fv_insertion'] = ''
        info['jf_insertion'] = ''
        info['seq'] = info['v_qr_seq'] + info['vd_insertion'] + info[
            'd_qr_seq'] + info['dj_insertion'] + info['j_qr_seq']

        if '-' in info['seq']:
            print 'ERROR found a dash in %s, returning failure' % unique_id
            while not fk.eof and fk.line[
                    1] != 'Details':  # skip stuff until start of next Detail block
                fk.increment()
            return

        if info['seq'] not in self.siminfo[unique_id][
                'seq']:  # arg. I can't do != because it tacks on v left and j right deletions
            print 'ERROR didn\'t find the right sequence for %s' % unique_id
            print '  ', info['seq']
            print '  ', self.siminfo[unique_id]['seq']
            sys.exit()

        if self.args.debug:
            print unique_id
            utils.print_reco_event(self.germline_seqs,
                                   self.siminfo[unique_id],
                                   label='true:',
                                   extra_str='    ')
            utils.print_reco_event(self.germline_seqs,
                                   info,
                                   label='inferred:',
                                   extra_str='    ')

        for region in utils.regions:
            if info[region + '_gene'] not in self.germline_seqs[region]:
                print 'ERROR %s not in germlines' % info[region + '_gene']
                assert False

            gl_seq = info[region + '_gl_seq']
            if '[' in gl_seq:  # ambiguous
                for nuke in utils.nukes:
                    gl_seq = gl_seq.replace('[', nuke)
                    if gl_seq in self.germline_seqs[region][info[region +
                                                                 '_gene']]:
                        print '  replaced [ with %s' % nuke
                        break
                info[region + '_gl_seq'] = gl_seq

            if info[region + '_gl_seq'] not in self.germline_seqs[region][info[
                    region + '_gene']]:
                print 'ERROR gl match not found for %s in %s' % (
                    info[region + '_gene'], unique_id)
                print '  ', info[region + '_gl_seq']
                print '  ', self.germline_seqs[region][info[region + '_gene']]
                self.perfplotter.add_partial_fail(self.siminfo[unique_id],
                                                  info)
                while not fk.eof and fk.line[
                        1] != 'Details':  # skip stuff until start of next Detail block
                    fk.increment()
                return

        self.perfplotter.evaluate(self.siminfo[unique_id], info)
        self.details[unique_id] = info
        self.sim_need.remove(unique_id)

        while not fk.eof and fk.line[
                1] != 'Details':  # skip stuff until start of next Detail block
            fk.increment()

    # ----------------------------------------------------------------------------------------
    def find_partial_failures(self, fostream_name):
        unique_ids = []
        for line in open(fostream_name.replace('.fostream', '')).readlines():
            if len(self.sim_need) == 0:
                return
            if len(line.strip()) == 0:  # skip blank lines
                continue

            line = line.replace('"', '')
            line = line.split(';')

            unique_id = line[0]

            if 'NA' not in line:  # skip lines that were ok
                unique_ids.append(unique_id)
                continue
            if unique_id not in self.sim_need:
                continue
            if unique_id not in self.siminfo:
                continue  # not looking for this <unique_id> a.t.m.

            info = {}
            info['unique_id'] = unique_id
            for stuff in line:
                for region in utils.regions:  # add the first instance of IGH[VDJ] (if it's there at all)
                    if 'IGH' + region.upper(
                    ) in stuff and region + '_gene' not in info:
                        genes = re.findall(
                            'IGH' + region.upper() + '[^ ][^ ]*', stuff)
                        if len(genes) == 0:
                            print 'ERROR no %s genes in %s' % (region, stuff)
                        gene = genes[0]
                        if gene not in self.germline_seqs[region]:
                            print 'ERROR bad gene %s for %s' % (gene,
                                                                unique_id)
                            sys.exit()
                        info[region + '_gene'] = gene
            self.perfplotter.add_partial_fail(self.siminfo[unique_id], info)
            if self.args.debug:
                print '%-20s  partial fail %s %s %s' % (
                    unique_id, utils.color_gene(info['v_gene']) if 'v_gene'
                    in info else '', utils.color_gene(info['d_gene']) if
                    'd_gene' in info else '', utils.color_gene(info['j_gene'])
                    if 'j_gene' in info else ''),
                print '  (true %s %s %s)' % tuple([
                    self.siminfo[unique_id][region + '_gene']
                    for region in utils.regions
                ])
            self.failtails[unique_id] = info
            self.n_partially_failed += 1
            self.sim_need.remove(unique_id)

        return unique_ids
Exemplo n.º 13
0
class Waterer(object):
    """ Run smith-waterman on the query sequences in <infname> """
    def __init__(self, args, input_info, reco_info, glfo, my_datadir, parameter_dir, write_parameters=False, find_new_alleles=False):
        self.parameter_dir = parameter_dir.rstrip('/')
        self.args = args
        self.debug = self.args.debug if self.args.sw_debug is None else self.args.sw_debug

        self.max_insertion_length = 35  # if vdjalign reports an insertion longer than this, rerun the query (typically with different match/mismatch ratio)
        self.absolute_max_insertion_length = 200  # just ignore them if it's longer than this

        self.input_info = input_info
        self.remaining_queries = set([q for q in self.input_info.keys()])  # we remove queries from this set when we're satisfied with the current output (in general we may have to rerun some queries with different match/mismatch scores)
        self.new_indels = 0  # number of new indels that were kicked up this time through

        self.match_mismatch = copy.deepcopy(self.args.initial_match_mismatch)  # don't want to modify it!
        self.gap_open_penalty = self.args.gap_open_penalty  # not modifying it now, but just to make sure we don't in the future

        self.reco_info = reco_info
        self.glfo = glfo
        self.info = {}
        self.info['queries'] = []  # list of queries that *passed* sw, i.e. for which we have information
        self.info['all_best_matches'] = set()  # every gene that was a best match for at least one query
        self.info['all_matches'] = {r : set() for r in utils.regions}  # every gene that was *any* match for at least one query
        self.info['indels'] = {}

        self.nth_try = 1
        self.unproductive_queries = set()

        # rewrite input germline sets (if needed)
        self.my_datadir = my_datadir

        self.alfinder, self.pcounter, self.true_pcounter, self.perfplotter = None, None, None, None
        if find_new_alleles:  # NOTE *not* the same as <self.args.find_new_alleles>
            self.alfinder = AlleleFinder(self.glfo, self.args)
        if write_parameters:  # NOTE *not* the same as <self.args.cache_parameters>
            self.pcounter = ParameterCounter(self.glfo, self.args)
            if not self.args.is_data:
                self.true_pcounter = ParameterCounter(self.glfo, self.args)
        if self.args.plot_performance:
            self.perfplotter = PerformancePlotter(self.glfo, 'sw')

        if not os.path.exists(self.args.ig_sw_dir + 'ig-sw'):
            raise Exception('ERROR ig-sw path d.n.e: ' + self.args.ig_sw_dir + 'ig-sw')

    # ----------------------------------------------------------------------------------------
    def run(self):
        # start = time.time()
        base_infname = 'query-seqs.fa'
        base_outfname = 'query-seqs.sam'
        sys.stdout.flush()

        n_procs = self.args.n_fewer_procs
        initial_queries_per_proc = float(len(self.remaining_queries)) / n_procs
        while len(self.remaining_queries) > 0:  # we remove queries from <self.remaining_queries> as we're satisfied with their output
            if self.nth_try > 1 and float(len(self.remaining_queries)) / n_procs < initial_queries_per_proc:
                n_procs = int(max(1., float(len(self.remaining_queries)) / initial_queries_per_proc))
            self.write_vdjalign_input(base_infname, n_procs)
            self.execute_commands(base_infname, base_outfname, n_procs)
            self.read_output(base_outfname, n_procs)
            if self.nth_try > 3:
                break
            self.nth_try += 1  # it's set to 1 before we begin the first try, and increases to 2 just before we start the second try

        self.finalize()

    # ----------------------------------------------------------------------------------------
    def finalize(self):
        if self.perfplotter is not None:
            self.perfplotter.plot(self.args.plotdir + '/sw', only_csv=self.args.only_csv_plots)
        # print '    sw time: %.3f' % (time.time()-start)
        print '      info for %d' % len(self.info['queries']),
        skipped_unproductive = len(self.unproductive_queries)
        n_remaining = len(self.remaining_queries)
        if skipped_unproductive > 0 or n_remaining > 0:
            print '     (skipped',
            print '%d / %d = %.2f unproductive' % (skipped_unproductive, len(self.input_info), float(skipped_unproductive) / len(self.input_info)),
            if n_remaining > 0:
                print '   %d / %d = %.2f other' % (n_remaining, len(self.input_info), float(n_remaining) / len(self.input_info)),
            print ')',
        print ''
        sys.stdout.flush()
        if n_remaining > 0:
            printstr = '   %s %d missing %s' % (utils.color('red', 'warning'), n_remaining, utils.plural_str('annotation', n_remaining))
            if n_remaining < 15:
                printstr += ' (' + ':'.join(self.remaining_queries) + ')'
            print printstr
        if self.debug and len(self.info['indels']) > 0:
            print '      indels: %s' % ':'.join(self.info['indels'].keys())
        assert len(self.info['queries']) + skipped_unproductive + n_remaining == len(self.input_info)
        if self.debug and not self.args.is_data and n_remaining > 0:
            print 'true annotations for remaining events:'
            for qry in self.remaining_queries:
                utils.print_reco_event(self.glfo['seqs'], self.reco_info[qry], extra_str='      ', label='true:')
        if self.alfinder is not None:
            self.alfinder.finalize(debug=self.args.debug_new_allele_finding)
            self.info['new-alleles'] = self.alfinder.new_allele_info
            if self.args.plotdir is not None:
                self.alfinder.plot(self.args.plotdir + '/sw', only_csv=self.args.only_csv_plots)

        # add padded info to self.info (returns if stuff has already been padded)
        self.pad_seqs_to_same_length()  # NOTE this uses *all the gene matches (not just the best ones), so it has to come before we call pcounter.write(), since that fcn rewrites the germlines removing genes that weren't best matches. But NOTE also that I'm not sure what but that the padding actually *needs* all matches (rather than just all *best* matches)

        if self.pcounter is not None:
            if self.args.plotdir is not None:
                self.pcounter.plot(self.args.plotdir + '/sw', subset_by_gene=True, cyst_positions=self.glfo['cyst-positions'], tryp_positions=self.glfo['tryp-positions'], only_csv=self.args.only_csv_plots)
                if self.true_pcounter is not None:
                    self.true_pcounter.plot(self.args.plotdir + '/sw-true', subset_by_gene=True, cyst_positions=self.glfo['cyst-positions'], tryp_positions=self.glfo['tryp-positions'], only_csv=self.args.only_csv_plots)
            self.pcounter.write(self.parameter_dir, self.my_datadir)
            if self.true_pcounter is not None:
                self.true_pcounter.write(self.parameter_dir + '-true')

        self.info['remaining_queries'] = self.remaining_queries

    # ----------------------------------------------------------------------------------------
    def subworkdir(self, iproc, n_procs):
        if n_procs == 1:
            return self.args.workdir
        else:
            return self.args.workdir + '/sw-' + str(iproc)

    # ----------------------------------------------------------------------------------------
    def execute_commands(self, base_infname, base_outfname, n_procs):
        # ----------------------------------------------------------------------------------------
        def get_outfname(iproc):
            return self.subworkdir(iproc, n_procs) + '/' + base_outfname
        # ----------------------------------------------------------------------------------------
        def get_cmd_str(iproc):
            return self.get_vdjalign_cmd_str(self.subworkdir(iproc, n_procs), base_infname, base_outfname, n_procs)

        # start all procs for the first time
        procs, n_tries = [], []
        for iproc in range(n_procs):
            procs.append(utils.run_cmd(get_cmd_str(iproc), self.subworkdir(iproc, n_procs)))
            n_tries.append(1)
            time.sleep(0.1)

        # keep looping over the procs until they're all done
        while procs.count(None) != len(procs):  # we set each proc to None when it finishes
            for iproc in range(n_procs):
                if procs[iproc] is None:  # already finished
                    continue
                if procs[iproc].poll() is not None:  # it's finished
                    utils.finish_process(iproc, procs, n_tries, self.subworkdir(iproc, n_procs), get_outfname(iproc), get_cmd_str(iproc))
            sys.stdout.flush()
            time.sleep(1)

        for iproc in range(n_procs):
            os.remove(self.subworkdir(iproc, n_procs) + '/' + base_infname)

        sys.stdout.flush()

    # ----------------------------------------------------------------------------------------
    def write_vdjalign_input(self, base_infname, n_procs):
        n_remaining = len(self.remaining_queries)
        queries_per_proc = float(n_remaining) / n_procs
        n_queries_per_proc = int(math.ceil(queries_per_proc))
        written_queries = set()  # make sure we actually write each query TODO remove this when you work out where they're disappearing to
        if n_procs == 1:  # double check for rounding problems or whatnot
            assert n_queries_per_proc == n_remaining
        for iproc in range(n_procs):
            workdir = self.subworkdir(iproc, n_procs)
            if n_procs > 1:
                utils.prep_dir(workdir)
            with opener('w')(workdir + '/' + base_infname) as sub_infile:
                iquery = 0
                for query_name in self.remaining_queries:  # NOTE this is wasteful to loop of all the remaining queries for each process... but maybe not that wasteful
                    if iquery >= n_remaining:
                        break
                    if iquery < iproc*n_queries_per_proc or iquery >= (iproc + 1)*n_queries_per_proc:  # not for this process
                        iquery += 1
                        continue
                    sub_infile.write('>' + query_name + ' NUKES\n')

                    seq = self.input_info[query_name]['seq']
                    if query_name in self.info['indels']:
                        seq = self.info['indels'][query_name]['reversed_seq']  # use the query sequence with shm insertions and deletions reversed
                    sub_infile.write(seq + '\n')
                    written_queries.add(query_name)
                    iquery += 1
        not_written = self.remaining_queries - written_queries
        if len(not_written) > 0:
            raise Exception('didn\'t write %s to %s' % (':'.join(not_written), self.args.workdir))

    # ----------------------------------------------------------------------------------------
    def get_vdjalign_cmd_str(self, workdir, base_infname, base_outfname, n_procs=None):
        """
        Run smith-waterman alignment (from Connor's ighutils package) on the seqs in <base_infname>, and toss all the top matches into <base_outfname>.
        """
        # large gap-opening penalty: we want *no* gaps in the middle of the alignments
        # match score larger than (negative) mismatch score: we want to *encourage* some level of shm. If they're equal, we tend to end up with short unmutated alignments, which screws everything up
        cmd_str = '/partis/packages/ig-sw/src/ig_align/ig-sw'
        if self.args.slurm or utils.auto_slurm(n_procs):
            cmd_str = 'srun ' + cmd_str
        cmd_str += ' -l ' + 'IG' + self.args.chain.upper()
        cmd_str += ' -d 50'
        match, mismatch = self.match_mismatch
        cmd_str += ' -m ' + str(match) + ' -u ' + str(mismatch)
        cmd_str += ' -o ' + str(self.gap_open_penalty)
        cmd_str += ' -p ' + self.my_datadir + '/' + self.args.chain + '/'
        cmd_str += ' ' + workdir + '/' + base_infname + ' ' + workdir + '/' + base_outfname
        return cmd_str

    # ----------------------------------------------------------------------------------------
    def read_output(self, base_outfname, n_procs=1):
        queries_to_rerun = OrderedDict()  # This is to keep track of every query that we don't add to self.info (i.e. it does *not* include unproductive queries that we ignore/skip entirely because we were told to by a command line argument)
                                          # ...whereas <self.unproductive_queries> is to keep track of the queries that were definitively unproductive (i.e. we removed them from self.remaining_queries) when we were told to skip unproductives by a command line argument
        for reason in ['unproductive', 'no-match', 'weird-annot.', 'nonsense-bounds', 'invalid-codon']:
            queries_to_rerun[reason] = set()

        self.new_indels = 0
        n_processed = 0
        self.tmp_queries_read_from_file = set()  # TODO remove this
        for iproc in range(n_procs):
            outfname = self.subworkdir(iproc, n_procs) + '/' + base_outfname
            with contextlib.closing(pysam.Samfile(outfname)) as sam: #changed bam to sam because ig-sw outputs sam files
                grouped = itertools.groupby(iter(sam), operator.attrgetter('qname'))
                for _, reads in grouped:  # loop over query sequences
                    self.process_query(sam.references, list(reads), queries_to_rerun)
                    n_processed += 1

        not_read = self.remaining_queries - self.tmp_queries_read_from_file
        if len(not_read) > 0:
            raise Exception('didn\'t read %s from %s' % (':'.join(not_read), self.args.workdir))

        if self.nth_try == 1:
            print '        processed       remaining      new-indels          rerun: ' + '      '.join([reason for reason in queries_to_rerun])
        print '      %8d' % n_processed,
        if len(self.remaining_queries) > 0:
            printstr = '       %8d' % len(self.remaining_queries)
            printstr += '       %8d' % self.new_indels
            printstr += '            '
            n_to_rerun = 0
            for reason in queries_to_rerun:
                printstr += '        %8d' % len(queries_to_rerun[reason])
                n_to_rerun += len(queries_to_rerun[reason])
            print printstr,
            if n_to_rerun + self.new_indels != len(self.remaining_queries):
                print ''
                raise Exception('numbers don\'t add up in sw output reader (n_to_rerun + new_indels != remaining_queries): %d + %d != %d   (look in %s)' % (n_to_rerun, self.new_indels, len(self.remaining_queries), self.args.workdir))
            if self.nth_try < 2 or self.new_indels == 0:  # increase the mismatch score if it's the first try, or if there's no new indels
                print '            increasing mismatch score (%d --> %d) and rerunning them' % (self.match_mismatch[1], self.match_mismatch[1] + 1)
                self.match_mismatch[1] += 1
            elif self.new_indels > 0:  # if there were some indels, rerun with the same parameters (but when the input is written the indel will be "reversed' in the sequences that's passed to ighutil)
                print '            rerunning for indels'
                self.new_indels = 0
            else:  # shouldn't get here
                assert False
        else:
            print '        all done'

        for iproc in range(n_procs):
            workdir = self.subworkdir(iproc, n_procs)
            os.remove(workdir + '/' + base_outfname)
            if n_procs > 1:  # still need the top-level workdir
                os.rmdir(workdir)

    # ----------------------------------------------------------------------------------------
    def get_indel_info(self, query_name, cigarstr, qrseq, glseq, gene):
        cigars = re.findall('[0-9][0-9]*[A-Z]', cigarstr)  # split cigar string into its parts
        cigars = [(cstr[-1], int(cstr[:-1])) for cstr in cigars]  # split each part into the code and the length

        codestr = ''
        qpos = 0  # position within query sequence
        indelfo = utils.get_empty_indel()  # replacement_seq: query seq with insertions removed and germline bases inserted at the position of deletions
        tmp_indices = []
        for code, length in cigars:
            codestr += length * code
            if code == 'I':  # advance qr seq but not gl seq
                indelfo['indels'].append({'type' : 'insertion', 'pos' : qpos, 'len' : length, 'seqstr' : ''})  # insertion begins at <pos>
                tmp_indices += [len(indelfo['indels']) - 1  for _ in range(length)]# indel index corresponding to this position in the alignment
            elif code == 'D':  # advance qr seq but not gl seq
                indelfo['indels'].append({'type' : 'deletion', 'pos' : qpos, 'len' : length, 'seqstr' : ''})  # first deleted base is <pos> (well, first base which is in the position of the first deleted base)
                tmp_indices += [len(indelfo['indels']) - 1  for _ in range(length)]# indel index corresponding to this position in the alignment
            else:
                tmp_indices += [None  for _ in range(length)]  # indel index corresponding to this position in the alignment
            qpos += length

        qrprintstr, glprintstr = '', ''
        iqr, igl = 0, 0
        for icode in range(len(codestr)):
            code = codestr[icode]
            if code == 'M':
                qrbase = qrseq[iqr]
                if qrbase != glseq[igl]:
                    qrbase = utils.color('red', qrbase)
                qrprintstr += qrbase
                glprintstr += glseq[igl]
                indelfo['reversed_seq'] += qrseq[iqr]  # add the base to the overall sequence with all indels reversed
            elif code == 'S':
                continue
            elif code == 'I':
                qrprintstr += utils.color('light_blue', qrseq[iqr])
                glprintstr += utils.color('light_blue', '*')
                indelfo['indels'][tmp_indices[icode]]['seqstr'] += qrseq[iqr]  # and to the sequence of just this indel
                igl -= 1
            elif code == 'D':
                qrprintstr += utils.color('light_blue', '*')
                glprintstr += utils.color('light_blue', glseq[igl])
                indelfo['reversed_seq'] += glseq[igl]  # add the base to the overall sequence with all indels reversed
                indelfo['indels'][tmp_indices[icode]]['seqstr'] += glseq[igl]  # and to the sequence of just this indel
                iqr -= 1
            else:
                raise Exception('unhandled code %s' % code)

            iqr += 1
            igl += 1

        if self.debug:
            print '\n      indels in %s' % query_name
            print '          %20s %s' % (gene, glprintstr)
            print '          %20s %s' % ('query', qrprintstr)
            for idl in indelfo['indels']:
                print '          %10s: %d bases at %d (%s)' % (idl['type'], idl['len'], idl['pos'], idl['seqstr'])
        # utils.undo_indels(indelfo)
        # print '                       %s' % self.input_info[query_name]['seq']

        return indelfo

    # ----------------------------------------------------------------------------------------
    def process_query(self, references, reads, queries_to_rerun):
        primary = next((r for r in reads if not r.is_secondary), None)
        query_seq = primary.seq
        query_name = primary.qname
        self.tmp_queries_read_from_file.add(query_name)
        first_match_query_bounds = None  # since sw excises its favorite v match, we have to know this match's boundaries in order to calculate k_d for all the other matches
        all_match_names = {}
        warnings = {}  # ick, this is a messy way to pass stuff around
        for region in utils.regions:
            all_match_names[region] = []
        all_query_bounds, all_germline_bounds = {}, {}
        for read in reads:  # loop over the matches found for each query sequence
            # set this match's values
            read.seq = query_seq  # only the first one has read.seq set by default, so we need to set the rest by hand
            gene = references[read.tid]
            region = utils.get_region(gene)
            raw_score = read.tags[0][1]  # raw because they don't include the gene choice probs
            score = raw_score
            qrbounds = (read.qstart, read.qend)
            glbounds = (read.pos, read.aend)
            if region == 'v' and first_match_query_bounds is None:
                first_match_query_bounds = qrbounds

            # perform a few checks and see if we want to skip this match
            # TODO I wish this wasn't here and I suspect I don't really need it (any more) UPDATE I dunno, this definitely eliminates some stupid (albeit rare) matches
            if region == 'v':  # skip matches with cpos past the end of the query seq (i.e. eroded a ton on the right side of the v)
                cpos = self.glfo['cyst-positions'][gene] - glbounds[0] + qrbounds[0]  # position within original germline gene, minus the position in that germline gene at which the match starts, plus the position in the query sequence at which the match starts
                if cpos < 0 or cpos >= len(query_seq):
                    continue

            if 'I' in read.cigarstring or 'D' in read.cigarstring:  # skip indels, and tell the HMM to skip indels (you won't see any unless you decrease the <self.gap_open_penalty>)
                if self.args.no_indels:  # you can forbid indels on the command line
                    continue
                if self.nth_try < 2:  # we also forbid indels on the first try (we want to increase the mismatch score before we conclude it's "really" an indel)
                    continue
                if len(all_match_names[region]) == 0:  # if this is the first (best) match for this region, allow indels (otherwise skip the match)
                    if query_name not in self.info['indels']:
                        self.info['indels'][query_name] = self.get_indel_info(query_name, read.cigarstring, query_seq[qrbounds[0] : qrbounds[1]], self.glfo['seqs'][region][gene][glbounds[0] : glbounds[1]], gene)
                        self.info['indels'][query_name]['reversed_seq'] = query_seq[ : qrbounds[0]] + self.info['indels'][query_name]['reversed_seq'] + query_seq[qrbounds[1] : ]
                        self.new_indels += 1
                        # TODO this 'return' used to be after and indented from the else below, and that continue wasn't there. I should make sure this is how I want it
                        return  # don't process this query any further -- since it's now in the indel info it'll get run next time through
                    else:
                        if self.debug:
                            print '     ignoring subsequent indels for %s' % query_name
                        continue  # hopefully there's a later match without indels
                else:
                    continue

            if qrbounds[1]-qrbounds[0] != glbounds[1]-glbounds[0]:
                raise Exception('germline match (%d %d) not same length as query match (%d %d)' % (qrbounds[0], qrbounds[1], glbounds[0], glbounds[1]))

            assert qrbounds[1] <= len(query_seq)
            if glbounds[1] > len(self.glfo['seqs'][region][gene]):
                print '  ', gene
                print '  ', glbounds[1], len(self.glfo['seqs'][region][gene])
                print '  ', self.glfo['seqs'][region][gene]
            assert glbounds[1] <= len(self.glfo['seqs'][region][gene])
            assert qrbounds[1]-qrbounds[0] == glbounds[1]-glbounds[0]

            # and finally add this match's information
            warnings[gene] = ''
            all_match_names[region].append((score, gene))  # NOTE it is important that this is ordered such that the best match is first
            all_query_bounds[gene] = qrbounds
            all_germline_bounds[gene] = glbounds

        self.summarize_query(query_name, query_seq, all_match_names, all_query_bounds, all_germline_bounds, warnings, first_match_query_bounds, queries_to_rerun)

    # ----------------------------------------------------------------------------------------
    def print_match(self, region, gene, query_seq, score, glbounds, qrbounds, codon_pos, warnings, skipping=False):
        out_str_list = []
        buff_str = (20 - len(gene)) * ' '
        out_str_list.append('%8s%s%s%9s%3s %6.0f        ' % (' ', utils.color_gene(gene), '', '', buff_str, score))
        out_str_list.append('%4d%4d   %s\n' % (glbounds[0], glbounds[1], self.glfo['seqs'][region][gene][glbounds[0]:glbounds[1]]))
        out_str_list.append('%46s  %4d%4d' % ('', qrbounds[0], qrbounds[1]))
        out_str_list.append('   %s ' % (utils.color_mutants(self.glfo['seqs'][region][gene][glbounds[0]:glbounds[1]], query_seq[qrbounds[0]:qrbounds[1]])))
        if region != 'd':
            out_str_list.append('(%s %d)' % (utils.conserved_codons[region], codon_pos))
        if warnings[gene] != '':
            out_str_list.append('WARNING ' + warnings[gene])
        if skipping:
            out_str_list.append('skipping!')

        print ''.join(out_str_list)

    # ----------------------------------------------------------------------------------------
    def get_overlap_and_available_space(self, rpair, best, qrbounds):
        l_reg = rpair['left']
        r_reg = rpair['right']
        l_gene = best[l_reg]
        r_gene = best[r_reg]
        overlap = qrbounds[l_gene][1] - qrbounds[r_gene][0]
        available_space = qrbounds[r_gene][1] - qrbounds[l_gene][0]
        return overlap, available_space

    # ----------------------------------------------------------------------------------------
    def check_boundaries(self, rpair, qrbounds, glbounds, query_name, query_seq, best, recursed=False, debug=False):
        # NOTE this duplicates code in shift_overlapping_boundaries(), which makes me cranky, but this setup avoids other things I dislike more
        l_reg = rpair['left']
        r_reg = rpair['right']
        l_gene = best[l_reg]
        r_gene = best[r_reg]

        overlap, available_space = self.get_overlap_and_available_space(rpair, best, qrbounds)

        if debug:
            print '  %s %s    overlap %d    available space %d' % (l_reg, r_reg, overlap, available_space)

        status = 'ok'
        if overlap > 0:  # positive overlap means they actually overlap
            status = 'overlap'
        if overlap > available_space or overlap == 1 and available_space == 1:  # call it nonsense if the boundaries are really whack (i.e. there isn't enough space to resolve the overlap) -- we'll presumably either toss the query or rerun with different match/mismatch
            status = 'nonsense'

        if debug:
            print '  overlap status: %s' % status

        if not recursed and status == 'nonsense' and l_reg == 'd' and self.nth_try > 2:  # on rare occasions with very high mutation, vdjalign refuses to give us a j match that's at all to the right of the d match
            assert l_reg == 'd' and r_reg == 'j'
            if debug:
                print '  %s: synthesizing d match' % query_name
            leftmost_position = min(qrbounds[l_gene][0], qrbounds[r_gene][0])
            qrbounds[l_gene] = (leftmost_position, leftmost_position + 1)  # swap whatever crummy nonsense d match we have now for a one-base match at the left end of things (things in practice should be left end of j match)
            glbounds[l_gene] = (0, 1)
            status = self.check_boundaries(rpair, qrbounds, glbounds, query_name, query_seq, best, recursed=True, debug=debug)
            if status == 'overlap':
                if debug:
                    print '  \'overlap\' status after synthesizing d match. Setting to \'nonsense\', I can\'t deal with this bullshit'
                status = 'nonsense'

        return status

    # ----------------------------------------------------------------------------------------
    def shift_overlapping_boundaries(self, rpair, qrbounds, glbounds, query_name, query_seq, best, debug=False):
        # NOTE this does pretty much the same thing as resolve_overlapping_matches in joinparser.py
        """
        s-w allows d and j matches (and v and d matches) to overlap... which makes no sense, so apportion the disputed territory between the two regions.
        Note that this still works if, say, v is the entire sequence, i.e. one match is entirely subsumed by another.
        """
        l_reg = rpair['left']
        r_reg = rpair['right']
        l_gene = best[l_reg]
        r_gene = best[r_reg]

        overlap, available_space = self.get_overlap_and_available_space(rpair, best, qrbounds)

        if overlap <= 0:  # nothing to do, they're already consistent
            print 'shouldn\'t get here any more if there\'s no overlap'
            return

        if overlap > available_space:
            raise Exception('overlap %d bigger than available space %d between %s and %s for %s' % (overlap, available_space, l_reg, r_reg, query_name))

        if debug:
            print '%s%s:  %d-%d overlaps with %d-%d by %d' % (l_reg, r_reg, qrbounds[l_gene][0], qrbounds[l_gene][1], qrbounds[r_gene][0], qrbounds[r_gene][1], overlap)

        l_length = qrbounds[l_gene][1] - qrbounds[l_gene][0]  # initial length of lefthand gene match
        r_length = qrbounds[r_gene][1] - qrbounds[r_gene][0]  # and same for the righthand one
        l_portion, r_portion = 0, 0  # portion of the initial overlap that we give to each side
        if debug:
            print '    lengths        portions     '
        while l_portion + r_portion < overlap:
            if debug:
                print '  %4d %4d      %4d %4d' % (l_length, r_length, l_portion, r_portion)
            if l_length <= 1 and r_length <= 1:  # don't want to erode match (in practice it'll be the d match) all the way to zero
                raise Exception('both lengths went to one without resolving overlap for %s: %s %s' % (query_name, qrbounds[l_gene], qrbounds[r_gene]))
            elif l_length > 1 and r_length > 1:  # if both have length left, alternate back and forth
                if (l_portion + r_portion) % 2 == 0:
                    l_portion += 1  # give one base to the left
                    l_length -= 1
                else:
                    r_portion += 1  # and one to the right
                    r_length -= 1
            elif l_length > 1:
                l_portion += 1
                l_length -= 1
            elif r_length > 1:
                r_portion += 1
                r_length -= 1

        if debug:
            print '  %4d %4d    %4d %4d      %s %s' % (l_length, r_length, l_portion, r_portion, '', '')
            print '      %s apportioning %d bases between %s (%d) match and %s (%d) match' % (query_name, overlap, l_reg, l_portion, r_reg, r_portion)
        assert l_portion + r_portion == overlap
        qrbounds[l_gene] = (qrbounds[l_gene][0], qrbounds[l_gene][1] - l_portion)
        glbounds[l_gene] = (glbounds[l_gene][0], glbounds[l_gene][1] - l_portion)
        qrbounds[r_gene] = (qrbounds[r_gene][0] + r_portion, qrbounds[r_gene][1])
        glbounds[r_gene] = (glbounds[r_gene][0] + r_portion, glbounds[r_gene][1])

        best[l_reg + '_gl_seq'] = self.glfo['seqs'][l_reg][l_gene][glbounds[l_gene][0] : glbounds[l_gene][1]]
        best[l_reg + '_qr_seq'] = query_seq[qrbounds[l_gene][0]:qrbounds[l_gene][1]]
        best[r_reg + '_gl_seq'] = self.glfo['seqs'][r_reg][r_gene][glbounds[r_gene][0] : glbounds[r_gene][1]]
        best[r_reg + '_qr_seq'] = query_seq[qrbounds[r_gene][0]:qrbounds[r_gene][1]]

    # ----------------------------------------------------------------------------------------
    def add_to_info(self, query_name, query_seq, kvals, match_names, best, all_germline_bounds, all_query_bounds, codon_positions):
        assert query_name not in self.info
        self.info['queries'].append(query_name)
        self.info[query_name] = {}
        self.info[query_name]['unique_id'] = query_name  # redundant, but used somewhere down the line
        self.info[query_name]['k_v'] = kvals['v']
        self.info[query_name]['k_d'] = kvals['d']
        self.info[query_name]['all'] = ':'.join(match_names['v'] + match_names['d'] + match_names['j'])  # all gene matches for this query

        self.info[query_name]['cdr3_length'] = codon_positions['j'] - codon_positions['v'] + 3  #tryp_position_in_joined_seq - self.cyst_position + 3
        self.info[query_name]['cyst_position'] = codon_positions['v']
        self.info[query_name]['tryp_position'] = codon_positions['j']

        # erosion, insertion, mutation info for best match
        self.info[query_name]['v_5p_del'] = all_germline_bounds[best['v']][0]
        self.info[query_name]['v_3p_del'] = len(self.glfo['seqs']['v'][best['v']]) - all_germline_bounds[best['v']][1]  # len(germline v) - gl_match_end
        self.info[query_name]['d_5p_del'] = all_germline_bounds[best['d']][0]
        self.info[query_name]['d_3p_del'] = len(self.glfo['seqs']['d'][best['d']]) - all_germline_bounds[best['d']][1]
        self.info[query_name]['j_5p_del'] = all_germline_bounds[best['j']][0]
        self.info[query_name]['j_3p_del'] = len(self.glfo['seqs']['j'][best['j']]) - all_germline_bounds[best['j']][1]

        self.info[query_name]['fv_insertion'] = query_seq[ : all_query_bounds[best['v']][0]]
        self.info[query_name]['vd_insertion'] = query_seq[all_query_bounds[best['v']][1] : all_query_bounds[best['d']][0]]
        self.info[query_name]['dj_insertion'] = query_seq[all_query_bounds[best['d']][1] : all_query_bounds[best['j']][0]]
        self.info[query_name]['jf_insertion'] = query_seq[all_query_bounds[best['j']][1] : ]

        self.info[query_name]['indelfo'] = self.info['indels'].get(query_name, utils.get_empty_indel())

        for region in utils.regions:
            self.info[query_name][region + '_gene'] = best[region]
            self.info['all_best_matches'].add(best[region])
            self.info['all_matches'][region] |= set(match_names[region])

        self.info[query_name]['seq'] = query_seq  # NOTE this is the seq output by vdjalign, i.e. if we reversed any indels it is the reversed sequence

        existing_implicit_keys = tuple(['cdr3_length', 'cyst_position', 'tryp_position'])
        utils.add_implicit_info(self.glfo, self.info[query_name], multi_seq=False, existing_implicit_keys=existing_implicit_keys)

        if self.debug:
            if not self.args.is_data:
                utils.print_reco_event(self.glfo['seqs'], self.reco_info[query_name], extra_str='      ', label='true:')
            utils.print_reco_event(self.glfo['seqs'], self.info[query_name], extra_str='      ', label='inferred:')

        if self.alfinder is not None:
            self.alfinder.increment(self.info[query_name])
        if self.pcounter is not None:
            self.pcounter.increment_all_params(self.info[query_name])
            if self.true_pcounter is not None:
                self.true_pcounter.increment_all_params(self.reco_info[query_name])
        if self.perfplotter is not None:
            if query_name in self.info['indels']:
                print '    skipping performance evaluation of %s because of indels' % query_name  # I just have no idea how to handle naive hamming fraction when there's indels
            else:
                self.perfplotter.evaluate(self.reco_info[query_name], self.info[query_name])

        self.remaining_queries.remove(query_name)

    # ----------------------------------------------------------------------------------------
    def summarize_query(self, query_name, query_seq, all_match_names, all_query_bounds, all_germline_bounds, warnings, first_match_query_bounds, queries_to_rerun):
        best, match_names = {}, {}
        k_v_min, k_d_min = 999, 999
        k_v_max, k_d_max = 0, 0
        for region in utils.regions:
            all_match_names[region] = sorted(all_match_names[region], reverse=True)
            match_names[region] = []
        if self.debug >= 2:
            print query_name
        for region in utils.regions:
            for score, gene in all_match_names[region]:
                glbounds = all_germline_bounds[gene]
                qrbounds = all_query_bounds[gene]
                assert qrbounds[1] <= len(query_seq)  # NOTE I'm putting these up above as well (in process_query), so in time I should remove them from here
                assert glbounds[1] <= len(self.glfo['seqs'][region][gene])
                assert qrbounds[0] >= 0
                assert glbounds[0] >= 0
                glmatchseq = self.glfo['seqs'][region][gene][glbounds[0]:glbounds[1]]

                match_names[region].append(gene)

                if self.debug >= 2:
                    self.print_match(region, gene, query_seq, score, glbounds, qrbounds, -1, warnings, skipping=False)

                # if the germline match and the query match aren't the same length, s-w likely added an insert, which we shouldn't get since the gap-open penalty is jacked up so high
                if len(glmatchseq) != len(query_seq[qrbounds[0]:qrbounds[1]]):  # neurotic double check (um, I think) EDIT hey this totally saved my ass
                    print 'ERROR %d not same length' % query_name
                    print glmatchseq, glbounds[0], glbounds[1]
                    print query_seq[qrbounds[0]:qrbounds[1]]
                    assert False

                # NOTE since I'm no longer skipping the genes after the first <args.n_max_per_region>, the OR of k-space below is overly conservative. UPDATE not sure if this is still relevant, but I'll move it down here in case I feel like thinking about it later
                if region == 'v':
                    this_k_v = all_query_bounds[gene][1]  # NOTE even if the v match doesn't start at the left hand edge of the query sequence, we still measure k_v from there.
                                                          # In other words, sw doesn't tell the hmm about it
                    k_v_min = min(this_k_v, k_v_min)
                    k_v_max = max(this_k_v, k_v_max)
                if region == 'd':
                    this_k_d = all_query_bounds[gene][1] - first_match_query_bounds[1]  # end of d minus end of v
                    k_d_min = min(this_k_d, k_d_min)
                    k_d_max = max(this_k_d, k_d_max)

                # check consistency with best match (since the best match is excised in s-w code, and because ham is run with *one* k_v k_d set)
                if region not in best:
                    best[region] = gene
                    best[region + '_gl_seq'] = self.glfo['seqs'][region][gene][glbounds[0]:glbounds[1]]
                    best[region + '_qr_seq'] = query_seq[qrbounds[0]:qrbounds[1]]
                    best[region + '_score'] = score

        for region in utils.regions:
            if region not in best:
                if self.debug:
                    print '      no', region, 'match found for', query_name  # NOTE if no d match found, we should really just assume entire d was eroded
                queries_to_rerun['no-match'].add(query_name)
                return

        # s-w allows d and j matches to overlap, so we need to apportion the disputed bases
        region_pairs = ({'left':'v', 'right':'d'}, {'left':'d', 'right':'j'})
        for rpair in region_pairs:
            overlap_status = self.check_boundaries(rpair, all_query_bounds, all_germline_bounds, query_name, query_seq, best)
            if overlap_status == 'overlap':
                self.shift_overlapping_boundaries(rpair, all_query_bounds, all_germline_bounds, query_name, query_seq, best)
            elif overlap_status == 'nonsense':
                queries_to_rerun['nonsense-bounds'].add(query_name)
                return
            else:
                assert overlap_status == 'ok'

        # check for suspiciously bad annotations
        vd_insertion = query_seq[all_query_bounds[best['v']][1] : all_query_bounds[best['d']][0]]
        dj_insertion = query_seq[all_query_bounds[best['d']][1] : all_query_bounds[best['j']][0]]
        if self.nth_try < 2:
            if len(vd_insertion) > self.max_insertion_length or len(dj_insertion) > self.max_insertion_length:
                if self.debug:
                    print '      suspiciously long insertion in %s, rerunning' % query_name
                queries_to_rerun['weird-annot.'].add(query_name)
                return
        if len(vd_insertion) > self.absolute_max_insertion_length or len(dj_insertion) > self.absolute_max_insertion_length:
            if self.debug:
                print '      suspiciously long insertion in %s, rerunning' % query_name
            queries_to_rerun['weird-annot.'].add(query_name)
            return

        if self.debug:
            print query_name

        # set and check conserved codon positions
        tmp_gl_positions = {'v' : self.glfo['cyst-positions'], 'j' : self.glfo['tryp-positions']}  # hack hack hack
        codon_positions = {}
        for region in ['v', 'j']:
            pos = tmp_gl_positions[region][best[region]] - all_germline_bounds[best[region]][0] + all_query_bounds[best[region]][0]  # position within original germline gene, minus the position in that germline gene at which the match starts, plus the position in the query sequence at which the match starts
            if pos < 0 or pos >= len(query_seq):
                if self.debug:
                    print '      invalid %s codon position (%d in seq of length %d), rerunning' % (region, pos, len(query_seq))
                queries_to_rerun['invalid-codon'].add(query_name)
                return
            codon_positions[region] = pos

        # check for unproductive rearrangements
        codons_ok = utils.check_both_conserved_codons(query_seq, codon_positions['v'], codon_positions['j'], assert_on_fail=False)
        cdr3_length = codon_positions['j'] - codon_positions['v'] + 3

        if cdr3_length < 6:  # NOTE six is also hardcoded in utils
            if self.debug:
                print '      negative cdr3 length %d' % (cdr3_length)
            queries_to_rerun['invalid-codon'].add(query_name)
            return

        in_frame_cdr3 = (cdr3_length % 3 == 0)
        no_stop_codon = utils.stop_codon_check(query_seq, codon_positions['v'])
        if not codons_ok or not in_frame_cdr3 or not no_stop_codon:
            if self.debug:
                print '       unproductive rearrangement:',
                if not codons_ok:
                    print '  bad codons',
                if not in_frame_cdr3:
                    print '  out of frame cdr3',
                if not no_stop_codon:
                    print '  stop codon'
                print ''

            if self.nth_try < 2 and (not codons_ok or not in_frame_cdr3):  # rerun with higher mismatch score (sometimes unproductiveness is the result of a really screwed up annotation rather than an actual unproductive sequence). Note that stop codons aren't really indicative of screwed up annotations, so they don't count.
                if self.debug:
                    print '            ...rerunning'
                queries_to_rerun['unproductive'].add(query_name)
                return
            elif self.args.skip_unproductive:
                if self.debug:
                    print '            ...skipping'
                self.unproductive_queries.add(query_name)
                self.remaining_queries.remove(query_name)
                return
            else:
                pass  # this is here so you don't forget that if neither of the above is true, we fall through and add the query to self.info

        # best k_v, k_d:
        k_v = all_query_bounds[best['v']][1]  # end of v match
        k_d = all_query_bounds[best['d']][1] - all_query_bounds[best['v']][1]  # end of d minus end of v

        if k_d_max < 5:  # since the s-w step matches to the longest possible j and then excises it, this sometimes gobbles up the d, resulting in a very short d alignment.
            if self.debug:
                print '  expanding k_d'
            k_d_max = max(8, k_d_max)

        if 'IGHJ4*' in best['j'] and self.glfo['seqs']['d'][best['d']][-5:] == 'ACTAC':  # the end of some d versions is the same as the start of some j versions, so the s-w frequently kicks out the 'wrong' alignment
            if self.debug:
                print '  doubly expanding k_d'
            if k_d_max-k_d_min < 8:
                k_d_min -= 5
                k_d_max += 2

        k_v_min = max(1, k_v_min - self.args.default_v_fuzz)  # ok, so I don't *actually* want it to be zero... oh, well
        k_v_max += self.args.default_v_fuzz
        k_d_min = max(1, k_d_min - self.args.default_d_fuzz)
        k_d_max += self.args.default_d_fuzz
        assert k_v_min > 0 and k_d_min > 0 and k_v_max > 0 and k_d_max > 0

        if self.debug:
            print '         k_v: %d [%d-%d)' % (k_v, k_v_min, k_v_max)
            print '         k_d: %d [%d-%d)' % (k_d, k_d_min, k_d_max)


        kvals = {}
        kvals['v'] = {'best':k_v, 'min':k_v_min, 'max':k_v_max}
        kvals['d'] = {'best':k_d, 'min':k_d_min, 'max':k_d_max}
        self.add_to_info(query_name, query_seq, kvals, match_names, best, all_germline_bounds, all_query_bounds, codon_positions=codon_positions)

    # ----------------------------------------------------------------------------------------
    def get_padding_parameters(self, debug=False):
        maxima = {'gl_cpos' : None, 'gl_cpos_to_j_end' : None}
        for query in self.info['queries']:
            swfo = self.info[query]
            fvstuff = max(0, len(swfo['fv_insertion']) - swfo['v_5p_del'])  # we always want to pad out to the entire germline sequence, so don't let this go negative
            jfstuff = max(0, len(swfo['jf_insertion']) - swfo['j_3p_del'])

            for v_match in self.info['all_matches']['v']:  # NOTE have to loop over all gl matches, even ones for other sequences, because we want bcrham to be able to compare any sequence to any other UPDATE but do I really need to use *all* all matches, or would it be ok to just use all *best* matches? not sure...
                gl_cpos = self.glfo['cyst-positions'][v_match] + fvstuff
                if maxima['gl_cpos'] is None or gl_cpos > maxima['gl_cpos']:
                    maxima['gl_cpos'] = gl_cpos

            seq = swfo['seq']
            cpos = swfo['cyst_position']  # cyst position in query sequence (as opposed to gl_cpos, which is in germline allele)
            for j_match in self.info['all_matches']['j']:  # NOTE have to loop over all gl matches, even ones for other sequences, because we want bcrham to be able to compare any sequence to any other UPDATE but do I really need to use *all* all matches, or would it be ok to just use all *best* matches? not sure...
                # TODO this is totally wrong -- I'm only storing j_3p_del for the best match... but hopefully it'll give enough padding for the moment
                gl_cpos_to_j_end = len(seq) - cpos + swfo['j_3p_del'] + jfstuff
                if maxima['gl_cpos_to_j_end'] is None or gl_cpos_to_j_end > maxima['gl_cpos_to_j_end']:
                    maxima['gl_cpos_to_j_end'] = gl_cpos_to_j_end

        if debug:
            print '    maxima:',
            for k, v in maxima.items():
                print '%s %d    ' % (k, v),
            print ''
        return maxima

    # ----------------------------------------------------------------------------------------
    def pad_seqs_to_same_length(self, debug=False):
        """
        Pad all sequences in <seqinfo> to the same length to the left and right of their conserved cysteine positions.
        Next, pads all sequences further out (if necessary) such as to eliminate all v_5p and j_3p deletions.
        """

        maxima = self.get_padding_parameters(debug=debug)

        for query in self.info['queries']:
            swfo = self.info[query]
            if 'padded' in swfo:  # already added padded information (we're probably partitioning, and this is not the first step)
                return
            seq = swfo['seq']
            cpos = swfo['cyst_position']
            if cpos < 0 or cpos >= len(seq):
                print 'hm now what do I want to do here?'
            k_v = swfo['k_v']

            padleft = maxima['gl_cpos'] - cpos  # left padding: biggest germline cpos minus cpos in this sequence
            padright = maxima['gl_cpos_to_j_end'] - (len(seq) - cpos)
            if padleft < 0 or padright < 0:
                raise Exception('bad padding %d %d for %s' % (padleft, padright, query))

            padfo = {}
            assert len(utils.ambiguous_bases) == 1  # could allow more than one, but it's not implemented a.t.m.
            padfo['seq'] = padleft * utils.ambiguous_bases[0] + seq + padright * utils.ambiguous_bases[0]
            if query in self.info['indels']:
                if debug:
                    print '    also padding reversed sequence'
                self.info['indels'][query]['reversed_seq'] = padleft * utils.ambiguous_bases[0] + self.info['indels'][query]['reversed_seq'] + padright * utils.ambiguous_bases[0]
            padfo['k_v'] = {'min' : k_v['min'] + padleft, 'max' : k_v['max'] + padleft}
            padfo['cyst_position'] = swfo['cyst_position'] + padleft
            padfo['padleft'] = padleft
            padfo['padright'] = padright
            if debug:
                print '      pad %d %d   %s' % (padleft, padright, query)
                print '     %d --> %d (%d-%d --> %d-%d)' % (len(seq), len(padfo['seq']),
                                                            k_v['min'], k_v['max'],
                                                            padfo['k_v']['min'], padfo['k_v']['max'])
            swfo['padded'] = padfo

        if debug:
            for query in self.info['queries']:
                print '%20s %s' % (query, self.info[query]['padded']['seq'])
Exemplo n.º 14
0
class JoinParser(object):
    def __init__(
        self, seqfname, joinfnames, datadir
    ):  # <seqfname>: input to joinsolver, <joinfname> output from joinsolver (I only need both because they don't seem to put the full query seq in the output)
        self.debug = 0
        self.n_max_queries = -1
        self.queries = []

        self.germline_seqs = utils.read_germlines(datadir,
                                                  remove_N_nukes=False)
        assert os.path.exists(os.getenv('www'))
        self.perfplotter = PerformancePlotter(
            self.germline_seqs,
            os.getenv('www') + '/partis/joinsolver_performance', 'js')

        # get info that was passed to joinsolver
        self.seqinfo = {}
        with opener('r')(seqfname) as seqfile:
            reader = csv.DictReader(seqfile)
            iline = 0
            for line in reader:
                if len(self.queries
                       ) > 0 and line['unique_id'] not in self.queries:
                    continue
                self.seqinfo[line['unique_id']] = line
                iline += 1
                if self.n_max_queries > 0 and iline >= self.n_max_queries:
                    break

        self.n_failed, self.n_total = 0, 0
        for joinfname in joinfnames:
            self.parse_file(joinfname)

        self.perfplotter.plot()
        print 'failed: %d / %d = %f' % (self.n_failed, self.n_total,
                                        float(self.n_failed) / self.n_total)

    # ----------------------------------------------------------------------------------------
    def parse_file(self, infname):
        tree = ET.parse(infname)
        root = tree.getroot()

        for query in root:
            self.n_total += 1
            if self.n_max_queries > 0 and self.n_total > self.n_max_queries:
                break

            unique_id = query.attrib['id'].replace('>', '').replace(' ', '')
            if len(self.queries) > 0 and unique_id not in self.queries:
                continue
            if self.debug:
                print self.n_total, unique_id
            line = {}
            line['unique_id'] = unique_id
            line['seq'] = self.seqinfo[unique_id]['seq']
            for region in utils.regions:
                if self.debug:
                    print ' ', region
                self.get_region_matches(region, query, line)
            if 'v_gene' not in line or 'd_gene' not in line or 'j_gene' not in line:
                print '  ERROR giving up on %s' % unique_id
                self.n_failed += 1
                continue

            add_insertions(line)
            try:
                resolve_overlapping_matches(line, self.debug)
            except:
                print 'ERROR apportionment failed on %s' % unique_id
                self.n_failed += 1
                continue

            self.perfplotter.evaluate(self.seqinfo[unique_id], line)

            if self.debug:
                utils.print_reco_event(self.germline_seqs, line)

    # ----------------------------------------------------------------------------------------
    def parse_match_seqs(self, match, region_query_seq):
        gl_match_seq = match.text
        if gl_match_seq == None:
            return (None, None)
        if self.debug > 1:
            print '     query', region_query_seq
            print '        gl', gl_match_seq

        # if gl match extends outside of query seq, strip off that part
        region_query_seq, gl_match_seq = cut_matches(region_query_seq,
                                                     gl_match_seq)
        # and if region_query_seq extends to left of gl match, strip off that stuff as well
        gl_match_seq, region_query_seq = cut_matches(gl_match_seq,
                                                     region_query_seq)
        # and remove the rest of the spaces
        region_query_seq = region_query_seq.replace(' ', '')
        gl_match_seq = gl_match_seq.replace(' ', '')

        # then replace dots in gl_match_seq, and just remove dashes
        assert len(gl_match_seq) == len(region_query_seq)
        new_glseq = []
        for inuke in range(len(region_query_seq)):
            if gl_match_seq[inuke] == '.':
                new_glseq.append(region_query_seq[inuke])
            elif gl_match_seq[inuke] == '-':
                pass
            else:
                assert gl_match_seq[inuke] in utils.nukes
                new_glseq.append(gl_match_seq[inuke])
        gl_match_seq = ''.join(new_glseq)

        if '-' in region_query_seq:
            print '    WARNING removing gaps in query seq'
            region_query_seq = region_query_seq.replace('-', '')

        if self.debug > 1:
            print '     query', region_query_seq
            print '        gl', gl_match_seq

        return (region_query_seq, gl_match_seq)

    # ----------------------------------------------------------------------------------------
    def get_region_matches(self, region, query, line):
        """ get info for <region> and add it to <line> """
        if query.find(region +
                      'matches').find('nomatches') != None:  # no matches!
            print '  ERROR no %s matches' % region
            return

        region_query_seq = query.find(region + 'matches').find('userSeq').find(
            'bases').text
        matches = [
            match for match in query.find(region + 'matches')
            if match.tag == 'germline'
        ]
        if len(matches) == 0:
            print 'ERROR no matches for', line['unique_id']
            return

        imatch = 0
        match = matches[imatch]
        # just take the first (best) match
        region_query_seq, gl_match_seq = self.parse_match_seqs(
            match, region_query_seq)
        if region_query_seq == None or gl_match_seq == None:
            return

        match_name = match.attrib['id'].replace(' ', '')
        try:
            match_name = figure_out_which_damn_gene(self.germline_seqs,
                                                    match_name,
                                                    gl_match_seq,
                                                    debug=self.debug)
        except AssertionError:  # couldn't find a decent one, so try again with the second match
            # well, ok, I guess I'll just *add* the damn thing to <germline_seqs>. NOTE that is so, so, dirty
            self.germline_seqs[region][match.attrib['id'].replace(
                ' ', '')] = gl_match_seq
            print '   WARNING adding %s to <germline_seqs>' % match.attrib[
                'id'].replace(' ', '')
            match_name = figure_out_which_damn_gene(self.germline_seqs,
                                                    match.attrib['id'].replace(
                                                        ' ', ''),
                                                    gl_match_seq,
                                                    debug=self.debug)

        if self.debug > 1:
            print '     ', match_name
        if region_query_seq not in line['seq']:
            print region_query_seq
            print line['seq']
        assert region_query_seq in line[
            'seq']  # they're not coming from the same file, so may as well make sure

        del_5p = self.germline_seqs[region][match_name].find(gl_match_seq)
        del_3p = len(self.germline_seqs[region][match_name]) - del_5p - len(
            gl_match_seq)
        if region == 'j' and del_5p < 0:
            print '    WARNING adding to right side of germline j'
            # assert len(gl_match_seq) > len(self.germline_seqs[region][match_name])
            self.germline_seqs[region][match_name] = gl_match_seq
            del_5p = self.germline_seqs[region][match_name].find(gl_match_seq)
            del_3p = len(self.germline_seqs[region]
                         [match_name]) - del_5p - len(gl_match_seq)

        if del_5p < 0 or del_3p < 0:
            print 'ERROR couldn\'t figure out deletions in', region
            print '    germline', self.germline_seqs[region][
                match_name], match_name
            print '    germline match', gl_match_seq
            print '   ', del_5p, del_3p
            return

        line[region + '_5p_del'] = del_5p
        line[region + '_3p_del'] = del_3p
        line[region + '_gene'] = match_name
        line[region + '_gl_seq'] = gl_match_seq
        line[region + '_qr_seq'] = region_query_seq
Exemplo n.º 15
0
			trueDictionary[unique_id] = {}
			trueDictionary[unique_id]['v_gene'] = row1['v_gene']	
			trueDictionary[unique_id]['d_gene'] = row1['d_gene']	
			trueDictionary[unique_id]['j_gene'] = row1['v_gene']	
			#print trueDictionary[unique_id]
			iDictionary[unique_id] = {}
			iDictionary[unique_id]['v_gene'] = row2['Best V hit']
			iDictionary[unique_id]['d_gene'] = row2['Best D hit']
			iDictionary[unique_id]['j_gene'] = row2['Best J hit']
			#print iDictionary[unique_id]

#run evaluate function from performanceplotter.py
for key in trueDictionary:
	#if key == '123818946361786991':
		#print 'RUNNING EVALUATE ON: ', key
	perfplotter.evaluate(trueDictionary[key], iDictionary[key])
	#perfplotter.evaluate(trueDictionary[key], iDictionary[key])
print 'COMPLETED EVALUATE'
#plot the information gained from the 'evaluate' function
perfplotter.plot(mixcrPlotDir)			
print mixcrPlotDir
print 'COMPLETED PLOTTING'
#----------------------------
#Code from previous development
'''	
with open("simu-10-leaves-1-mutate.csv") as inFile1:
	with open('edited_output_file.txt') as inFile2:
		reader1 = csv.DictReader(inFile1)
		reader2 = csv.DictReader(inFile2, delimiter='\t')
		for i1, i2 in zip(reader1, reader2):
			#gets the unique id number from the dictionary in the first id
Exemplo n.º 16
0
class JoinParser(object):
    def __init__(self, seqfname, joinfnames, datadir):  # <seqfname>: input to joinsolver, <joinfname> output from joinsolver (I only need both because they don't seem to put the full query seq in the output)
        self.debug = 0
        self.n_max_queries = -1
        self.queries = []

        self.germline_seqs = utils.read_germline_set(datadir, remove_N_nukes=False)['seqs']
        assert os.path.exists(os.getenv('www'))
        self.perfplotter = PerformancePlotter(self.germline_seqs, os.getenv('www') + '/partis/joinsolver_performance', 'js')

        # get info that was passed to joinsolver
        self.seqinfo = {}
        with opener('r')(seqfname) as seqfile:
            reader = csv.DictReader(seqfile)
            iline = 0
            for line in reader:
                if len(self.queries) > 0 and line['unique_id'] not in self.queries:
                    continue
                self.seqinfo[line['unique_id']] = line
                iline += 1
                if self.n_max_queries > 0 and iline >= self.n_max_queries:
                    break

        self.n_failed, self.n_total = 0, 0
        for joinfname in joinfnames:
            self.parse_file(joinfname)

        self.perfplotter.plot()
        print 'failed: %d / %d = %f' % (self.n_failed, self.n_total, float(self.n_failed) / self.n_total)

    # ----------------------------------------------------------------------------------------
    def parse_file(self, infname):
        tree = ET.parse(infname)
        root = tree.getroot()

        for query in root:
            self.n_total += 1
            if self.n_max_queries > 0 and self.n_total > self.n_max_queries:
                break

            unique_id = query.attrib['id'].replace('>', '').replace(' ', '')
            if len(self.queries) > 0 and  unique_id not in self.queries:
                continue
            if self.debug:
                print self.n_total, unique_id
            line = {}
            line['unique_id'] = unique_id
            line['seq'] = self.seqinfo[unique_id]['seq']
            for region in utils.regions:
                if self.debug:
                    print ' ', region
                self.get_region_matches(region, query, line)
            if 'v_gene' not in line or 'd_gene' not in line or 'j_gene' not in line:
                print '  ERROR giving up on %s' % unique_id
                self.n_failed += 1
                continue

            add_insertions(line)
            try:
                resolve_overlapping_matches(line, self.debug)
            except:
                print 'ERROR apportionment failed on %s' % unique_id
                self.n_failed += 1
                continue

            self.perfplotter.evaluate(self.seqinfo[unique_id], line)

            if self.debug:
                utils.print_reco_event(self.germline_seqs, line)

    # ----------------------------------------------------------------------------------------
    def parse_match_seqs(self, match, region_query_seq):
        gl_match_seq = match.text
        if gl_match_seq == None:
            return (None, None)
        if self.debug > 1:
            print '     query', region_query_seq
            print '        gl', gl_match_seq

        # if gl match extends outside of query seq, strip off that part
        region_query_seq, gl_match_seq = cut_matches(region_query_seq, gl_match_seq)
        # and if region_query_seq extends to left of gl match, strip off that stuff as well
        gl_match_seq, region_query_seq = cut_matches(gl_match_seq, region_query_seq)
        # and remove the rest of the spaces
        region_query_seq = region_query_seq.replace(' ', '')
        gl_match_seq = gl_match_seq.replace(' ', '')
    
        # then replace dots in gl_match_seq, and just remove dashes
        assert len(gl_match_seq) == len(region_query_seq)
        new_glseq = []
        for inuke in range(len(region_query_seq)):
            if gl_match_seq[inuke] == '.':
                new_glseq.append(region_query_seq[inuke])
            elif gl_match_seq[inuke] == '-':
                pass
            else:
                assert gl_match_seq[inuke] in utils.nukes
                new_glseq.append(gl_match_seq[inuke])
        gl_match_seq = ''.join(new_glseq)
        
        if '-' in region_query_seq:
            print '    WARNING removing gaps in query seq'
            region_query_seq = region_query_seq.replace('-', '')

        if self.debug > 1:
            print '     query', region_query_seq
            print '        gl', gl_match_seq

        return (region_query_seq, gl_match_seq)

    # ----------------------------------------------------------------------------------------
    def get_region_matches(self, region, query, line):
        """ get info for <region> and add it to <line> """
        if query.find(region + 'matches').find('nomatches') != None:  # no matches!
            print '  ERROR no %s matches' % region
            return

        region_query_seq = query.find(region + 'matches').find('userSeq').find('bases').text
        matches = [ match for match in query.find(region + 'matches') if match.tag == 'germline' ]
        if len(matches) == 0:
            print 'ERROR no matches for',line['unique_id']
            return

        imatch = 0
        match = matches[imatch]
        # just take the first (best) match
        region_query_seq, gl_match_seq = self.parse_match_seqs(match, region_query_seq)
        if region_query_seq == None or gl_match_seq == None:
            return

        match_name = match.attrib['id'].replace(' ', '')
        try:
            match_name = figure_out_which_damn_gene(self.germline_seqs, match_name, gl_match_seq, debug=self.debug)
        except AssertionError:  # couldn't find a decent one, so try again with the second match
            # well, ok, I guess I'll just *add* the damn thing to <germline_seqs>. NOTE that is so, so, dirty
            self.germline_seqs[region][match.attrib['id'].replace(' ', '')] = gl_match_seq
            print '   WARNING adding %s to <germline_seqs>' % match.attrib['id'].replace(' ', '')
            match_name = figure_out_which_damn_gene(self.germline_seqs, match.attrib['id'].replace(' ', ''), gl_match_seq, debug=self.debug)

        if self.debug > 1:
            print '     ', match_name
        if region_query_seq not in line['seq']:
            print region_query_seq
            print line['seq']
        assert region_query_seq in line['seq']  # they're not coming from the same file, so may as well make sure
            
        del_5p = self.germline_seqs[region][match_name].find(gl_match_seq)
        del_3p = len(self.germline_seqs[region][match_name]) - del_5p - len(gl_match_seq)
        if region == 'j' and del_5p < 0:
            print '    WARNING adding to right side of germline j'
            # assert len(gl_match_seq) > len(self.germline_seqs[region][match_name])
            self.germline_seqs[region][match_name] = gl_match_seq
            del_5p = self.germline_seqs[region][match_name].find(gl_match_seq)
            del_3p = len(self.germline_seqs[region][match_name]) - del_5p - len(gl_match_seq)
            
        if del_5p < 0 or del_3p < 0:
            print 'ERROR couldn\'t figure out deletions in', region
            print '    germline', self.germline_seqs[region][match_name], match_name
            print '    germline match', gl_match_seq
            print '   ', del_5p, del_3p
            return

        line[region + '_5p_del'] = del_5p
        line[region + '_3p_del'] = del_3p
        line[region + '_gene'] = match_name
        line[region + '_gl_seq'] = gl_match_seq
        line[region + '_qr_seq'] = region_query_seq