예제 #1
0
파일: SplitSeq.py 프로젝트: avilella/presto
    def _read_file(value_file, field):
        field_list = []
        try:
            with open(value_file, 'rt') as handle:
                reader_dict = csv.DictReader(handle, dialect='excel-tab')
                for row in reader_dict:
                    field_list.append(row[field])
        except IOError:
            printError('File %s cannot be read.' % value_file)
        except:
            printError('File %s is invalid.' % value_file)

        return field_list
예제 #2
0
def processSeqQueue(alive,
                    data_queue,
                    result_queue,
                    process_func,
                    process_args={}):
    """
    Pulls from data queue, performs calculations, and feeds results queue

    Arguments:
      alive : multiprocessing.Value boolean controlling whether processing
              continues; when False function returns
      data_queue : multiprocessing.Queue holding data to process
      result_queue : multiprocessing.Queue to hold processed results
      process_func : function to use for processing sequences
      process_args : Dictionary of arguments to pass to process_func

    Returns:
      None
    """
    try:
        # Iterator over data queue until sentinel object reached
        while alive.value:
            # Get data from queue
            if data_queue.empty(): continue
            else: data = data_queue.get()
            # Exit upon reaching sentinel
            if data is None: break

            # Perform work
            result = process_func(data, **process_args)

            #import cProfile
            #prof = cProfile.Profile()
            #result = prof.runcall(process_func, data, **process_args)
            #prof.dump_stats('worker-%d.prof' % os.getpid())

            # Feed results to result queue
            result_queue.put(result)
        else:
            sys.stderr.write('PID %s> Error in sibling process detected. Cleaning up.\n' \
                             % os.getpid())
            return None
    except:
        alive.value = False
        printError('Error processing sequence with ID: %s.' % data.id,
                   exit=False)
        raise

    return None
예제 #3
0
def offsetSeqSet(seq_list, offset_dict, field=default_primer_field, 
                 mode='pad', delimiter=default_delimiter):
    """
    Pads the head of a set of sequences with gaps according to an offset list

    Arguments: 
      seq_list : a list of SeqRecord objects to offset
      offset_dict : a dictionary of {set ID: offset values}
      field : the field in sequence description containing set IDs
      mode : defines the action taken; one of 'pad','cut'
      delimiter : a tuple of delimiters for (annotations, field/values, value lists)
        
    Returns: 
      Bio.Align.MultipleSeqAlignment: object containing the alignment.
    """
    ann_list = [parseAnnotation(s.description, delimiter=delimiter) for s in seq_list]
    tag_list = [a[field] for a in ann_list]

    # Pad sequences with offsets
    align_list = []
    if mode == 'pad':
        max_len = max([len(s) + offset_dict[t] 
                  for s, t in zip(seq_list, tag_list)])
        for rec, tag in zip(seq_list, tag_list):
            new_rec = rec[:]
            new_rec.letter_annotations = {}
            new_rec.seq = '-' * offset_dict[tag] + new_rec.seq
            new_rec.seq += '-' * (max_len - len(new_rec.seq))
            align_list.append(new_rec)
    # Cut sequences to common start position
    elif mode == 'cut':
        max_offset = max(offset_dict.values())
        cut_dict = {k:(max_offset - v) for k, v in offset_dict.items()}
        max_len = max([len(s) - cut_dict[t] 
                  for s, t in zip(seq_list, tag_list)])
        for rec, tag in zip(seq_list, tag_list):
            new_rec = rec[:]
            new_rec.letter_annotations = {}
            new_rec.seq = new_rec.seq[cut_dict[tag]:]
            new_rec.seq += '-' * (max_len - len(new_rec.seq))
            align_list.append(new_rec)
    else:
        printError('Invalid offset mode.')

    # Convert list to MultipleSeqAlignment object
    align = MultipleSeqAlignment(align_list)
    
    return align
예제 #4
0
def makeBlastnDb(ref_file, db_exec=default_blastdb_exec):
    """
    Makes a ublast database file

    Arguments:
      ref_file : the path to the reference database file
      db_exec : the path to the makeblastdb executable

    Returns:
      tuple : (name and location of the database, handle of the tempfile.TemporaryDirectory)
    """
    # Open temporary files
    seq_handle = tempfile.NamedTemporaryFile(suffix='.fasta', mode='w+t', encoding='utf-8')
    db_handle = tempfile.TemporaryDirectory()

    # Write temporary ungapped reference file
    ref_dict = readReferenceFile(ref_file)

    #writer = SeqIO.FastaIO.FastaWriter(seq_handle, wrap=None)
    #writer.write_file(ref_dict.values())
    SeqIO.write(ref_dict.values(), seq_handle,  format="fasta")

    seq_handle.seek(0)

    # Define usearch command
    cmd = [db_exec,
           '-in', seq_handle.name,
           '-out', os.path.join(db_handle.name, 'reference'),
           '-dbtype', 'nucl',
           '-title', 'reference',
           '-parse_seqids']
    try:
        stdout_str = check_output(cmd, stderr=STDOUT, shell=False,
                                  universal_newlines=True)
    except:
        seq_handle.close()
        printError('Failed to make blastn database.')

    # Close temporary sequence file
    seq_handle.close()

    return (os.path.join(db_handle.name, 'reference'), db_handle)
예제 #5
0
def makeUBlastDb(ref_file, db_exec=default_usearch_exec):
    """
    Makes a ublast database file

    Arguments:
      ref_file : path to the reference database file.
      db_exec : path to the usearch executable.

    Returns:
      tuple : (location of the database, handle of the tempfile.NamedTemporaryFile)
    """
    # Open temporary files
    seq_handle = tempfile.NamedTemporaryFile(suffix='.fasta', mode='w+t', encoding='utf-8')
    db_handle = tempfile.NamedTemporaryFile(suffix='.udb')

    # Write temporary ungapped reference file
    ref_dict = readReferenceFile(ref_file)
    writer = SeqIO.FastaIO.FastaWriter(seq_handle, wrap=None)
    writer.write_file(ref_dict.values())
    seq_handle.seek(0)

    # Define usearch command
    cmd = [db_exec,
           '-makeudb_ublast', seq_handle.name,
           '-wordlength', '9',
           '-output', db_handle.name,
           '-dbmask', 'none']
    try:
        stdout_str = check_output(cmd, stderr=STDOUT, shell=False,
                                  universal_newlines=True)
    except:
        seq_handle.close()
        printError('Failed to make usearch database.')

    # Close temporary sequence file
    seq_handle.close()

    return (db_handle.name, db_handle)
예제 #6
0
def parseIHMM(aligner_file, seq_file, repo, cellranger_file=None, partial=False, asis_id=True,
              extended=False, format=default_format, out_file=None, out_args=default_out_args):
    """
    Main for iHMMuneAlign aligned sample sequences.

    Arguments:
      aligner_file : iHMMune-Align output file to process.
      seq_file : fasta file input to iHMMuneAlign (from which to get sequence).
      repo : folder with germline repertoire files.
      partial : If True put incomplete alignments in the pass file.
      asis_id : if ID is to be parsed for pRESTO output with default delimiters.
      extended : if True parse alignment scores, FWR and CDR region fields.
      format : output format. One of 'changeo' or 'airr'.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
      dict : names of the 'pass' and 'fail' output files.
    """
    # Print parameter info
    log = OrderedDict()
    log['START'] = 'MakeDB'
    log['COMMAND'] = 'ihmm'
    log['ALIGNER_FILE'] = os.path.basename(aligner_file)
    log['SEQ_FILE'] = os.path.basename(seq_file)
    log['ASIS_ID'] = asis_id
    log['PARTIAL'] = partial
    log['EXTENDED'] = extended
    printLog(log)

    start_time = time()
    printMessage('Loading files', start_time=start_time, width=20)

    # Count records in sequence file
    total_count = countSeqFile(seq_file)

    # Get input sequence dictionary
    seq_dict = getSeqDict(seq_file)

    # Create germline repo dictionary
    references = readGermlines(repo)

    # Load supplementary annotation table
    if cellranger_file is not None:
        f = cellranger_extended if extended else cellranger_base
        annotations = readCellRanger(cellranger_file, fields=f)
    else:
        annotations = None

    printMessage('Done', start_time=start_time, end=True, width=20)

    # Check for IMGT-gaps in germlines
    if all('...' not in x for x in references.values()):
        printWarning('Germline reference sequences do not appear to contain IMGT-numbering spacers. Results may be incorrect.')

    # Define format operators
    try:
        __, writer, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)
    out_args['out_type'] = schema.out_type

    # Define output fields
    fields = list(schema.required)
    if extended:
        custom = IHMMuneReader.customFields(scores=True, regions=True, schema=schema)
        fields.extend(custom)

    # Parse and write output
    with open(aligner_file, 'r') as f:
        parse_iter = IHMMuneReader(f, seq_dict, references)
        germ_iter = (addGermline(x, references) for x in parse_iter)
        output = writeDb(germ_iter, fields=fields, aligner_file=aligner_file, total_count=total_count, 
                        annotations=annotations, asis_id=asis_id, partial=partial,
                        writer=writer, out_file=out_file, out_args=out_args)

    return output
예제 #7
0
파일: SplitSeq.py 프로젝트: avilella/presto
def groupSeqFile(seq_file, field, threshold=None, out_args=default_out_args):
    """
    Divides a sequence file into segments by description tags

    Arguments: 
      seq_file : filename of the sequence file to split
      field : The annotation field to split seq_file by
      threshold : The numerical threshold for group sequences by;
                  if None treat field as textual
      out_args : common output argument dictionary from parseCommonArgs

    Returns: 
      list: output file names
    """
    log = OrderedDict()
    log['START'] = 'SplitSeq'
    log['COMMAND'] = 'group'
    log['FILE'] = os.path.basename(seq_file)
    log['FIELD'] = field
    log['THRESHOLD'] = threshold
    printLog(log)

    # Open file handles
    in_type = getFileType(seq_file)
    seq_iter = readSeqFile(seq_file)
    if out_args['out_type'] is None: out_args['out_type'] = in_type

    # Determine total numbers of records
    rec_count = countSeqFile(seq_file)

    # Process sequences
    start_time = time()
    seq_count = 0
    if threshold is None:
        # Sort records into files based on textual field
        # Create set of unique field tags
        temp_iter = readSeqFile(seq_file)
        tag_list = getAnnotationValues(temp_iter,
                                       field,
                                       unique=True,
                                       delimiter=out_args['delimiter'])

        if sys.platform != 'win32':
            import resource
            # Increase open file handle limit if needed
            file_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
            file_count = len(tag_list) + 256
            if file_limit < file_count and file_count <= 8192:
                #print file_limit, file_count
                resource.setrlimit(resource.RLIMIT_NOFILE,
                                   (file_count, file_count))
            elif file_count > 8192:
                e = '''OS file limit would need to be set to %i.
                    If you are sure you want to do this, then increase the 
                    file limit in the OS (via ulimit) and rerun this tool.
                    ''' % file_count
                printError(dedent(e))

        # Create output handles
        # out_label = '%s=%s' % (field, tag)
        handles_dict = {
            tag: getOutputHandle(seq_file,
                                 '%s-%s' % (field, tag),
                                 out_dir=out_args['out_dir'],
                                 out_name=out_args['out_name'],
                                 out_type=out_args['out_type'])
            for tag in tag_list
        }

        # Iterate over sequences
        for seq in seq_iter:
            printProgress(seq_count, rec_count, 0.05, start_time=start_time)
            seq_count += 1
            # Write sequences
            tag = parseAnnotation(seq.description,
                                  delimiter=out_args['delimiter'])[field]
            SeqIO.write(seq, handles_dict[tag], out_args['out_type'])
    else:
        # Sort records into files based on numeric threshold
        threshold = float(threshold)
        # Create output handles
        handles_dict = {
            'under':
            getOutputHandle(seq_file,
                            'under-%.1g' % threshold,
                            out_dir=out_args['out_dir'],
                            out_name=out_args['out_name'],
                            out_type=out_args['out_type']),
            'atleast':
            getOutputHandle(seq_file,
                            'atleast-%.1g' % threshold,
                            out_dir=out_args['out_dir'],
                            out_name=out_args['out_name'],
                            out_type=out_args['out_type'])
        }

        # Iterate over sequences
        for seq in seq_iter:
            printProgress(seq_count, rec_count, 0.05, start_time=start_time)
            seq_count += 1
            # Write sequences
            tag = parseAnnotation(seq.description,
                                  delimiter=out_args['delimiter'])[field]
            tag = 'under' if float(tag) < threshold else 'atleast'
            SeqIO.write(seq, handles_dict[tag], out_args['out_type'])

    # Print log
    printProgress(seq_count, rec_count, 0.05, start_time=start_time)
    log = OrderedDict()
    for i, k in enumerate(handles_dict):
        log['OUTPUT%i' % (i + 1)] = os.path.basename(handles_dict[k].name)
    log['SEQUENCES'] = rec_count
    log['PARTS'] = len(handles_dict)
    log['END'] = 'SplitSeq'
    printLog(log)

    # Close output file handles
    for k in handles_dict:
        handles_dict[k].close()

    return [handles_dict[k].name for k in handles_dict]
