Exemplo n.º 1
0
Arquivo: db.py Projeto: ajm/glutton
class GluttonDB(object) :
    def __init__(self, fname=None) :
        self.fname       = fname
        self.compression = ZIP_DEFLATED
        self.metadata    = None
        self.data        = None     # dict of famid -> GeneFamily obj (list of Genes)
        self.seq2famid   = None     # dict of geneid -> famid
        self.dirty       = False
        self.lock        = threading.Lock()
        self.complete_jobs = 0
        self.total_jobs = 0

        self.log = get_log()

        if self.fname :
            self._read()

            if not self.is_complete() :
                self.log.warn("%s is not complete!" % self.fname)

    @property
    def species(self) :
        return self.metadata['species-name']

    @property
    def release(self) :
        return self.metadata['species-release']

    @property
    def nucleotide(self) :
        return self.metadata['nucleotide']

    @property
    def download_time(self) :
        return self.metadata['download-time']

    @property
    def version(self) :
        return self.metadata['glutton-version']

    @property
    def database(self) :
        return self.metadata['database-name']

    @property
    def filename(self) :
        return self.fname

    @property
    def checksum(self) :
        return md5(self.fname)

    def stop(self) :
        if hasattr(self, "q") :
            self.q.stop()

    def flush(self) :
        if self.dirty :
            self._write()

    # read manifest to get data and mapping files
    # read data file to get seqID to gluttonID mapping
    # read mapping file to get gluttonID to file name mapping 
    def _read(self) :
        global MANIFEST_FNAME

        z = ZipFile(self.fname, 'r', compression=self.compression)
    
        def _err(msg) :
            z.close()
            raise GluttonImportantFileNotFoundError(msg)
    
        # without the manifest all is lost
        # we need this to get the names of the other
        # XML files
        if MANIFEST_FNAME not in z.namelist() :
            _err('manifest not found in %s' % self.fname)

        self.metadata = json.load(z.open(MANIFEST_FNAME))
        
        self.log.info("read manifest - created on %s using glutton version %.1f" % \
            (time.strftime('%d/%m/%y at %H:%M:%S', time.localtime(self.download_time)), \
             self.version))

        # the data file is the raw data grouped into gene families
        # when we do a local alignment we need to get the gene id
        # of the best hit and find out which gene family it belongs to 
        if self.metadata['data-file'] not in z.namelist() :
            _err('data file (%s) not found in %s' % (self.metadata['data-file'], self.fname))

        self.data = json_to_glutton(json.load(z.open(self.metadata['data-file'])))
        self.seq2famid = self._create_lookup_table(self.data)

        self.log.info("read %d gene families (%d genes)" % (len(self.data), len(self.seq2famid)))

        z.close()

    def _create_lookup_table(self, families) :
        tmp = {}

        for fid in families :
            for gene in families[fid] :
                tmp[gene.id] = fid
        
        return tmp

    def _valid_manifest(self, m) :
        global metadata_keys

        for k in metadata_keys :
            if k not in m :
                return False
        
        return True

    def _write_to_archive(self, data, zfile, zname) :
        fname = tmpfile()

        f = open(fname, 'w')
        f.write(json.dumps(data))
        f.close()

        zfile.write(fname, arcname=zname)
        os.remove(fname)

    def _write(self) :
        global MANIFEST_FNAME

        assert self._valid_manifest(self.metadata)

        z = ZipFile(self.fname, 'a', compression=self.compression)

        self._write_to_archive(self.metadata,               z, MANIFEST_FNAME)
        self._write_to_archive(glutton_to_json(self.data),  z, self.metadata['data-file'])

        z.close()

        self.dirty = False

    def _default_datafile(self, species, release) :
        return "%s_%d_data.json" % (species, release)

    def build(self, fname, species, release=None, database_name='ensembl', nucleotide=False, download_only=False) :
        self.fname = fname

        # if the name is specified and the file exists, then that means the 
        # download already took place and we should get:
        #   - database_name
        #   - release
        # from the metadata
        if self.fname and exists(self.fname) :
            self.log.info("%s exists, resuming..." % self.fname)
       
        else :
            # release not specified
            if not release :
                self.log.info("release not provided, getting latest release...")
                release = EnsemblDownloader().get_latest_release(species, database_name)
                self.log.info("latest release is %d" % release) 

            # default name if it was not defined
            if not self.fname :
                #self.fname = "%s_%d_%s_%s.glt" % (species, release, "nuc" if nucleotide else "pep", get_ensembl_download_method())
                self.fname = "%s_%d.glt" % (species, release)
                self.log.info("database filename not specified, using '%s'" % self.fname)

            # are we resuming or starting fresh?
            if not exists(self.fname) :
                self.log.info("%s does not exist, starting from scratch..." % self.fname)
                self._initialise_db(species, release, database_name, nucleotide)

        # either way, read contents into memory
        self._read()


        # not really necessary, but check that the species from cli and in the file
        # are the same + nucleotide
        if self.species != species :
            self.log.warn("species from CLI (%s) and glutton file (%s) do not match!" % (species, self.species))

        if release and (self.release != release) :
            self.log.warn("release from CLI (%d) and glutton file (%d) do not match!" % (release, self.release))

        if self.nucleotide != nucleotide :
            self.log.warn("nucleotide/protein from CLI (%s) and glutton file (%s) do not match!" % \
                ("nucleotide" if nucleotide else "protein", "nucleotide" if self.nucleotide else "protein"))


        # no work to do
        if self.is_complete() : 
            self.log.info("%s is already complete!" % self.fname)
            return

        # don't do the analysis, just exit
        if download_only :
            self.log.info("download complete")
            return

        # build db
        self._perform_alignments()
        
        # write to disk
        self._write()

        self.log.info("finished building %s/%s" % (self.species, self.release))

    def _get_unaligned_families(self) :
        unaligned = []

        z = ZipFile(self.fname, 'r', compression=self.compression)
        aligned = set([ i.split('.')[0] for i in z.namelist() if i.endswith('.tree') ])
        z.close()

        for i in self.data :
            if (i not in aligned) and (len(self.data[i]) > 1) :
                unaligned.append(i)

        self.log.info("found %d unaligned gene families" % len(unaligned))

        return unaligned

    def _perform_alignments(self) :
        unaligned = self._get_unaligned_families()

        if not hasattr(self, "q") :
            self.q = WorkQueue()

        self.total_jobs = len(unaligned)
        self.complete_jobs = -1
        self._progress()

        for i in unaligned :
            self.q.enqueue(PrankJob(self.job_callback, self.data[i]))

        self.log.debug("waiting for job queue to drain...")

        self.q.join()

    def _initialise_db(self, species, release, database_name, nucleotide) :
        e = EnsemblDownloader()
        self.log.info("downloading %s/%d" % (species, release))
        
        try :
            self.data = ensembl_to_glutton(e.download(species, release, database_name, nucleotide))

        except EnsemblDownloadError, ede :
            self.log.fatal(ede.message)
            exit(1)



        # drosophila melanogastor - nucleotide - ensembl-main
        # contains transcripts, but not gene families
        count = 0
        for famid in self.data :
            if len(self.data[famid]) == 1 :
                count += 1
        
        if count == len(self.data) :
            raise GluttonDBBuildError("downloaded %d gene families composed of a single gene each ('sql' method will do this on some species that do not contain all transcripts (e.g. drosophila_melanogaster in ensembl-main))" % count)



        self.metadata = {}
        self.metadata['download-time']      = time.time()

        # glutton metadata
        self.metadata['glutton-version']    = glutton.__version__
        self.metadata['program-name']       = Prank().name
        self.metadata['program-version']    = Prank().version
        self.metadata['species-name']       = species
        self.metadata['species-release']    = release
        self.metadata['nucleotide']         = nucleotide
        self.metadata['database-name']      = database_name
        self.metadata['download-method']    = get_ensembl_download_method()

        # other xml files
        self.metadata['data-file']          = self._default_datafile(species, release)
        
        self.dirty = True
        self._write()