예제 #8
0
def filterSeq(seq_file,
              filter_func,
              filter_args={},
              out_file=None,
              out_args=default_out_args,
              nproc=None,
              queue_size=None):
    """
    Filters sequences by fraction of ambiguous nucleotides
    
    Arguments: 
      seq_file : the sequence file to filter.
      filter_func : the function to use for filtering sequences.
      filter_args : a dictionary of arguments to pass to filter_func.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.
      nproc : the number of processQueue processes;
              if None defaults to the number of CPUs.
      queue_size : maximum size of the argument queue;
                   if None defaults to 2*nproc.
                 
    Returns:
      list: a list of successful output file names
    """
    # Define output file label dictionary
    cmd_dict = {
        filterLength: 'length',
        filterMissing: 'missing',
        filterRepeats: 'repeats',
        filterQuality: 'quality',
        maskQuality: 'maskqual',
        trimQuality: 'trimqual'
    }

    # Print parameter info
    log = OrderedDict()
    log['START'] = 'FilterSeq'
    log['COMMAND'] = cmd_dict.get(filter_func, filter_func.__name__)
    log['FILE'] = os.path.basename(seq_file)
    for k in sorted(filter_args):
        log[k.upper()] = filter_args[k]
    log['NPROC'] = nproc
    printLog(log)

    # Check input type
    in_type = getFileType(seq_file)
    if in_type != 'fastq' and filter_func in (filterQuality, maskQuality,
                                              trimQuality):
        printError('Input file must be FASTQ for %s mode.' %
                   cmd_dict[filter_func])

    # Define feeder function and arguments
    feed_func = feedSeqQueue
    feed_args = {'seq_file': seq_file}
    # Define worker function and arguments
    work_func = processSeqQueue
    work_args = {'process_func': filter_func, 'process_args': filter_args}
    # Define collector function and arguments
    collect_func = collectSeqQueue
    collect_args = {
        'seq_file': seq_file,
        'label': cmd_dict[filter_func],
        'out_file': out_file,
        'out_args': out_args
    }

    # Call process manager
    result = manageProcesses(feed_func, work_func, collect_func, feed_args,
                             work_args, collect_args, nproc, queue_size)

    # Print log
    result['log']['END'] = 'FilterSeq'
    printLog(result['log'])

    return result['out_files']
예제 #9
0
def manageProcesses(feed_func,
                    work_func,
                    collect_func,
                    feed_args={},
                    work_args={},
                    collect_args={},
                    nproc=None,
                    queue_size=None):
    """
    Manages feeder, worker and collector processes

    Arguments:
      feed_func (function): Data Queue feeder function.
      work_func (function): Worker function.
      collect_func (function): Result Queue collector function.
      feed_args (dict): Dictionary of arguments to pass to feed_func.
      work_args (dict): Dictionary of arguments to pass to work_func.
      collect_args (dict): Dictionary of arguments to pass to collect_func.
      nproc (int): Number of processQueue processes;
                   if None defaults to the number of CPUs
      queue_size (int): Maximum size of the argument queue;
                        if None defaults to 2*nproc

    Returns:
      dict: Dictionary of collector results
    """

    # Define signal handler that raises KeyboardInterrupt
    def _signalHandler(s, f):
        raise SystemExit

    # Define function to terminate child processes
    def _terminate():
        sys.stderr.write('NOTICE> Terminating child processes...  ')
        # Terminate feeders
        feeder.terminate()
        feeder.join()
        # Terminate workers
        for w in workers:
            w.terminate()
            w.join()
        # Terminate collector
        collector.terminate()
        collector.join
        sys.stderr.write('Done.\n')

    # Raise SystemExit upon termination signal
    signal.signal(signal.SIGTERM, _signalHandler)

    # Define number of processes and queue size
    if nproc is None: nproc = mp.cpu_count()
    if queue_size is None: queue_size = nproc * 2

    # Define shared child process keep alive flag
    alive = mp.Value(ctypes.c_bool, True)

    # Define shared data queues
    data_queue = mp.Queue(queue_size)
    result_queue = mp.Queue(queue_size)
    # TODO:  find out what's up with this context shenanigans
    ctx = mp.get_context()
    collect_queue = ctx.SimpleQueue()
    # Initiate manager and define shared data objects

    try:
        # Initiate feeder process
        feeder = mp.Process(target=feed_func,
                            args=(alive, data_queue),
                            kwargs=feed_args)
        feeder.start()

        # Initiate worker processes
        workers = []
        for __ in range(nproc):
            w = mp.Process(target=work_func,
                           args=(alive, data_queue, result_queue),
                           kwargs=work_args)
            w.start()
            workers.append(w)

        # Initiate collector process
        collector = mp.Process(target=collect_func,
                               args=(alive, result_queue, collect_queue),
                               kwargs=collect_args)
        collector.start()

        # Wait for feeder to finish and add sentinel objects to data_queue
        feeder.join()
        for __ in range(nproc):
            data_queue.put(None)

        # Wait for worker processes to finish and add sentinel to result_queue
        for w in workers:
            w.join()
        result_queue.put(None)

        # Wait for collector process to finish and add sentinel to collect_queue
        collector.join()
        collect_queue.put(None)

        # Get collector return values
        collected = collect_queue.get()
    except (KeyboardInterrupt, SystemExit):
        sys.stderr.write('NOTICE> Exit signal received.\n')
        _terminate()
        sys.exit()
    except Exception as e:
        printError('%s.' % e, exit=False)
        _terminate()
        sys.exit()
    except:
        printError('Exiting with unknown exception.', exit=False)
        _terminate()
        sys.exit()
    else:
        if not alive.value:
            printError('Exiting due to child process error.', exit=False)
            _terminate()
            sys.exit()

    return collected
예제 #10
0
def processQueue(alive,
                 data_queue,
                 result_queue,
                 cluster_func,
                 cluster_args={},
                 cluster_field=default_cluster_field,
                 cluster_prefix=default_cluster_prefix,
                 delimiter=default_delimiter):
    """
    Pulls from data queue, performs calculations, and feeds results queue

    Arguments:
      alive : a multiprocessing.Value boolean controlling whether processing
              continues; when False function returns.
      data_queue : a multiprocessing.Queue holding data to process.
      result_queue : a multiprocessing.Queue to hold processed results.
      cluster_func : the function to use for clustering.
      cluster_args : a dictionary of optional arguments for the clustering function.
      cluster_field : string defining the output cluster field name.
      cluster_prefix : string defining a prefix for the cluster identifier.
      delimiter : a tuple of delimiters for (annotations, field/values, value lists).

    Returns:
      None
    """
    try:
        # Iterator over data queue until sentinel object reached
        while alive.value:
            # Get data from queue
            if data_queue.empty(): continue
            else: data = data_queue.get()
            # Exit upon reaching sentinel
            if data is None: break

            # Define result object
            result = SeqResult(data.id, data.data)
            result.log['BARCODE'] = data.id
            result.log['SEQCOUNT'] = len(data)

            # Perform clustering
            cluster_dict = cluster_func(data.data, **cluster_args)

            # Process failed result
            if cluster_dict is None:
                # Update log
                result.log['CLUSTERS'] = 0
                for i, seq in enumerate(data.data, start=1):
                    result.log['CLUST0-%i' % i] = str(seq.seq)

                # Feed results queue and continue
                result_queue.put(result)
                continue

            # Get number of clusters
            result.log['CLUSTERS'] = len(cluster_dict)

            # Update sequence annotations with cluster assignments
            results = list()
            seq_dict = {s.id: s for s in data.data}
            for cluster, id_list in cluster_dict.items():
                for i, seq_id in enumerate(id_list, start=1):
                    # Add cluster annotation
                    seq = seq_dict[seq_id]
                    label = '%s%i' % (cluster_prefix, cluster)
                    header = parseAnnotation(seq.description,
                                             delimiter=delimiter)
                    header = mergeAnnotation(header, {cluster_field: label},
                                             delimiter=delimiter)
                    seq.id = seq.name = flattenAnnotation(header,
                                                          delimiter=delimiter)
                    seq.description = ''

                    # Update log and results
                    result.log['CLUST%i-%i' % (cluster, i)] = str(seq.seq)
                    results.append(seq)

            # Check results
            result.results = results
            result.valid = (len(results) == len(seq_dict))
            # Feed results to result queue
            result_queue.put(result)
        else:
            sys.stderr.write('PID %s> Error in sibling process detected. Cleaning up.\n' \
                             % os.getpid())

            return None
    except:
        alive.value = False
        printError('Error processing sequence set with ID: %s.' % data.id,
                   exit=False)
        raise

    return None
예제 #11
0
def parseIMGT(aligner_file, seq_file=None, repo=None, cellranger_file=None, partial=False, asis_id=True,
              extended=False, format=default_format, out_file=None, out_args=default_out_args):
    """
    Main for IMGT aligned sample sequences.

    Arguments:
      aligner_file : zipped file or unzipped folder output by IMGT.
      seq_file : FASTA file input to IMGT (from which to get seqID).
      repo : folder with germline repertoire files.
      partial : If True put incomplete alignments in the pass file.
      asis_id : if ID is to be parsed for pRESTO output with default delimiters.
      extended : if True add alignment score, FWR, CDR and junction fields to output file.
      format : output format. one of 'changeo' or 'airr'.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
      dict : names of the 'pass' and 'fail' output files.
    """
    # Print parameter info
    log = OrderedDict()
    log['START'] = 'MakeDb'
    log['COMMAND'] = 'imgt'
    log['ALIGNER_FILE'] = aligner_file
    log['SEQ_FILE'] = os.path.basename(seq_file) if seq_file else ''
    log['ASIS_ID'] = asis_id
    log['PARTIAL'] = partial
    log['EXTENDED'] = extended
    printLog(log)

    start_time = time()
    printMessage('Loading files', start_time=start_time, width=20)

    # Extract IMGT files
    temp_dir, imgt_files = extractIMGT(aligner_file)

    # Count records in IMGT files
    total_count = countDbFile(imgt_files['summary'])

    # Get (parsed) IDs from fasta file submitted to IMGT
    id_dict = getIDforIMGT(seq_file) if seq_file else {}

    # Load supplementary annotation table
    if cellranger_file is not None:
        f = cellranger_extended if extended else cellranger_base
        annotations = readCellRanger(cellranger_file, fields=f)
    else:
        annotations = None

    printMessage('Done', start_time=start_time, end=True, width=20)

    # Define format operators
    try:
        __, writer, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)
    out_args['out_type'] = schema.out_type

    # Define output fields
    fields = list(schema.required)
    if extended:
        custom = IMGTReader.customFields(scores=True, regions=True, junction=True, schema=schema)
        fields.extend(custom)

    # Parse IMGT output and write db
    with open(imgt_files['summary'], 'r') as summary_handle, \
            open(imgt_files['gapped'], 'r') as gapped_handle, \
            open(imgt_files['ntseq'], 'r') as ntseq_handle, \
            open(imgt_files['junction'], 'r') as junction_handle:

        # Open parser
        parse_iter = IMGTReader(summary_handle, gapped_handle, ntseq_handle, junction_handle)

        # Add germline sequence
        if repo is None:
            germ_iter = parse_iter
        else:
            references = readGermlines(repo)
            # Check for IMGT-gaps in germlines
            if all('...' not in x for x in references.values()):
                printWarning('Germline reference sequences do not appear to contain IMGT-numbering spacers. Results may be incorrect.')
            germ_iter = (addGermline(x, references) for x in parse_iter)

        # Write db
        output = writeDb(germ_iter, fields=fields, aligner_file=aligner_file, total_count=total_count, 
                         annotations=annotations, id_dict=id_dict, asis_id=asis_id, partial=partial,
                         writer=writer, out_file=out_file, out_args=out_args)

    # Cleanup temp directory
    temp_dir.cleanup()

    return output
예제 #12
0
def runUClust(seq_list, ident=default_cluster_ident, length_ratio=default_length_ratio,
              seq_start=0, seq_end=None,
              threads=1, cluster_exec=default_usearch_exec):
    """
    Cluster a set of sequences using the UCLUST algorithm from USEARCH

    Arguments:
      seq_list (list): a list of SeqRecord objects to align.
      ident (float): the sequence identity cutoff to be passed to usearch.
      length_ratio (float): usearch parameter defining the minimum short/long length
                            ratio allowed within a cluster.
      seq_start (int): the start position to trim sequences at before clustering.
      seq_end (int): the end position to trim sequences at before clustering.
      threads (int): number of threads for usearch.
      cluster_exec (str): the path to the usearch executable.

    Returns:
      dict: {cluster id: list of sequence ids}.
    """
    # Function to trim and mask sequences
    gap_trans = str.maketrans({'-': 'N', '.': 'N'})
    def _clean(rec, i, j):
        seq = str(rec.seq[i:j])
        seq = seq.translate(gap_trans)
        return SeqRecord(Seq(seq), id=rec.id, name=rec.name, description=rec.description)

    # Make a trimmed and masked copy of each sequence so we don't mess up originals
    seq_trimmed = [_clean(x, seq_start, seq_end) for x in seq_list]

    # Return sequence if only one sequence in seq_iter
    if len(seq_trimmed) < 2:
        return {1:[seq_trimmed[0].id]}

    # If there are any empty sequences after trimming return None
    if any([len(x.seq) == 0 for x in seq_trimmed]):
        return None

    # Open temporary files
    in_handle = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8')
    out_handle = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8')

    # Define usearch command
    cmd = [cluster_exec,
           '-cluster_fast', in_handle.name,
           '-uc', out_handle.name,
           '-id', str(ident),
           '-minsl', str(length_ratio),
           '-qmask', 'none',
           '-minseqlength', '1',
           '-threads', str(threads)]

    # Write usearch input fasta file
    SeqIO.write(seq_trimmed, in_handle, 'fasta')
    in_handle.seek(0)

    # Run usearch uclust algorithm
    try:
        stdout_str = check_output(cmd, stderr=STDOUT, shell=False,
                                  universal_newlines=True)
    except CalledProcessError as e:
        printError('Running command: %s\n%s' % (' '.join(cmd), e.output))

    # Parse the results of usearch
    # Output columns for the usearch 'uc' output format
    #   0 = entry type -- S: centroid seq, H: hit, C: cluster record (redundant with S)
    #   1 = group the sequence is assigned to
    #   8 = the id of the sequence
    #   9 = id of the centroid for cluster
    cluster_dict = {}
    for row in csv.reader(out_handle, delimiter='\t'):
        if row[0] in ('H', 'S'):
            # Trim sequence label to portion before space for usearch v9 compatibility
            key = int(row[1]) + 1
            # Trim sequence label to portion before space for usearch v9 compatibility
            hit = row[8].split()[0]
            # Update cluster dictionary
            cluster = cluster_dict.setdefault(key, [])
            cluster.append(hit)

    return cluster_dict if cluster_dict else None
예제 #13
0
def runCDHit(seq_list, ident=default_cluster_ident, length_ratio=default_length_ratio,
             seq_start=0, seq_end=None, max_memory=default_max_memory,
             threads=1, cluster_exec=default_cdhit_exec):
    """
    Cluster a set of sequences using CD-HIT

    Arguments:
      seq_list (list): a list of SeqRecord objects to align.
      ident (float): the sequence identity cutoff to be passed to cd-hit-est.
      length_ratio (float): cd-hit-est parameter defining the minimum short/long length
                            ratio allowed within a cluster.
      seq_start (int): the start position to trim sequences at before clustering.
      seq_end (int): the end position to trim sequences at before clustering.
      max_memory (int): cd-hit-est max memory limit (Mb)
      threads (int): number of threads for cd-hit-est.
      cluster_exec (str): the path to the cd-hit-est executable.

    Returns:
      dict: {cluster id: list of sequence ids}.
    """
    # Function to trim and mask sequences
    gap_trans = str.maketrans({'-': 'N', '.': 'N'})
    def _clean(rec, i, j):
        seq = str(rec.seq[i:j])
        seq = seq.translate(gap_trans)
        return SeqRecord(Seq(seq), id=rec.id, name=rec.name, description=rec.description)

    # Make a trimmed and masked copy of each sequence so we don't mess up originals
    seq_trimmed = [_clean(x, seq_start, seq_end) for x in seq_list]

    # Return sequence if only one sequence in seq_iter
    if len(seq_trimmed) < 2:
        return {1:[seq_trimmed[0].id]}

    # If there are any empty sequences after trimming return None
    if any([len(x.seq) == 0 for x in seq_trimmed]):
        return None

    # Open temporary files
    in_handle = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8')
    out_handle = tempfile.NamedTemporaryFile(mode='w+t', encoding='utf-8')

    # Define usearch command
    cmd = [cluster_exec,
           '-i', in_handle.name,
           '-o', out_handle.name,
           '-c', str(ident),
           '-s', str(length_ratio),
           '-n', '3',
           '-d', '0',
           '-M', str(max_memory),
           '-T', str(threads)]

    # Write usearch input fasta file
    SeqIO.write(seq_trimmed, in_handle, 'fasta')
    in_handle.seek(0)

    # Run CD-HIT
    try:
        stdout_str = check_output(cmd, stderr=STDOUT, shell=False,
                                  universal_newlines=True)
    except CalledProcessError as e:
        printError('Running command: %s\n%s' % (' '.join(cmd), e.output))

    # Parse the results of CD-HIT
    # Output of the .clstr file
    #   >Cluster 0
    #   0	17nt, >S01|BARCODE=CTAAGTGACTGGAGTTC... *
    #   1	17nt, >S02|BARCODE=CTAAGTGACTGGAGTTC... at +/100.00%
    #   2	17nt, >S07|BARCODE=CTAAGTGACTGGACTTC... at +/94.12%
    #   >Cluster 1
    #   0	17nt, >S12|BARCODE=TTTTTTTTTTTTTTTTT... *
    # Parsing regex
    block_regex = re.compile('>Cluster [0-9]+')
    id_regex = re.compile('([0-9]+\t[0-9]+nt, \>)(.+)(\.\.\.)')

    # Parse .clstr file
    cluster_dict = {}
    cluster_file = '%s.clstr' % out_handle.name
    with open(cluster_file, 'r') as cluster_handle:
        # Define parsing blocks
        clusters = groupby(cluster_handle, key=lambda x: block_regex.match(x))
        # Iterate over clusters and update return dict
        count = 1
        for key, group in clusters:
            if key is not None:
                __, block = next(clusters)
                cluster_dict[count] = [id_regex.match(x).group(2) for x in block]
                count += 1

    # Delete temp file
    os.remove(cluster_file)

    return cluster_dict if cluster_dict else None
예제 #14
0
def processEEQueue(alive,
                   data_queue,
                   result_queue,
                   cons_func,
                   cons_args={},
                   min_count=default_min_count,
                   max_diversity=None):
    """
    Pulls from data queue, performs calculations, and feeds results queue

    Arguments: 
      alive : a multiprocessing.Value boolean controlling whether processing 
              continues; when False function returns
      data_queue : a multiprocessing.Queue holding data to process
      result_queue : a multiprocessing.Queue to hold processed results
      cons_func : the function to use for consensus generation 
      cons_args : a dictionary of optional arguments for the consensus function
      min_count : threshold number of sequences to retain a set
      max_diversity : the minimum diversity score to retain a set;
                    if None do not calculate diversity
                        
    Returns: 
      None
    """
    try:
        # Iterator over data queue until sentinel object reached
        while alive.value:
            # Get data from queue
            if data_queue.empty(): continue
            else: data = data_queue.get()
            # Exit upon reaching sentinel
            if data is None: break

            # Define result dictionary for iteration
            result = SeqResult(data.id, data.data)
            result.results = {
                'pos': None,
                'nuc': None,
                'qual': None,
                'set': None
            }

            # Define sequences set
            seq_list = data.data
            seq_count = len(data)

            # Update log
            result.log['SET'] = data.id
            result.log['SEQCOUNT'] = seq_count
            for i, s in enumerate(seq_list):
                result.log['SEQ%i' % (i + 1)] = str(s.seq)

            # Check count threshold and continue if failed
            if len(data) < min_count:
                result_queue.put(result)
                continue

            #@ Check all sequences in group have the same length and continue if failed
            init_len = len(seq_list[0])
            if any(len(seq) != init_len for seq in seq_list):
                result_queue.put(result)
                continue

            # Calculate average pairwise error rate
            if max_diversity is not None:
                diversity = calculateDiversity(seq_list)
                result.log['DIVERSITY'] = diversity
                # Check diversity threshold and continue if failed
                if diversity > max_diversity:
                    result_queue.put(result)
                    continue

            # Define reference sequence by consensus
            ref_seq = cons_func(seq_list, **cons_args)

            # Count mismatches against consensus
            mismatch = countMismatches(seq_list, ref_seq)

            # Calculate average reported and observed error
            reported_q = mismatch['set']['q_sum'][len(
                seq_list)] / mismatch['set']['total'][len(seq_list)]
            error_rate = mismatch['set']['mismatch'][len(
                seq_list)] / mismatch['set']['total'][len(seq_list)]

            # Update log
            result.log['REFERENCE'] = str(ref_seq.seq)
            result.log['MISMATCH'] = ''.join(['*' if x > 0 else ' ' \
                                              for x in mismatch['pos']['mismatch']])
            result.log['ERROR'] = '%.6f' % error_rate
            result.log['REPORTED_Q'] = '%.2f' % reported_q
            result.log['EMPIRICAL_Q'] = '%.2f' % (
                -10 * np.log10(max(error_rate, 1e-9)))

            # Update results and feed result queue
            result.valid = True
            result.results.update(mismatch)
            result_queue.put(result)
        else:
            sys.stderr.write('PID %s> Error in sibling process detected. Cleaning up.\n' \
                             % os.getpid())
            return None
    except:
        alive.value = False
        printError('Error processing sequence set with ID: %s.' % data.id,
                   exit=False)
        raise

    return None
예제 #15
0
def estimateSets(seq_file,
                 cons_func=frequencyConsensus,
                 cons_args={},
                 set_field=default_barcode_field,
                 min_count=default_min_count,
                 max_diversity=None,
                 out_args=default_out_args,
                 nproc=None,
                 queue_size=None):
    """
    Calculates error rates of sequence sets

    Arguments: 
      seq_file : the sample sequence file name
      cons_func : the function to use for consensus generation 
      cons_args : a dictionary of arguments for the consensus function
      set_field : the annotation field containing set IDs
      min_count : threshold number of sequences to consider a set
      max_diversity : a threshold defining the average pairwise error rate required to retain a read group;
                    if None do not calculate diversity
      out_args : common output argument dictionary from parseCommonArgs
      nproc : the number of processQueue processes;
            if None defaults to the number of CPUs
      queue_size : maximum size of the argument queue;
                 if None defaults to 2*nproc
                    
    Returns: 
      tuple : (position error, quality error, nucleotide pairwise error) output file names
    """
    # Define subcommand label dictionary
    cmd_dict = {frequencyConsensus: 'freq', qualityConsensus: 'qual'}

    # Print parameter info
    log = OrderedDict()
    log['START'] = 'EstimateError'
    log['FILE'] = os.path.basename(seq_file)
    log['MODE'] = cmd_dict.get(cons_func, cons_func.__name__)
    log['SET_FIELD'] = set_field
    log['MIN_COUNT'] = min_count
    log['MAX_DIVERSITY'] = max_diversity
    log['NPROC'] = nproc
    printLog(log)

    # Check input file type
    in_type = getFileType(seq_file)
    if in_type != 'fastq':
        printError('Input file must be FASTQ.')

    # Define feeder function and arguments
    index_args = {'field': set_field, 'delimiter': out_args['delimiter']}
    feed_func = feedSeqQueue
    feed_args = {
        'seq_file': seq_file,
        'index_func': indexSeqSets,
        'index_args': index_args
    }
    # Define worker function and arguments
    work_func = processEEQueue
    work_args = {
        'cons_func': cons_func,
        'cons_args': cons_args,
        'min_count': min_count,
        'max_diversity': max_diversity
    }
    # Define collector function and arguments
    collect_func = collectEEQueue
    collect_args = {
        'seq_file': seq_file,
        'out_args': out_args,
        'set_field': set_field
    }

    # Call process manager
    result = manageProcesses(feed_func, work_func, collect_func, feed_args,
                             work_args, collect_args, nproc, queue_size)

    # Print log
    result['log']['END'] = 'EstimateError'
    printLog(result['log'])

    return result['out_files']