Exemplo n.º 2
0
class All_vs_all_search(object):
    def __init__(self, batch_size=100):
        self.nucleotide = False
        self.min_hitidentity = None
        self.min_hitlength = None
        self.max_evalue = None
        self.batch_size = batch_size
        self.log = get_log()
        self.cleanup_files = []
        self.gene_assignments = {}
        self.lock = threading.Lock()
        self.q = None

        self.total_jobs = 0
        self.complete_jobs = 0

    def _batch(self, x):
        tmp = []

        for i in x:
            tmp.append(i)

            if len(tmp) == self.batch_size:
                yield tmp
                tmp = []

        if not tmp:
            raise StopIteration

        yield tmp

    def process(self, db, queries, nucleotide, min_hitidentity, min_hitlength, max_evalue):
        self.nucleotide = nucleotide
        self.min_hitidentity = min_hitidentity
        self.min_hitlength = min_hitlength
        self.max_evalue = max_evalue

        # we need to deal with the index files here because
        # all of the blastx jobs need them
        self.cleanup_files += [db + i for i in [".phr", ".pin", ".psq"]]

        # creates db + {phr,pin,psq} in same dir as db
        self.log.info("creating blast db...")
        Blast.makedb(db)  # XXX THIS IS ALWAYS PROTEIN, BECAUSE WE WANT TO RUN BLASTX

        # queue up the jobs
        self.log.info("starting local alignments...")
        self.q = WorkQueue()

        self.total_jobs = len(queries)
        self.complete_jobs = -self.batch_size
        self._progress()

        for query in self._batch(queries):
            self.q.enqueue(BlastJob(self.job_callback, db, query, "blastx"))

        self.log.debug("waiting for job queue to drain...")
        self.q.join()

        rm_f(self.cleanup_files)

        return self.gene_assignments

    def stop(self):
        if self.q:
            self.q.stop()

        rm_f(self.cleanup_files)

    def get_intermediate_results(self):
        return self.gene_assignments

    def _progress(self):
        self.complete_jobs += self.batch_size

        if self.complete_jobs > self.total_jobs:
            self.complete_jobs = self.total_jobs

        sys.stderr.write("\rProgress: %d / %d blastx alignments " % (self.complete_jobs, self.total_jobs))

        if self.complete_jobs == self.total_jobs:
            sys.stderr.write("\n")
            sys.stderr.flush()

    def job_callback(self, job):
        self.log.debug("%d blast results returned" % len(job.results))

        self.lock.acquire()

        self._progress()

        if job.success():
            qlen = dict([(q.id, len(q)) for q in job.input])

            for br in job.results:
                # length = max(br.qstart, br.qend) - min(br.qstart, br.qend)
                strand = "+" if br.qstart < br.qend else "-"

                if (
                    (br.qseqid in self.gene_assignments)
                    or (self.max_evalue < br.evalue)
                    or (self.min_hitidentity > br.pident)
                    or (self.min_hitlength > br.length)
                ):
                    continue

                self.gene_assignments[br.qseqid] = (br.sseqid, strand)

        for q in job.input:
            if q.id not in self.gene_assignments:
                self.gene_assignments[q.id] = None

        self.lock.release()