예제 #16
0
def createGermlines(db_file, references, seq_field=default_seq_field, v_field=default_v_field,
                    d_field=default_d_field, j_field=default_j_field,
                    cloned=False, clone_field=default_clone_field, germ_types=default_germ_types,
                    format=default_format, out_file=None, out_args=default_out_args):
    """
    Write germline sequences to tab-delimited database file

    Arguments:
      db_file : input tab-delimited database file.
      references : folders and/or files containing germline repertoire data in FASTA format.
      seq_field : field in which to look for sequence.
      v_field : field in which to look for V call.
      d_field : field in which to look for D call.
      j_field : field in which to look for J call.
      cloned : if True build germlines by clone, otherwise build individual germlines.
      clone_field : field containing clone identifiers; ignored if cloned=False.
      germ_types : list of germline sequence types to be output from the set of 'full', 'dmask', 'vonly', 'regions'
      format : input and output format.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : arguments for output preferences.

    Returns:
      dict: names of the 'pass' and 'fail' output files.
    """
    # Print parameter info
    log = OrderedDict()
    log['START'] = 'CreateGermlines'
    log['FILE'] = os.path.basename(db_file)
    log['GERM_TYPES'] = ','.join(germ_types)
    log['SEQ_FIELD'] = seq_field
    log['V_FIELD'] = v_field
    log['D_FIELD'] = d_field
    log['J_FIELD'] = j_field
    log['CLONED'] = cloned
    if cloned:  log['CLONE_FIELD'] = clone_field
    printLog(log)

    # Define format operators
    try:
        reader, writer, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s' % format)
    out_args['out_type'] = schema.out_type

    # TODO: this won't work for AIRR necessarily
    # Define output germline fields
    germline_fields = OrderedDict()
    seq_type = seq_field.split('_')[-1]
    if 'full' in germ_types:  germline_fields['full'] = 'germline_' + seq_type
    if 'dmask' in germ_types:  germline_fields['dmask'] = 'germline_' + seq_type + '_d_mask'
    if 'vonly' in germ_types:  germline_fields['vonly'] = 'germline_' + seq_type + '_v_region'
    if 'regions' in germ_types:  germline_fields['regions'] = 'germline_regions'
    if cloned:
        germline_fields['v'] = 'germline_v_call'
        germline_fields['d'] = 'germline_d_call'
        germline_fields['j'] = 'germline_j_call'
    out_fields = getDbFields(db_file,
                             add=[schema.fromReceptor(f) for f in germline_fields.values()],
                             reader=reader)

    # Get repertoire and open Db reader
    reference_dict = readGermlines(references)
    db_handle = open(db_file, 'rt')
    db_iter = reader(db_handle)

    # Check for required columns
    try:
        required = ['v_germ_start_imgt', 'd_germ_start', 'j_germ_start',
                    'np1_length', 'np2_length']
        checkFields(required, db_iter.fields, schema=schema)
    except LookupError as e:
        printError(e)

    # Check for IMGT-gaps in germlines
    if all('...' not in x for x in reference_dict.values()):
        printWarning('Germline reference sequences do not appear to contain IMGT-numbering spacers. Results may be incorrect.')

    # Count input
    total_count = countDbFile(db_file)

    # Check for existence of fields
    for f in [v_field, d_field, j_field, seq_field]:
        if f not in db_iter.fields:
            printError('%s field does not exist in input database file.' % f)

    # Translate to Receptor attribute names
    v_field = schema.toReceptor(v_field)
    d_field = schema.toReceptor(d_field)
    j_field = schema.toReceptor(j_field)
    seq_field = schema.toReceptor(seq_field)
    clone_field = schema.toReceptor(clone_field)

    # Define Receptor iterator
    if cloned:
        start_time = time()
        printMessage('Sorting by clone', start_time=start_time, width=20)
        sorted_records = sorted(db_iter, key=lambda x: x.getField(clone_field))
        printMessage('Done', start_time=start_time, end=True, width=20)
        receptor_iter = groupby(sorted_records, lambda x: x.getField(clone_field))
    else:
        receptor_iter = ((x.sequence_id, [x]) for x in db_iter)

    # Define log handle
    if out_args['log_file'] is None:
        log_handle = None
    else:
        log_handle = open(out_args['log_file'], 'w')

    # Initialize handles, writers and counters
    pass_handle, pass_writer = None, None
    fail_handle, fail_writer = None, None
    rec_count, pass_count, fail_count = 0, 0, 0
    start_time = time()

    # Iterate over rows
    for key, records in receptor_iter:
        # Print progress
        printProgress(rec_count, total_count, 0.05, start_time=start_time)

        # Define iteration variables
        records = list(records)
        rec_log = OrderedDict([('ID', key)])
        rec_count += len(records)

        # Build germline for records
        if len(records) == 1:
            germ_log, germlines, genes = buildGermline(records[0], reference_dict, seq_field=seq_field, v_field=v_field,
                                                       d_field=d_field, j_field=j_field)
        else:
            germ_log, germlines, genes = buildClonalGermline(records, reference_dict, seq_field=seq_field, v_field=v_field,
                                                             d_field=d_field, j_field=j_field)
        rec_log.update(germ_log)

        # Write row to pass or fail file
        if germlines is not None:
            pass_count += len(records)

            # Add germlines to Receptor record
            annotations = {}
            if 'full' in germ_types:  annotations[germline_fields['full']] = germlines['full']
            if 'dmask' in germ_types:  annotations[germline_fields['dmask']] = germlines['dmask']
            if 'vonly' in germ_types:  annotations[germline_fields['vonly']] = germlines['vonly']
            if 'regions' in germ_types:  annotations[germline_fields['regions']] = germlines['regions']
            if cloned:
                annotations[germline_fields['v']] = genes['v']
                annotations[germline_fields['d']] = genes['d']
                annotations[germline_fields['j']] = genes['j']

            # Write records
            try:
                for r in records:
                    r.setDict(annotations)
                    pass_writer.writeReceptor(r)
            except AttributeError:
                # Create output file handle and writer
                if out_file is not None:
                    pass_handle = open(out_file, 'w')
                else:
                    pass_handle = getOutputHandle(db_file,
                                                  out_label='germ-pass',
                                                  out_dir=out_args['out_dir'],
                                                  out_name=out_args['out_name'],
                                                  out_type=out_args['out_type'])
                pass_writer = writer(pass_handle, fields=out_fields)
                for r in records:
                    r.setDict(annotations)
                    pass_writer.writeReceptor(r)
        else:
            fail_count += len(records)
            if out_args['failed']:
                try:
                    fail_writer.writeReceptor(records)
                except AttributeError:
                    fail_handle = getOutputHandle(db_file,
                                                  out_label='germ-fail',
                                                  out_dir=out_args['out_dir'],
                                                  out_name=out_args['out_name'],
                                                  out_type=out_args['out_type'])
                    fail_writer = writer(fail_handle, fields=out_fields)
                    fail_writer.writeReceptor(records)

        # Write log
        printLog(rec_log, handle=log_handle)

    # Print log
    printProgress(rec_count, total_count, 0.05, start_time=start_time)
    log = OrderedDict()
    log['OUTPUT'] = os.path.basename(pass_handle.name) if pass_handle is not None else None
    log['RECORDS'] = rec_count
    log['PASS'] = pass_count
    log['FAIL'] = fail_count
    log['END'] = 'CreateGermlines'
    printLog(log)

    # Close file handles
    db_handle.close()
    output = {'pass': None, 'fail': None}
    if pass_handle is not None:
        output['pass'] = pass_handle.name
        pass_handle.close()
    if fail_handle is not None:
        output['fail'] = fail_handle.name
        fail_handle.close()
    if log_handle is not None:
        log_handle.close()

    return output
예제 #17
0
def assemblePairs(head_file,
                  tail_file,
                  assemble_func,
                  assemble_args={},
                  coord_type=default_coord,
                  rc='tail',
                  head_fields=None,
                  tail_fields=None,
                  out_file=None,
                  out_args=default_out_args,
                  nproc=None,
                  queue_size=None):
    """
    Generates consensus sequences

    Arguments: 
      head_file : the head sequence file name
      tail_file : the tail sequence file name
      assemble_func : the function to use to assemble paired ends
      assemble_args : a dictionary of arguments to pass to the assembly function
      coord_type : the sequence header format
      rc : Defines which sequences ('head', 'tail', 'both', 'none') to reverse complement before assembly;
           if 'none' do not reverse complement sequences
      head_fields : list of annotations in head_file records to copy to assembled record;
                    if None do not copy an annotation
      tail_fields : list of annotations in tail_file records to copy to assembled record;
                    if None do not copy an annotation
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs
      nproc = the number of processQueue processes;
              if None defaults to the number of CPUs
      queue_size = maximum size of the argument queue;
                   if None defaults to 2*nproc
                 
    Returns: 
      list: a list of successful output file names.
    """
    # Define subcommand label dictionary
    cmd_dict = {
        alignAssembly: 'align',
        joinAssembly: 'join',
        referenceAssembly: 'reference',
        sequentialAssembly: 'sequential'
    }
    cmd_name = cmd_dict.get(assemble_func, assemble_func.__name__)

    # Print parameter info
    log = OrderedDict()
    log['START'] = 'AssemblePairs'
    log['COMMAND'] = cmd_name
    log['FILE1'] = os.path.basename(head_file)
    log['FILE2'] = os.path.basename(tail_file)
    log['COORD_TYPE'] = coord_type
    if 'ref_file' in assemble_args: log['REFFILE'] = assemble_args['ref_file']
    if 'alpha' in assemble_args: log['ALPHA'] = assemble_args['alpha']
    if 'max_error' in assemble_args:
        log['MAX_ERROR'] = assemble_args['max_error']
    if 'min_len' in assemble_args: log['MIN_LEN'] = assemble_args['min_len']
    if 'max_len' in assemble_args: log['MAX_LEN'] = assemble_args['max_len']
    if 'scan_reverse' in assemble_args:
        log['SCAN_REVERSE'] = assemble_args['scan_reverse']
    if 'gap' in assemble_args: log['GAP'] = assemble_args['gap']
    if 'min_ident' in assemble_args:
        log['MIN_IDENT'] = assemble_args['min_ident']
    if 'evalue' in assemble_args: log['EVALUE'] = assemble_args['evalue']
    if 'max_hits' in assemble_args: log['MAX_HITS'] = assemble_args['max_hits']
    if 'fill' in assemble_args: log['FILL'] = assemble_args['fill']
    if 'aligner' in assemble_args: log['ALIGNER'] = assemble_args['aligner']
    log['NPROC'] = nproc
    printLog(log)

    # Count input files
    head_count = countSeqFile(head_file)
    tail_count = countSeqFile(tail_file)
    if head_count != tail_count:
        printError('FILE1 (n=%i) and FILE2 (n=%i) must have the same number of records.' \
                 % (head_count, tail_count))

    # Setup for reference alignment
    if cmd_name in ('reference', 'sequential'):
        ref_file = assemble_args.pop('ref_file')
        db_exec = assemble_args.pop('db_exec')

        # Build reference sequence dictionary
        assemble_args['ref_dict'] = readReferenceFile(ref_file)

        # Build reference database files
        try:
            db_func = {
                'blastn': makeBlastnDb,
                'usearch': makeUBlastDb
            }[assemble_args['aligner']]
            ref_db, db_handle = db_func(ref_file, db_exec)
            assemble_args['ref_db'] = ref_db
        except:
            printError('Error building reference database for aligner %s with executable %s.' \
                       % (assemble_args['aligner'], db_exec))

    # Define feeder function and arguments
    feed_func = feedPairQueue
    feed_args = {
        'seq_file_1': head_file,
        'seq_file_2': tail_file,
        'coord_type': coord_type,
        'delimiter': out_args['delimiter']
    }
    # Define worker function and arguments
    process_args = {
        'assemble_func': assemble_func,
        'assemble_args': assemble_args,
        'rc': rc,
        'fields_1': head_fields,
        'fields_2': tail_fields,
        'delimiter': out_args['delimiter']
    }
    work_func = processSeqQueue
    work_args = {'process_func': assemblyWorker, 'process_args': process_args}
    # Define collector function and arguments
    collect_func = collectPairQueue
    collect_args = {
        'seq_file_1': head_file,
        'seq_file_2': tail_file,
        'label': 'assemble',
        'out_file': out_file,
        'out_args': out_args
    }

    # Call process manager
    result = manageProcesses(feed_func, work_func, collect_func, feed_args,
                             work_args, collect_args, nproc, queue_size)

    # Close reference database handle
    if cmd_name in ('reference', 'sequential'):
        try:
            db_handle.close()
        except AttributeError:
            db_handle.cleanup()
        except:
            printError('Cannot close reference database file.')

    # Print log
    log = OrderedDict()
    log['OUTPUT'] = result['log'].pop('OUTPUT')
    for k, v in result['log'].items():
        log[k] = v
    log['END'] = 'AssemblePairs'
    printLog(log)

    return result['out_files']
예제 #18
0
def clusterSets(seq_file,
                ident=default_cluster_ident,
                length_ratio=default_length_ratio,
                seq_start=0,
                seq_end=None,
                set_field=default_barcode_field,
                cluster_field=default_cluster_field,
                cluster_prefix=default_cluster_prefix,
                cluster_tool=default_cluster_tool,
                cluster_exec=default_cluster_exec,
                out_file=None,
                out_args=default_out_args,
                nproc=None,
                queue_size=None):
    """
    Performs clustering on sets of sequences

    Arguments:
      seq_file : the sample sequence file name.
      ident : the identity threshold for clustering sequences.
      length_ratio : minimum short/long length ratio allowed within a cluster.
      seq_start : the start position to trim sequences at before clustering.
      seq_end : the end position to trim sequences at before clustering.
      set_field : the annotation containing set IDs.
      cluster_field : the name of the output cluster field.
      cluster_prefix : string defining a prefix for the cluster identifier.
      cluster_exec : the path to the clustering executable.
      cluster_tool : the clustering tool to use; one of cd-hit or usearch.
            out_file : output file name. Automatically generated from the input file if None.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.
      nproc : the number of processQueue processes;
              if None defaults to the number of CPUs.
      queue_size : maximum size of the argument queue;
                   if None defaults to 2*nproc.

    Returns:
      str: the clustered output file name.
    """
    # Print parameter info
    log = OrderedDict()
    log['START'] = 'ClusterSets'
    log['COMMAND'] = 'set'
    log['FILE'] = os.path.basename(seq_file)
    log['IDENTITY'] = ident
    log['SEQUENCE_START'] = seq_start
    log['SEQUENCE_END'] = seq_end
    log['SET_FIELD'] = set_field
    log['CLUSTER_FIELD'] = cluster_field
    log['CLUSTER_PREFIX'] = cluster_prefix
    log['CLUSTER_TOOL'] = cluster_tool
    log['NPROC'] = nproc
    printLog(log)

    # Set cluster tool
    try:
        cluster_func = map_cluster_tool.get(cluster_tool)
    except:
        printError('Invalid clustering tool %s.' % cluster_tool)

    # Check the minimum identity
    if ident < min_cluster_ident[cluster_tool]:
        printError('Minimum identity %s too low for clustering tool %s.' %
                   (str(ident), cluster_tool))

    # Define cluster function parameters
    cluster_args = {
        'cluster_exec': cluster_exec,
        'ident': ident,
        'length_ratio': length_ratio,
        'seq_start': seq_start,
        'seq_end': seq_end
    }

    # Define feeder function and arguments
    index_args = {'field': set_field, 'delimiter': out_args['delimiter']}
    feed_func = feedSeqQueue
    feed_args = {
        'seq_file': seq_file,
        'index_func': indexSeqSets,
        'index_args': index_args
    }
    # Define worker function and arguments
    work_func = processQueue
    work_args = {
        'cluster_func': cluster_func,
        'cluster_args': cluster_args,
        'cluster_field': cluster_field,
        'cluster_prefix': cluster_prefix,
        'delimiter': out_args['delimiter']
    }
    # Define collector function and arguments
    collect_func = collectSeqQueue
    collect_args = {
        'seq_file': seq_file,
        'label': 'cluster',
        'out_file': out_file,
        'out_args': out_args,
        'index_field': set_field
    }

    # Call process manager
    result = manageProcesses(feed_func, work_func, collect_func, feed_args,
                             work_args, collect_args, nproc, queue_size)

    # Print log
    log = OrderedDict()
    log['OUTPUT'] = result['log'].pop('OUTPUT')
    for k, v in result['log'].items():
        log[k] = v
    log['END'] = 'ClusterSets'
    printLog(log)

    return result['out_files']
예제 #19
0
파일: SplitSeq.py 프로젝트: avilella/presto
def selectSeqFile(seq_file,
                  field,
                  value_list=None,
                  value_file=None,
                  negate=False,
                  out_file=None,
                  out_args=default_out_args):
    """
    Select from a sequence file

    Arguments:
      seq_file : filename of the sequence file to sample from.
      field : the annotation field to check for required values.
      value_list : a list of annotation values that a sample must contain one of.
      value_file : a tab delimited file containing values to select.
      negate : if True select entires that do not contain the specific values.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
      str: output file name.
    """

    # Reads value_file
    def _read_file(value_file, field):
        field_list = []
        try:
            with open(value_file, 'rt') as handle:
                reader_dict = csv.DictReader(handle, dialect='excel-tab')
                for row in reader_dict:
                    field_list.append(row[field])
        except IOError:
            printError('File %s cannot be read.' % value_file)
        except:
            printError('File %s is invalid.' % value_file)

        return field_list

    # Print console log
    log = OrderedDict()
    log['START'] = 'SplitSeq'
    log['COMMAND'] = 'select'
    log['FILE'] = os.path.basename(seq_file)
    log['FIELD'] = field
    if value_list is not None:
        log['VALUE_LIST'] = ','.join([str(x) for x in value_list])
    if value_file is not None:
        log['VALUE_FILE'] = os.path.basename(value_file)
    log['NOT'] = negate
    printLog(log)

    # Read value_file
    if value_list is not None and value_file is not None:
        printError('Specify only one of value_list and value_file.')
    elif value_file is not None:
        value_list = _read_file(value_file, field)

    # Read sequence file
    in_type = getFileType(seq_file)
    seq_iter = readSeqFile(seq_file)
    if out_args['out_type'] is None: out_args['out_type'] = in_type

    # Output output handle
    if out_file is not None:
        out_handle = open(out_file, 'w')
    else:
        out_handle = getOutputHandle(seq_file,
                                     'selected',
                                     out_dir=out_args['out_dir'],
                                     out_name=out_args['out_name'],
                                     out_type=out_args['out_type'])

    # Generate subset of records
    start_time = time()
    pass_count, fail_count, rec_count = 0, 0, 0
    value_set = set(value_list)
    for rec in seq_iter:
        printCount(rec_count, 1e5, start_time=start_time)
        rec_count += 1

        # Parse annotations into a list of values
        ann = parseAnnotation(rec.description,
                              delimiter=out_args['delimiter'])[field]
        ann = ann.split(out_args['delimiter'][2])

        # Write
        if xor(negate, not value_set.isdisjoint(ann)):
            # Write
            SeqIO.write(rec, out_handle, out_args['out_type'])
            pass_count += 1
        else:
            fail_count += 1

    printCount(rec_count, 1e5, start_time=start_time, end=True)

    # Print log
    log = OrderedDict()
    log['OUTPUT'] = os.path.basename(out_handle.name)
    log['PASS'] = pass_count
    log['FAIL'] = fail_count
    log['END'] = 'SplitSeq'
    printLog(log)

    return out_handle.name
예제 #20
0
def convertToAIRR(db_file, format=default_format,
                  out_file=None, out_args=default_out_args):
    """
    Converts a Change-O formatted file into an AIRR formatted file

    Arguments:
      db_file : the database file name.
      format : input format.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
     str : output file name
    """
    log = OrderedDict()
    log['START'] = 'ConvertDb'
    log['COMMAND'] = 'airr'
    log['FILE'] = os.path.basename(db_file)
    printLog(log)

    # Define format operators
    try:
        reader, __, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)

    # Open input
    db_handle = open(db_file, 'rt')
    db_iter = reader(db_handle)

    # Set output fields replacing length with end fields
    in_fields = [schema.toReceptor(f) for f in db_iter.fields]
    out_fields = []
    for f in in_fields:
        if f in ReceptorData.length_fields and ReceptorData.length_fields[f][0] in in_fields:
            out_fields.append(ReceptorData.length_fields[f][1])
        out_fields.append(f)
    out_fields = list(OrderedDict.fromkeys(out_fields))
    out_fields = [AIRRSchema.fromReceptor(f) for f in out_fields]

    # Open output writer
    if out_file is not None:
        pass_handle = open(out_file, 'w')
    else:
        pass_handle = getOutputHandle(db_file, out_label='airr', out_dir=out_args['out_dir'],
                                      out_name=out_args['out_name'], out_type=AIRRSchema.out_type)
    pass_writer = AIRRWriter(pass_handle, fields=out_fields)

    # Count records
    result_count = countDbFile(db_file)

    # Iterate over records
    start_time = time()
    rec_count = 0
    for rec in db_iter:
        # Print progress for previous iteration
        printProgress(rec_count, result_count, 0.05, start_time=start_time)
        rec_count += 1
        # Write records
        pass_writer.writeReceptor(rec)

    # Print counts
    printProgress(rec_count, result_count, 0.05, start_time=start_time)
    log = OrderedDict()
    log['OUTPUT'] = os.path.basename(pass_handle.name)
    log['RECORDS'] = rec_count
    log['END'] = 'ConvertDb'
    printLog(log)

    # Close file handles
    pass_handle.close()
    db_handle.close()

    return pass_handle.name
예제 #21
0
def parseIgBLAST(aligner_file, seq_file, repo, amino_acid=False, cellranger_file=None, partial=False,
                 asis_id=True, asis_calls=False, extended=False, regions='default',
                 format='changeo', out_file=None, out_args=default_out_args):
    """
    Main for IgBLAST aligned sample sequences.

    Arguments:
      aligner_file (str): IgBLAST output file to process.
      seq_file (str): fasta file input to IgBlast (from which to get sequence).
      repo (str): folder with germline repertoire files.
      amino_acid (bool): if True then the IgBLAST output files are results from igblastp. igblastn is assumed if False.
      partial : If True put incomplete alignments in the pass file.
      asis_id (bool): if ID is to be parsed for pRESTO output with default delimiters.
      asis_calls (bool): if True do not parse gene calls for allele names.
      extended (bool): if True add alignment scores, FWR regions, and CDR regions to the output.
      regions (str): name of the IMGT FWR/CDR region definitions to use.
      format (str): output format. one of 'changeo' or 'airr'.
      out_file (str): output file name. Automatically generated from the input file if None.
      out_args (dict): common output argument dictionary from parseCommonArgs.

    Returns:
      dict : names of the 'pass' and 'fail' output files.
    """
    # Print parameter info
    log = OrderedDict()
    log['START'] = 'MakeDB'
    log['COMMAND'] = 'igblast-aa' if amino_acid else 'igblast'
    log['ALIGNER_FILE'] = os.path.basename(aligner_file)
    log['SEQ_FILE'] = os.path.basename(seq_file)
    log['ASIS_ID'] = asis_id
    log['ASIS_CALLS'] = asis_calls
    log['PARTIAL'] = partial
    log['EXTENDED'] = extended
    printLog(log)

    # Set amino acid conditions
    if amino_acid:
        format = '%s-aa' % format
        parser = IgBLASTReaderAA
    else:
        parser = IgBLASTReader

    # Start
    start_time = time()
    printMessage('Loading files', start_time=start_time, width=20)

    # Count records in sequence file
    total_count = countSeqFile(seq_file)

    # Get input sequence dictionary
    seq_dict = getSeqDict(seq_file)

    # Create germline repo dictionary
    references = readGermlines(repo, asis=asis_calls)

    # Load supplementary annotation table
    if cellranger_file is not None:
        f = cellranger_extended if extended else cellranger_base
        annotations = readCellRanger(cellranger_file, fields=f)
    else:
        annotations = None

    printMessage('Done', start_time=start_time, end=True, width=20)

    # Check for IMGT-gaps in germlines
    if all('...' not in x for x in references.values()):
        printWarning('Germline reference sequences do not appear to contain IMGT-numbering spacers. Results may be incorrect.')

    # Define format operators
    try:
        __, writer, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)
    out_args['out_type'] = schema.out_type

    # Define output fields
    fields = list(schema.required)
    if extended:
        custom = parser.customFields(schema=schema)
        fields.extend(custom)

    # Parse and write output
    with open(aligner_file, 'r') as f:
        parse_iter = parser(f, seq_dict, references, regions=regions, asis_calls=asis_calls)
        germ_iter = (addGermline(x, references, amino_acid=amino_acid) for x in parse_iter)
        output = writeDb(germ_iter, fields=fields, aligner_file=aligner_file, total_count=total_count, 
                         annotations=annotations, amino_acid=amino_acid, partial=partial, asis_id=asis_id,
                         regions=regions, writer=writer, out_file=out_file, out_args=out_args)

    return output
예제 #22
0
def alignRecords(db_file,
                 seq_fields,
                 group_func,
                 align_func,
                 group_args={},
                 align_args={},
                 format='changeo',
                 out_file=None,
                 out_args=default_out_args,
                 nproc=None,
                 queue_size=None):
    """
    Performs a multiple alignment on sets of sequences

    Arguments: 
      db_file : filename of the input database.
      seq_fields : the sequence fields to multiple align.
      group_func : function to use to group records.
      align_func : function to use to multiple align sequence groups.
      group_args : dictionary of arguments to pass to group_func.
      align_args : dictionary of arguments to pass to align_func.
      format : output format. One of 'changeo' or 'airr'.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.
      nproc : the number of processQueue processes.
              if None defaults to the number of CPUs.
      queue_size : maximum size of the argument queue.
                   if None defaults to 2*nproc.
                      
    Returns: 
      dict : names of the 'pass' and 'fail' output files.
    """
    # Define subcommand label dictionary
    cmd_dict = {
        alignAcross: 'across',
        alignWithin: 'within',
        alignBlocks: 'block'
    }

    # Print parameter info
    log = OrderedDict()
    log['START'] = 'AlignRecords'
    log['COMMAND'] = cmd_dict.get(align_func, align_func.__name__)
    log['FILE'] = os.path.basename(db_file)
    log['SEQ_FIELDS'] = ','.join(seq_fields)
    if 'group_fields' in group_args:
        log['GROUP_FIELDS'] = ','.join(group_args['group_fields'])
    if 'mode' in group_args: log['MODE'] = group_args['mode']
    if 'action' in group_args: log['ACTION'] = group_args['action']
    log['NPROC'] = nproc
    printLog(log)

    # Define format operators
    try:
        reader, writer, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)

    # Define feeder function and arguments
    if 'group_fields' in group_args and group_args['group_fields'] is not None:
        group_args['group_fields'] = [
            schema.toReceptor(f) for f in group_args['group_fields']
        ]
    feed_func = feedDbQueue
    feed_args = {
        'db_file': db_file,
        'reader': reader,
        'group_func': group_func,
        'group_args': group_args
    }
    # Define worker function and arguments
    field_map = OrderedDict([(schema.toReceptor(f), '%s_align' % f)
                             for f in seq_fields])
    align_args['field_map'] = field_map
    work_func = processDbQueue
    work_args = {'process_func': align_func, 'process_args': align_args}
    # Define collector function and arguments
    out_fields = getDbFields(db_file,
                             add=list(field_map.values()),
                             reader=reader)
    out_args['out_type'] = schema.out_type
    collect_func = collectDbQueue
    collect_args = {
        'db_file': db_file,
        'label': 'align',
        'fields': out_fields,
        'writer': writer,
        'out_file': out_file,
        'out_args': out_args
    }

    # Call process manager
    result = manageProcesses(feed_func, work_func, collect_func, feed_args,
                             work_args, collect_args, nproc, queue_size)

    # Print log
    result['log']['END'] = 'AlignRecords'
    printLog(result['log'])
    output = {k: v for k, v in result.items() if k in ('pass', 'fail')}

    return output
예제 #23
0
def buildConsensus(seq_file,
                   barcode_field=default_barcode_field,
                   min_count=default_consensus_min_count,
                   min_freq=default_consensus_min_freq,
                   min_qual=default_consensus_min_qual,
                   primer_field=None,
                   primer_freq=None,
                   max_gap=None,
                   max_error=None,
                   max_diversity=None,
                   copy_fields=None,
                   copy_actions=None,
                   dependent=False,
                   out_file=None,
                   out_args=default_out_args,
                   nproc=None,
                   queue_size=None):
    """
    Generates consensus sequences

    Arguments: 
      seq_file : the sample sequence file name
      barcode_field : the annotation field containing set IDs
      min_count : threshold number of sequences to define a consensus
      min_freq : the frequency cutoff to assign a base
      min_qual : the quality cutoff to assign a base
      primer_field : the annotation field containing primer tags;
                     if None do not annotate with primer tags
      primer_freq : the maximum primer tag frequency that must be meet to build a consensus;
                    if None do not filter by primer frequency
      max_gap : the maximum frequency of (., -) characters allowed before
                deleting a position; if None do not delete positions
      max_error : a threshold defining the maximum allowed error rate to retain a read group;
                  if None do not calculate error rate
      max_diversity : a threshold defining the average pairwise error rate required to retain a read group;
                      if None do not calculate diversity
      dependent : if False treat barcode group sequences as independent data
      copy_fields : a list of annotations to copy into consensus sequence annotations;
                    if None no additional annotations will be copied
      copy_actions : the list of actions to take for each copy_fields;
                     one of ['set', 'majority', 'min', 'max', 'sum']
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.
      nproc : the number of processQueue processes;
            if None defaults to the number of CPUs
      queue_size : maximum size of the argument queue;
                 if None defaults to 2*nproc
                    
    Returns: 
      list : a list of successful output file names.
    """
    # Print parameter info
    log = OrderedDict()
    log['START'] = 'BuildConsensus'
    log['FILE'] = os.path.basename(seq_file)
    log['BARCODE_FIELD'] = barcode_field
    log['MIN_COUNT'] = min_count
    log['MIN_FREQUENCY'] = min_freq
    log['MIN_QUALITY'] = min_qual
    log['MAX_GAP'] = max_gap
    log['PRIMER_FIELD'] = primer_field
    log['PRIMER_FREQUENCY'] = primer_freq
    log['MAX_ERROR'] = max_error
    log['MAX_DIVERSITY'] = max_diversity
    log['DEPENDENT'] = dependent
    log['COPY_FIELDS'] = ','.join(
        copy_fields) if copy_fields is not None else None
    log['COPY_ACTIONS'] = ','.join(
        copy_actions) if copy_actions is not None else None
    log['NPROC'] = nproc
    printLog(log)

    # Set consensus building function
    in_type = getFileType(seq_file)
    if in_type == 'fastq':
        cons_func = qualityConsensus
        cons_args = {
            'min_qual': min_qual,
            'min_freq': min_freq,
            'dependent': dependent
        }
    elif in_type == 'fasta':
        cons_func = frequencyConsensus
        cons_args = {'min_freq': min_freq}
    else:
        printError('Input file must be FASTA or FASTQ.')

    # Define feeder function and arguments
    index_args = {'field': barcode_field, 'delimiter': out_args['delimiter']}
    feed_func = feedSeqQueue
    feed_args = {
        'seq_file': seq_file,
        'index_func': indexSeqSets,
        'index_args': index_args
    }
    # Define worker function and arguments
    work_func = processQueue
    work_args = {
        'cons_func': cons_func,
        'cons_args': cons_args,
        'min_count': min_count,
        'primer_field': primer_field,
        'primer_freq': primer_freq,
        'max_gap': max_gap,
        'max_error': max_error,
        'max_diversity': max_diversity,
        'copy_fields': copy_fields,
        'copy_actions': copy_actions,
        'delimiter': out_args['delimiter']
    }
    # Define collector function and arguments
    collect_func = collectSeqQueue
    collect_args = {
        'seq_file': seq_file,
        'label': 'consensus',
        'out_file': out_file,
        'out_args': out_args,
        'index_field': barcode_field
    }

    # Call process manager
    result = manageProcesses(feed_func, work_func, collect_func, feed_args,
                             work_args, collect_args, nproc, queue_size)

    # Print log
    result['log']['END'] = 'BuildConsensus'
    printLog(result['log'])

    return result['out_files']
예제 #24
0
def clusterBarcodes(seq_file,
                    ident=default_cluster_ident,
                    length_ratio=default_length_ratio,
                    barcode_field=default_barcode_field,
                    cluster_field=default_cluster_field,
                    cluster_prefix=default_cluster_prefix,
                    cluster_tool=default_cluster_tool,
                    cluster_exec=default_cluster_exec,
                    out_file=None,
                    out_args=default_out_args,
                    nproc=None):
    """
    Performs clustering on sets of sequences

    Arguments:
      seq_file : the sample sequence file name.
      ident : the identity threshold for clustering sequences.
      length_ratio : minimum short/long length ratio allowed within a cluster.
      barcode_field : the annotation field containing barcode sequences.
      cluster_field : the name of the output cluster field.
      cluster_prefix : string defining a prefix for the cluster identifier.
      seq_start : the start position to trim sequences at before clustering.
      seq_end : the end position to trim sequences at before clustering.
      cluster_tool : the clustering tool to use; one of cd-hit or usearch.
      cluster_exec : the path to the executable for usearch.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : output arguments.
      nproc : the number of processQueue processes;
              if None defaults to the number of CPUs.

    Returns:
      str: the clustered output file name
    """

    # Function to modify SeqRecord header with cluster identifier
    def _header(seq,
                cluster,
                field=cluster_field,
                prefix=cluster_prefix,
                delimiter=out_args['delimiter']):
        label = '%s%i' % (prefix, cluster)
        header = parseAnnotation(seq.description, delimiter=delimiter)
        header = mergeAnnotation(header, {field: label}, delimiter=delimiter)
        seq.id = seq.name = flattenAnnotation(header, delimiter=delimiter)
        seq.description = ''
        return seq

    # Function to extract to make SeqRecord object from a barcode annotation
    def _barcode(seq, field=barcode_field, delimiter=out_args['delimiter']):
        header = parseAnnotation(seq.description, delimiter=delimiter)
        return SeqRecord(Seq(header[field]), id=seq.id)

    # Print parameter info
    log = OrderedDict()
    log['START'] = 'ClusterSets'
    log['COMMAND'] = 'barcode'
    log['FILE'] = os.path.basename(seq_file)
    log['IDENTITY'] = ident
    log['BARCODE_FIELD'] = barcode_field
    log['CLUSTER_FIELD'] = cluster_field
    log['CLUSTER_PREFIX'] = cluster_prefix
    log['CLUSTER_TOOL'] = cluster_tool
    log['NPROC'] = nproc
    printLog(log)

    # Set cluster tool
    try:
        cluster_func = map_cluster_tool.get(cluster_tool)
    except:
        printError('Invalid clustering tool %s.' % cluster_tool)

    # Check the minimum identity
    if ident < min_cluster_ident[cluster_tool]:
        printError('Minimum identity %s too low for clustering tool %s.' %
                   (str(ident), cluster_tool))

    # Count sequence file and parse into a list of SeqRecords
    result_count = countSeqFile(seq_file)
    barcode_iter = (_barcode(x) for x in readSeqFile(seq_file))

    # Perform clustering
    start_time = time()
    printMessage('Running %s' % cluster_tool, start_time=start_time, width=25)
    cluster_dict = cluster_func(barcode_iter,
                                ident=ident,
                                length_ratio=length_ratio,
                                seq_start=0,
                                seq_end=None,
                                threads=nproc,
                                cluster_exec=cluster_exec)
    printMessage('Done', start_time=start_time, end=True, width=25)

    # Determine file type
    if out_args['out_type'] is None:
        out_args['out_type'] = getFileType(seq_file)

    # Open output file handles
    if out_file is not None:
        pass_handle = open(out_file, 'w')
    else:
        pass_handle = getOutputHandle(seq_file,
                                      'cluster-pass',
                                      out_dir=out_args['out_dir'],
                                      out_name=out_args['out_name'],
                                      out_type=out_args['out_type'])

    # Open indexed sequence file
    seq_dict = readSeqFile(seq_file, index=True)

    # Iterate over sequence records and update header with cluster annotation
    start_time = time()
    rec_count = pass_count = 0
    for cluster, id_list in cluster_dict.items():
        printProgress(rec_count, result_count, 0.05, start_time=start_time)
        rec_count += len(id_list)

        # TODO:  make a generator. Figure out how to get pass_count updated
        # Define output sequences
        seq_output = [_header(seq_dict[x], cluster) for x in id_list]

        # Write output
        pass_count += len(seq_output)
        SeqIO.write(seq_output, pass_handle, out_args['out_type'])

    # Update progress
    printProgress(rec_count, result_count, 0.05, start_time=start_time)

    # Print log
    log = OrderedDict()
    log['OUTPUT'] = os.path.basename(pass_handle.name)
    log['CLUSTERS'] = len(cluster_dict)
    log['SEQUENCES'] = result_count
    log['PASS'] = pass_count
    log['FAIL'] = rec_count - pass_count
    log['END'] = 'ClusterSets'
    printLog(log)

    # Close handles
    pass_handle.close()

    return pass_handle.name
예제 #25
0
def insertGaps(db_file, references=None, format=default_format,
               out_file=None, out_args=default_out_args):
    """
    Inserts IMGT numbering into V fields

    Arguments:
      db_file : the database file name.
      references : folder with germline repertoire files. If None, do not updated alignment columns wtih IMGT gaps.
      format : input format.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
     str : output file name
    """
    log = OrderedDict()
    log['START'] = 'ConvertDb'
    log['COMMAND'] = 'imgt'
    log['FILE'] = os.path.basename(db_file)
    printLog(log)

    # Define format operators
    try:
        reader, writer, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)

    # Open input
    db_handle = open(db_file, 'rt')
    db_iter = reader(db_handle)

    # Check for required columns
    try:
        required = ['sequence_imgt', 'v_germ_start_imgt']
        checkFields(required, db_iter.fields, schema=schema)
    except LookupError as e:
        printError(e)

    # Load references
    reference_dict = readGermlines(references)

    # Check for IMGT-gaps in germlines
    if all('...' not in x for x in reference_dict.values()):
        printWarning('Germline reference sequences do not appear to contain IMGT-numbering spacers. Results may be incorrect.')

    # Open output writer
    if out_file is not None:
        pass_handle = open(out_file, 'w')
    else:
        pass_handle = getOutputHandle(db_file, out_label='gap', out_dir=out_args['out_dir'],
                                      out_name=out_args['out_name'], out_type=schema.out_type)
    pass_writer = writer(pass_handle, fields=db_iter.fields)

    # Count records
    result_count = countDbFile(db_file)

    # Iterate over records
    start_time = time()
    rec_count = pass_count = 0
    for rec in db_iter:
        # Print progress for previous iteration
        printProgress(rec_count, result_count, 0.05, start_time=start_time)
        rec_count += 1
        # Update IMGT fields
        imgt_dict = correctIMGTFields(rec, reference_dict)
        # Write records
        if imgt_dict is not None:
            pass_count += 1
            rec.setDict(imgt_dict, parse=False)
            pass_writer.writeReceptor(rec)

    # Print counts
    printProgress(rec_count, result_count, 0.05, start_time=start_time)
    log = OrderedDict()
    log['OUTPUT'] = os.path.basename(pass_handle.name)
    log['RECORDS'] = rec_count
    log['PASS'] = pass_count
    log['FAIL'] = rec_count - pass_count
    log['END'] = 'ConvertDb'
    printLog(log)

    # Close file handles
    pass_handle.close()
    db_handle.close()

    return pass_handle.name
예제 #26
0
def processQueue(alive,
                 data_queue,
                 result_queue,
                 cons_func,
                 cons_args={},
                 min_count=default_consensus_min_count,
                 primer_field=None,
                 primer_freq=None,
                 max_gap=None,
                 max_error=None,
                 max_diversity=None,
                 copy_fields=None,
                 copy_actions=None,
                 delimiter=default_delimiter):
    """
    Pulls from data queue, performs calculations, and feeds results queue

    Arguments:
      alive : a multiprocessing.Value boolean controlling whether processing
              continues; when False function returns.
      data_queue : a multiprocessing.Queue holding data to process.
      result_queue : a multiprocessing.Queue to hold processed results.
      cons_func : the function to use for consensus generation.
      cons_args : a dictionary of optional arguments for the consensus function.
      min_count : threshold number of sequences to define a consensus.
      primer_field : the annotation field containing primer names;
                     if None do not annotate with primer names.
      primer_freq : the maximum primer frequency that must be meet to build a consensus;
                    if None do not filter by primer frequency.
      max_gap : the maximum frequency of (., -) characters allowed before
                deleting a position; if None do not delete positions.
      max_error : the minimum error rate to retain a set;
                  if None do not calculate error rate.
      max_diversity : a threshold defining the average pairwise error rate required to retain a read group;
                      if None do not calculate diversity.
      copy_fields : a list of annotations to copy into consensus sequence annotations;
                    if None no additional annotations will be copied.
      copy_actions : the list of actions to take for each copy_fields;
                     one of ['set', 'majority', 'min', 'max', 'sum'].
      delimiter : a tuple of delimiters for (annotations, field/values, value lists).

    Returns: 
      None
    """
    try:
        # Iterator over data queue until sentinel object reached
        while alive.value:
            # Get data from queue
            if data_queue.empty(): continue
            else: data = data_queue.get()
            # Exit upon reaching sentinel
            if data is None: break

            # Define result dictionary for iteration
            result = SeqResult(data.id, data.data)
            result.log['BARCODE'] = data.id
            result.log['SEQCOUNT'] = len(data)

            # Define primer annotations and consensus primer if applicable
            if primer_field is None:
                primer_ann = None
                seq_list = data.data
            else:
                # Calculate consensus primer
                primer_ann = OrderedDict()
                prcons = annotationConsensus(data.data,
                                             primer_field,
                                             delimiter=delimiter)
                result.log['PRIMER'] = ','.join(prcons['set'])
                result.log['PRCOUNT'] = ','.join(
                    [str(c) for c in prcons['count']])
                result.log['PRCONS'] = prcons['cons']
                result.log['PRFREQ'] = prcons['freq']
                if primer_freq is None:
                    # Retain full sequence set if not in primer consensus mode
                    seq_list = data.data
                    primer_ann = mergeAnnotation(primer_ann,
                                                 {'PRIMER': prcons['set']},
                                                 delimiter=delimiter)
                    primer_ann = mergeAnnotation(primer_ann,
                                                 {'PRCOUNT': prcons['count']},
                                                 delimiter=delimiter)
                elif prcons['freq'] >= primer_freq:
                    # Define consensus subset
                    seq_list = subsetSeqSet(data.data,
                                            primer_field,
                                            prcons['cons'],
                                            delimiter=delimiter)
                    primer_ann = mergeAnnotation(primer_ann,
                                                 {'PRCONS': prcons['cons']},
                                                 delimiter=delimiter)
                    primer_ann = mergeAnnotation(primer_ann,
                                                 {'PRFREQ': prcons['freq']},
                                                 delimiter=delimiter)
                else:
                    # If set fails primer consensus, feed result queue and continue
                    result_queue.put(result)
                    continue

            # Check count threshold
            cons_count = len(seq_list)
            result.log['CONSCOUNT'] = cons_count
            if cons_count < min_count:
                #print(cons_count, min_count)
                # If set fails count threshold, feed result queue and continue
                result_queue.put(result)
                continue

            # Update log with input sequences
            for i, s in enumerate(seq_list):
                result.log['INSEQ%i' % (i + 1)] = str(s.seq)

            # If primer and count filters pass, generate consensus sequence
            consensus = cons_func(seq_list, **cons_args)

            # Delete positions with gap frequency over max_gap and update log with consensus
            if max_gap is not None:
                gap_positions = set(findGapPositions(seq_list, max_gap))
                result.log['CONSENSUS'] = ''.join([' ' if i in gap_positions else x \
                                                   for i, x in enumerate(consensus.seq)])
                if 'phred_quality' in consensus.letter_annotations:
                    result.log['QUALITY'] = ''.join([' ' if i in gap_positions else chr(q + 33) \
                                                     for i, q in enumerate(consensus.letter_annotations['phred_quality'])])
                consensus = deleteSeqPositions(consensus, gap_positions)
            else:
                gap_positions = None
                result.log['CONSENSUS'] = str(consensus.seq)
                if 'phred_quality' in consensus.letter_annotations:
                    result.log['QUALITY'] = ''.join([
                        chr(q + 33)
                        for q in consensus.letter_annotations['phred_quality']
                    ])

            # Calculate nucleotide diversity
            if max_diversity is not None:
                diversity = calculateDiversity(seq_list)
                result.log['DIVERSITY'] = diversity

                # If diversity exceeds threshold, feed result queue and continue
                if diversity > max_diversity:
                    result_queue.put(result)
                    continue

            # Calculate set error against consensus
            if max_error is not None:
                # Delete positions if required and calculate error
                if gap_positions is not None:
                    seq_check = [
                        deleteSeqPositions(s, gap_positions) for s in seq_list
                    ]
                else:
                    seq_check = seq_list
                error = calculateSetError(seq_check, consensus)
                result.log['ERROR'] = error

                # If error exceeds threshold, feed result queue and continue
                if error > max_error:
                    result_queue.put(result)
                    continue

            # TODO:  should move this into an improved annotationConsensus function with an action argument
            # Parse copy_field annotations and define consensus annotations
            if copy_fields is not None and copy_actions is not None:
                copy_ann = OrderedDict()
                for f, act in zip(copy_fields, copy_actions):
                    # Numeric operations
                    if act == 'min':
                        vals = getAnnotationValues(seq_list,
                                                   f,
                                                   delimiter=delimiter)
                        copy_ann[f] = '%.12g' % min(
                            [float(x or 0) for x in vals])
                    elif act == 'max':
                        vals = getAnnotationValues(seq_list,
                                                   f,
                                                   delimiter=delimiter)
                        copy_ann[f] = '%.12g' % max(
                            [float(x or 0) for x in vals])
                    elif act == 'sum':
                        vals = getAnnotationValues(seq_list,
                                                   f,
                                                   delimiter=delimiter)
                        copy_ann[f] = '%.12g' % sum(
                            [float(x or 0) for x in vals])
                    elif act == 'set':
                        vals = annotationConsensus(seq_list,
                                                   f,
                                                   delimiter=delimiter)
                        copy_ann[f] = vals['set']
                        copy_ann['%s_COUNT' % f] = vals['count']
                    elif act == 'majority':
                        vals = annotationConsensus(seq_list,
                                                   f,
                                                   delimiter=delimiter)
                        copy_ann[f] = vals['cons']
                        copy_ann['%s_FREQ' % f] = vals['freq']
            else:
                copy_ann = None

            # Define annotation for output sequence
            cons_ann = OrderedDict([('ID', data.id),
                                    ('CONSCOUNT', cons_count)])

            # Merge addition consensus annotations into output sequence annotations
            if primer_ann is not None:
                cons_ann = mergeAnnotation(cons_ann,
                                           primer_ann,
                                           delimiter=delimiter)
            if copy_ann is not None:
                cons_ann = mergeAnnotation(cons_ann,
                                           copy_ann,
                                           delimiter=delimiter)

            # Add output sequence annotations to consensus sequence
            consensus.id = consensus.name = flattenAnnotation(
                cons_ann, delimiter=delimiter)
            consensus.description = ''
            result.results = consensus
            result.valid = True

            # Feed results to result queue
            result_queue.put(result)
        else:
            sys.stderr.write('PID %s> Error in sibling process detected. Cleaning up.\n' \
                             % os.getpid())
            return None
    except:
        alive.value = False
        printError('Processing sequence set with ID: %s' % data.id, exit=False)
        raise

    return None
예제 #27
0
def convertToGenbank(db_file, inference=None, db_xref=None, molecule=default_molecule,
                     product=default_product, features=None, c_field=None, label=None,
                     count_field=None, index_field=None, allow_stop=False,
                     asis_id=False, asis_calls=False, allele_delim=default_allele_delim,
                     build_asn=False, asn_template=None, tbl2asn_exec=default_tbl2asn_exec,
                     format=default_format, out_file=None,
                     out_args=default_out_args):
    """
    Builds GenBank submission fasta and table files

    Arguments:
      db_file : the database file name.
      inference : reference alignment tool.
      db_xref : reference database link.
      molecule : source molecule (eg, "mRNA", "genomic DNA")
      product : Product (protein) name.
      features : dictionary of sample features (BioSample attributes) to add to the description of each record.
      c_field : column containing the C region gene call.
      label : a string to use as a label for the ID. if None do not add a field label.
      count_field : field name to populate the AIRR_READ_COUNT note.
      index_field : field name to populate the AIRR_CELL_INDEX note.
      allow_stop : if True retain records with junctions having stop codons.
      asis_id : if True use the original sequence ID for the output IDs.
      asis_calls : if True do not parse gene calls for IMGT nomenclature.
      allele_delim : delimiter separating the gene name from the allele number when asis_calls=True.
      build_asn : if True run tbl2asn on the generated .tbl and .fsa files.
      asn_template : template file (.sbt) to pass to tbl2asn.
      tbl2asn_exec : name of or path to the tbl2asn executable.
      format : input and output format.
      out_file : output file name without extension. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
      tuple : the output (feature table, fasta) file names.
    """
    log = OrderedDict()
    log['START'] = 'ConvertDb'
    log['COMMAND'] = 'genbank'
    log['FILE'] = os.path.basename(db_file)
    printLog(log)

    # Define format operators
    try:
        reader, __, schema = getFormatOperators(format)
    except ValueError:
        printError('Invalid format %s.' % format)

    # Open input
    db_handle = open(db_file, 'rt')
    db_iter = reader(db_handle)

    # Check for required columns
    try:
        required = ['sequence_input',
                    'v_call', 'd_call', 'j_call',
                    'v_seq_start', 'd_seq_start', 'j_seq_start']
        checkFields(required, db_iter.fields, schema=schema)
    except LookupError as e:
        printError(e)

    # Open output
    if out_file is not None:
        out_name, __ = os.path.splitext(out_file)
        fsa_handle = open('%s.fsa' % out_name, 'w')
        tbl_handle = open('%s.tbl' % out_name, 'w')
    else:
        fsa_handle = getOutputHandle(db_file, out_label='genbank', out_dir=out_args['out_dir'],
                                     out_name=out_args['out_name'], out_type='fsa')
        tbl_handle = getOutputHandle(db_file, out_label='genbank', out_dir=out_args['out_dir'],
                                     out_name=out_args['out_name'], out_type='tbl')

    # Count records
    result_count = countDbFile(db_file)

    # Define writer
    writer = csv.writer(tbl_handle, delimiter='\t', quoting=csv.QUOTE_NONE)

    # Iterate over records
    start_time = time()
    rec_count, pass_count, fail_count = 0, 0, 0
    for rec in db_iter:
        # Print progress for previous iteration
        printProgress(rec_count, result_count, 0.05, start_time=start_time)
        rec_count += 1

        # Extract table dictionary
        name = None if asis_id else rec_count
        seq = makeGenbankSequence(rec, name=name, label=label, count_field=count_field, index_field=index_field,
                                  molecule=molecule, features=features)
        tbl = makeGenbankFeatures(rec, start=seq['start'], end=seq['end'], product=product,
                                  db_xref=db_xref, inference=inference, c_field=c_field,
                                  allow_stop=allow_stop, asis_calls=asis_calls, allele_delim=allele_delim)

        if tbl is not None:
            pass_count +=1
            # Write table
            writer.writerow(['>Features', seq['record'].id])
            for feature, qualifiers in tbl.items():
                writer.writerow(feature)
                if qualifiers:
                    for x in qualifiers:
                        writer.writerow(list(chain(['', '', ''], x)))

            # Write sequence
            SeqIO.write(seq['record'], fsa_handle, 'fasta')
        else:
            fail_count += 1

    # Final progress bar
    printProgress(rec_count, result_count, 0.05, start_time=start_time)

    # Run tbl2asn
    if build_asn:
        start_time = time()
        printMessage('Running tbl2asn', start_time=start_time, width=25)
        result = runASN(fsa_handle.name, template=asn_template, exec=tbl2asn_exec)
        printMessage('Done', start_time=start_time, end=True, width=25)

    # Print ending console log
    log = OrderedDict()
    log['OUTPUT_TBL'] = os.path.basename(tbl_handle.name)
    log['OUTPUT_FSA'] = os.path.basename(fsa_handle.name)
    log['RECORDS'] = rec_count
    log['PASS'] = pass_count
    log['FAIL'] = fail_count
    log['END'] = 'ConvertDb'
    printLog(log)

    # Close file handles
    tbl_handle.close()
    fsa_handle.close()
    db_handle.close()

    return (tbl_handle.name, fsa_handle.name)
예제 #28
0
def processQueue(alive, data_queue, result_queue, align_func, align_args={},
                 calc_div=False, delimiter=default_delimiter):
    """
    Pulls from data queue, performs calculations, and feeds results queue

    Arguments: 
      alive : a multiprocessing.Value boolean controlling whether processing
              continues; when False function returns
      data_queue : a multiprocessing.Queue holding data to process
      result_queue : a multiprocessing.Queue to hold processed results
      align_func : the function to use for alignment
      align_args : a dictionary of optional arguments for the alignment function
      calc_div : if True perform diversity calculation
      delimiter : a tuple of delimiters for (annotations, field/values, value lists)

    Returns:
      None
    """
    try:
        # Iterator over data queue until sentinel object reached
        while alive.value:
            # Get data from queue
            if data_queue.empty():  continue
            else:  data = data_queue.get()
            # Exit upon reaching sentinel
            if data is None:  break
            
            # Define result object
            result = SeqResult(data.id, data.data)
            result.log['BARCODE'] = data.id
            result.log['SEQCOUNT'] = len(data)
    
            # Perform alignment
            seq_list = data.data
            align_list = align_func(seq_list, **align_args)
    
            # Process alignment
            if align_list is not None:
                # Calculate diversity
                if calc_div:
                    diversity = calculateDiversity(align_list)
                    result.log['DIVERSITY'] = diversity
                
                # Restore quality scores
                has_quality = hasattr(seq_list[0], 'letter_annotations') and \
                              'phred_quality' in seq_list[0].letter_annotations
                if has_quality:
                    qual_dict = {seq.id:seq.letter_annotations['phred_quality'] \
                                 for seq in seq_list}
                    for seq in align_list:
                        qual = deque(qual_dict[seq.id])
                        qual_new = [0 if c == '-' else qual.popleft() for c in seq.seq]
                        seq.letter_annotations['phred_quality'] = qual_new
    
                # Add alignment to log
                if 'field' in align_args:
                    for i, seq in enumerate(align_list):
                        ann = parseAnnotation(seq.description, delimiter=delimiter)
                        primer = ann[align_args['field']]
                        result.log['ALIGN%i:%s' % (i + 1, primer)] = seq.seq
                else:
                    for i, seq in enumerate(align_list):  
                        result.log['ALIGN%i' % (i + 1)] = seq.seq
                
                # Add alignment to results
                result.results = align_list
                result.valid = True
                        
            # Feed results to result queue
            result_queue.put(result)
        else:
            sys.stderr.write('PID %s> Error in sibling process detected. Cleaning up.\n' \
                             % os.getpid())
            return None
    except:
        alive.value = False
        printError('Processing sequence set with ID: %s.' % data.id, exit=False)
        raise
    
    return None
예제 #29
0
def parseCommonArgs(args, in_arg=None, in_types=None):
    """
    Checks common arguments from getCommonArgParser and transforms output options to a dictionary

    Arguments: 
      args : Argument Namespace defined by ArgumentParser.parse_args
      in_arg : String defining a non-standard input file argument to verify;
               by default ['db_files', 'seq_files', 'seq_files_1', 'seq_files_2', 'primer_file']
               are supported in that order
      in_types : List of types (file extensions as strings) to allow for files in file_arg
                 if None do not check type
                    
    Returns:
      dict : Dictionary copy of args with output arguments embedded in the dictionary out_args
    """ 
    db_types = ['tab']
    seq_types = ['fasta', 'fastq']
    primer_types = ['fasta']
    if in_types is not None:  in_types = [f.lower for f in in_types]
    args_dict = args.__dict__.copy()
    
    # Count input files
    if 'seq_files' in args_dict:
        input_count = len(args_dict['seq_files']  or [])
        input_files = args_dict['seq_files']
    elif all([k in args_dict for k in ('seq_files_1', 'seq_files_2')]):
        input_count = len(args_dict['seq_files_1']  or [])
        input_files = args_dict['seq_files_1'] + args_dict['seq_files_2']
    elif 'db_files' in args_dict:
        input_count = len(args_dict['db_files'] or [])
        input_files = args_dict['db_files']
    elif 'primer_file' in args_dict:
        input_count = 1
        input_files = args_dict['primer_file']
    elif in_arg is not None and in_arg in args_dict: 
        input_count = len(args_dict[in_arg] or [])
        input_files = args_dict[in_arg]
    else:
        printError('Cannot determine input file argument.')

    # Exit if output names or log files are specified with multiple input files    
    if args_dict.get('out_name', None) is not None and input_count > 1:
        printError('The --outname argument may not be specified with multiple input files.')
    if args_dict.get('log_file', None) is not None and input_count > 1:
        printError('The --log argument may not be specified with multiple input files.')

    # Verify single-end sequence files
    if 'seq_files' in args_dict and args_dict['seq_files']:
        for f in args_dict['seq_files']:
            if not os.path.isfile(f):
                printError('Sequence file %s does not exist.' % f)
            if getFileType(f) not in seq_types:
                printError('Sequence file %s is not a supported type. Must be one: %s.' \
                           % (f, ', '.join(seq_types)))
    
    # Verify paired-end sequence files
    if all([k in args_dict and args_dict[k] for k in ('seq_files_1', 'seq_files_2')]):
        if len(args_dict['seq_files_1']) != len(args_dict['seq_files_2']):
            printError('The -1 and -2 arguments must contain the same number of files.')
        for f1, f2 in zip(args_dict['seq_files_1'], args_dict['seq_files_2']):
            if getFileType(f1) != getFileType(f2):
                printError('Each pair of files in the -1 and -2 arguments must be the same file type.')
        for f in (args_dict['seq_files_1'] + args_dict['seq_files_2']):
            if not os.path.isfile(f):
                printError('Sequence file %s does not exist.' % f)
            if getFileType(f) not in seq_types:
                printError('Sequence file %s is not a supported type. Must be one: %s.' \
                           % (f, ', '.join(seq_types)))

    # Verify database files
    if 'db_files' in args_dict and args_dict['db_files']:
        for f in args_dict['db_files']:
            if not os.path.isfile(f):
                printError('Database file %s does not exist.' % f)
            if getFileType(f) not in db_types:
                printError('Database file %s is not a supported type. Must be one: %s.' \
                           % (f, ', '.join(db_types)))

    # Verify primer file
    if 'primer_file' in args_dict and args_dict['primer_file']:
        primer_file = args_dict['primer_file']
        if not os.path.isfile(primer_file):
            printError('Primer file %s does not exist.' % primer_file)
        if getFileType(primer_file) not in primer_types:
            printError('Primer file %s is not a supported type. Must be one: %s.' \
                       % (primer_file, ', '.join(primer_types)))

    # Verify non-standard input files
    if in_arg is not None and in_arg in args_dict and args_dict[in_arg]:
        files = args_dict[in_arg] if isinstance(args_dict[in_arg], list) \
                else [args_dict[in_arg]]
        for f in files:
            if not os.path.exists(f):
                printError('Input %s does not exist.' % f)
            if in_types is not None and getFileType(f) not in in_types:
                printError('Input %s is not a supported type. Must be one: %s.' \
                           % (f, ', '.join(in_types)))
    
    # Verify output file arguments and exit if anything is hinky
    if args_dict.get('out_files', None) is not None \
            or args_dict.get('out_file', None) is not None:
        if args_dict.get('out_dir', None) is not None:
            printError('The -o argument may not be specified with the --outdir argument.')
        if args_dict.get('out_name', None) is not None:
            printError('The -o argument may not be specified with the --outname argument.')
        if args_dict.get('failed', False):
            printError('The -o argument may not be specified with the --failed argument.')
    if args_dict.get('out_files', None) is not None:
        if len(args_dict['out_files']) != input_count:
            printError('The -o argument requires one output file name per input file.')
        for f in args_dict['out_files']:
            if f in input_files:
                printError('Output files and input files cannot have the same names.')
        for f in args_dict['out_files']:
            if os.path.isfile(f):
                printWarning('Output file %s already exists and will be overwritten.' % f)
    if args_dict.get('out_file', None) is not None:
        if args_dict['out_file'] in input_files:
            printError('Output files and input files cannot have the same names.')
        if os.path.isfile(args_dict['out_file']):
            printWarning('Output file %s already exists and will be overwritten.' % args_dict['out_file'])

    # Verify output directory
    if 'out_dir' in args_dict and args_dict['out_dir']:
        if os.path.exists(args_dict['out_dir']) and not os.path.isdir(args_dict['out_dir']):
            printError('Directory %s exists but it is not a directory.' % args_dict['out_dir'])

    # Redefine common output options as out_args dictionary
    out_args = ['log_file', 'delimiter', 'separator', 
                'out_dir', 'out_name', 'out_type', 'failed']
    args_dict['out_args'] = {k:args_dict.setdefault(k, None) for k in out_args}
    for k in out_args: del args_dict[k]
    
    return args_dict
예제 #30
0
def writeDb(records, fields, aligner_file, total_count, id_dict=None, annotations=None,
            amino_acid=False, partial=False, asis_id=True, regions='default',
            writer=AIRRWriter, out_file=None, out_args=default_out_args):
    """
    Writes parsed records to an output file
    
    Arguments: 
      records : a iterator of Receptor objects containing alignment data.
      fields : a list of ordered field names to write.
      aligner_file : input file name.
      total_count : number of records (for progress bar).
      id_dict : a dictionary of the truncated sequence ID mapped to the full sequence ID.
      annotations : additional annotation dictionary.
      amino_acid : if True do verification on amino acid fields.
      partial : if True put incomplete alignments in the pass file.
      asis_id : if ID is to be parsed for pRESTO output with default delimiters.
      regions (str): name of the IMGT FWR/CDR region definitions to use.
      writer : writer class.
      out_file : output file name. Automatically generated from the input file if None.
      out_args : common output argument dictionary from parseCommonArgs.

    Returns:
      None
    """
    # Wrapper for opening handles and writers
    def _open(x, f, writer=writer, out_file=out_file):
        if out_file is not None and x == 'pass':
            handle = open(out_file, 'w')
        else:
            handle = getOutputHandle(aligner_file,
                                     out_label='db-%s' % x,
                                     out_dir=out_args['out_dir'],
                                     out_name=out_args['out_name'],
                                     out_type=out_args['out_type'])
        return handle, writer(handle, fields=f)

    # Function to convert fasta header annotations to changeo columns
    def _changeo(f, header):
        h = [ChangeoSchema.fromReceptor(x) for x in header if x.upper() not in f]
        f.extend(h)
        return f

    def _airr(f, header):
        h = [AIRRSchema.fromReceptor(x) for x in header if x.lower() not in f]
        f.extend(h)
        return f

    # Function to verify IMGT-gapped sequence and junction concur
    def _imgt_check(rec):
        try:
            if amino_acid:
                rd = RegionDefinition(rec.junction_aa_length, amino_acid=amino_acid, definition=regions)
                x, y = rd.positions['junction']
                check = (rec.junction_aa == rec.sequence_aa_imgt[x:y])
            else:
                rd = RegionDefinition(rec.junction_length, amino_acid=amino_acid, definition=regions)
                x, y = rd.positions['junction']
                check = (rec.junction == rec.sequence_imgt[x:y])
        except (TypeError, AttributeError):
            check = False
        return check

    # Function to check for valid records strictly
    def _strict(rec):
        if amino_acid:
            valid = [rec.v_call and rec.v_call != 'None',
                     rec.j_call and rec.j_call != 'None',
                     rec.functional is not None,
                     rec.sequence_aa_imgt,
                     rec.junction_aa,
                     _imgt_check(rec)]
        else:
            valid = [rec.v_call and rec.v_call != 'None',
                     rec.j_call and rec.j_call != 'None',
                     rec.functional is not None,
                     rec.sequence_imgt,
                     rec.junction,
                     _imgt_check(rec)]
        return all(valid)

    # Function to check for valid records loosely
    def _gentle(rec):
        valid = [rec.v_call and rec.v_call != 'None',
                 rec.d_call and rec.d_call != 'None',
                 rec.j_call and rec.j_call != 'None']
        return any(valid)

    # Set writer class and annotation conversion function
    if writer == ChangeoWriter:
        _annotate = _changeo
    elif writer == AIRRWriter:
        _annotate = _airr
    else:
        printError('Invalid output writer.')

    # Additional annotation (e.g. 10X cell calls)
    # _append_table = None
    # if cellranger_file is not None:
    #     with open(cellranger_file) as csv_file:
    #         # Read in annotation file (use Sniffer to discover file delimiters)
    #         dialect = csv.Sniffer().sniff(csv_file.readline())
    #         csv_file.seek(0)
    #         csv_reader = csv.DictReader(csv_file, dialect = dialect)
    #
    #         # Generate annotation dictionary
    #         anntab_dict = {entry['contig_id']: {cellranger_map[field]: entry[field] \
    #                        for field in cellranger_map.keys()} for entry in csv_reader}
    #
    #     fields = _annotate(fields, cellranger_map.values())
    #     _append_table = lambda sequence_id: anntab_dict[sequence_id]

    # Set pass criteria
    _pass = _gentle if partial else _strict

    # Define log handle
    if out_args['log_file'] is None:
        log_handle = None
    else:
        log_handle = open(out_args['log_file'], 'w')

    # Initialize handles, writers and counters
    pass_handle, pass_writer = None, None
    fail_handle, fail_writer = None, None
    pass_count, fail_count = 0, 0
    start_time = time()

    # Validate and write output
    printProgress(0, total_count, 0.05, start_time=start_time)
    for i, record in enumerate(records, start=1):
        # Replace sequence description with full string, if required
        if id_dict is not None and record.sequence_id in id_dict:
            record.sequence_id = id_dict[record.sequence_id]

        # Parse sequence description into new columns
        if not asis_id:
            try:
                ann_raw = parseAnnotation(record.sequence_id)
                record.sequence_id = ann_raw.pop('ID')

                # Convert to Receptor fields
                ann_parsed = OrderedDict()
                for k, v in ann_raw.items():
                    ann_parsed[ChangeoSchema.toReceptor(k)] = v

                # Add annotations to Receptor and update field list
                record.setDict(ann_parsed, parse=True)
                if i == 1:  fields = _annotate(fields, ann_parsed.keys())
            except IndexError:
                # Could not parse pRESTO-style annotations so fall back to no parse
                asis_id = True
                printWarning('Sequence annotation format not recognized. Sequence headers will not be parsed.')

        # Add supplemental annotation fields
        # if _append_table is not None:
        #     record.setDict(_append_table(record.sequence_id), parse=True)
        if annotations is not None:
            record.setDict(annotations[record.sequence_id], parse=True)
            if i == 1:  fields = _annotate(fields, annotations[record.sequence_id].keys())

        # Count pass or fail and write to appropriate file
        if _pass(record):
            pass_count += 1
            # Write row to pass file
            try:
                pass_writer.writeReceptor(record)
            except AttributeError:
                # Open pass file and writer
                pass_handle, pass_writer = _open('pass', fields)
                pass_writer.writeReceptor(record)
        else:
            fail_count += 1
            # Write row to fail file if specified
            if out_args['failed']:
                try:
                    fail_writer.writeReceptor(record)
                except AttributeError:
                    # Open fail file and writer
                    fail_handle, fail_writer = _open('fail', fields)
                    fail_writer.writeReceptor(record)

        # Write log
        if log_handle is not None:
            log = OrderedDict([('ID', record.sequence_id),
                               ('V_CALL', record.v_call),
                               ('D_CALL', record.d_call),
                               ('J_CALL', record.j_call),
                               ('PRODUCTIVE', record.functional)])
            if not _imgt_check(record) and not amino_acid:
                log['ERROR'] = 'Junction does not match the sequence starting at position 310 in the IMGT numbered V(D)J sequence.'
            printLog(log, log_handle)

        # Print progress
        printProgress(i, total_count, 0.05, start_time=start_time)

    # Print console log
    log = OrderedDict()
    log['OUTPUT'] = os.path.basename(pass_handle.name) if pass_handle is not None else None
    log['PASS'] = pass_count
    log['FAIL'] = fail_count
    log['END'] = 'MakeDb'
    printLog(log)

    # Close file handles
    output = {'pass': None, 'fail': None}
    if pass_handle is not None:
        output['pass'] = pass_handle.name
        pass_handle.close()
    if fail_handle is not None:
        output['fail'] = fail_handle.name
        fail_handle.close()

    return output