def fetch_ncbi_wiki_map(num_threads, batch_size, taxid_list,
                         taxid2wikidict):
     ''' Use Entrez API to fetch taxonid -> wikipedia page mapping '''
     threads = []
     semaphore = threading.Semaphore(num_threads)
     mutex = TraceLock("fetch_ncbi_wiki_map", threading.RLock())
     batch = []
     with open(taxid_list, 'r') as taxf:
         for line in taxf:
             taxid = line.rstrip()
             if taxid == 'taxid':
                 continue  # header
             batch.append(taxid)
             if len(batch) >= batch_size:
                 semaphore.acquire()
                 t = threading.Thread(
                     target=PipelineStepFetchTaxInfo.
                     get_taxid_mapping_for_batch,
                     args=[batch, taxid2wikidict, mutex, semaphore])
                 t.start()
                 threads.append(t)
                 batch = []
     if len(batch) > 0:
         semaphore.acquire()
         t = threading.Thread(
             target=PipelineStepFetchTaxInfo.get_taxid_mapping_for_batch,
             args=[batch, taxid2wikidict, mutex, semaphore])
         t.start()
         threads.append(t)
     for t in threads:
         t.join()
Exemple #2
0
def make_space(done={}, mutex=TraceLock("make_space", multiprocessing.RLock())):  # pylint: disable=dangerous-default-value
    with mutex:
        if not done:
            try:
                really_make_space()
            except:
                log.write("Error making space.  Please attend to this before instance storage fills up.")
                log.write(traceback.format_exc())
            done['time'] = time.time()
    def test_aquire_lock(mock_log_event):
        '''Test waiting to acquire lock'''
        lock = TraceLock('lock2')  # object under test
        event = threading.Event()

        def thread_run():
            with lock:  # lock acquired
                event.set()
                time.sleep(0.5)  # make main thread wait for this lock a bit

        first_thread = threading.Thread(target=thread_run)
        first_thread.name = "Thread-1"
        first_thread.start()

        event.wait()  # wait first_thread acquire the lock

        with lock:  # lock will be held for a while we want
            pass

        mock_log_event.assert_has_calls([
            call('trace_lock',
                 values={
                     'lock_name': 'lock2',
                     'thread_name': 'Thread-1',
                     'state': 'acquired'
                 },
                 debug=True),
            call('trace_lock',
                 values={
                     'lock_name': 'lock2',
                     'thread_name': 'MainThread',
                     'state': 'waiting'
                 },
                 debug=True),
            call('trace_lock',
                 values={
                     'lock_name': 'lock2',
                     'thread_name': 'Thread-1',
                     'state': 'released'
                 },
                 debug=True),
            call('trace_lock',
                 values={
                     'lock_name': 'lock2',
                     'thread_name': 'MainThread',
                     'state': 'acquired_after_wait'
                 },
                 debug=True),
            call('trace_lock',
                 values={
                     'lock_name': 'lock2',
                     'thread_name': 'MainThread',
                     'state': 'released'
                 },
                 debug=True)
        ])
Exemple #4
0
def install_s3mi(installed={}, mutex=TraceLock("install_s3mi", multiprocessing.RLock())):  # pylint: disable=dangerous-default-value
    with mutex:
        if installed:  # Mutable default value persists
            return
        try:
            # This is typically a no-op.
            command.execute(
                "which s3mi || pip install git+git://github.com/chanzuckerberg/s3mi.git"
            )
            command.execute(
                "s3mi tweak-vm || echo s3mi tweak-vm sometimes fails under docker. Continuing..."
            )
        finally:
            installed['time'] = time.time()
Exemple #5
0
    def start(self):
        # Come up with the plan
        (step_list, self.large_file_list, covered_targets) = self.plan()

        self.manage_reference_downloads_cache()
        threading.Thread(target=self.prefetch_large_files).start()

        for step in step_list:  # download the files from s3 when necessary
            for target in step["in"]:
                target_info = covered_targets[target]
                if target_info['s3_downloadable']:
                    threading.Thread(target=self.fetch_target_from_s3,
                                     args=(target, )).start()

        self.create_status_json_file()
        # Start initializing all the steps and start running them and wait until all of them are done
        step_instances = []
        step_status_lock = TraceLock(
            f"Step-level status updates for stage {self.name}",
            threading.RLock())
        for step in step_list:
            log.write("Initializing step %s" % step["out"])
            StepClass = getattr(importlib.import_module(step["module"]),
                                step["class"])
            step_output = self.targets[step["out"]]
            step_inputs = [self.targets[itarget] for itarget in step["in"]]
            step_instance = StepClass(step["out"], step_inputs, step_output,
                                      self.output_dir_local,
                                      self.output_dir_s3, self.ref_dir_local,
                                      step["additional_files"],
                                      step["additional_attributes"],
                                      self.step_status_local, step_status_lock)
            step_instance.start()
            step_instances.append(step_instance)
        # Collecting stats files
        for step in step_instances:
            try:
                step.wait_until_all_done()
            except Exception as e:
                # Some exception thrown by one of the steps
                if isinstance(e, InvalidInputFileError):
                    self.write_invalid_input_json(e.json)
                traceback.print_exc()
                for s in step_instances:
                    # notify the waiting step instances to self destruct
                    s.stop_waiting()
                log.write("An exception was thrown. Stage failed.")
                raise e
        log.write("all steps are done")
Exemple #6
0
    def run_remotely(self, input_fas, output_m8):
        # Split files into chunks for performance
        chunk_size = GSNAP_CHUNK_SIZE if self.alignment_algorithm == "gsnap" else RAPSEARCH_CHUNK_SIZE
        part_suffix, input_chunks = self.chunk_input(input_fas, chunk_size)
        self.chunk_count = len(input_chunks)

        # Process chunks
        chunk_output_files = [None] * self.chunk_count
        chunk_threads = []
        mutex = TraceLock("run_remotely", threading.RLock())
        # Randomize execution order for performance
        randomized = list(enumerate(input_chunks))
        random.shuffle(randomized)

        try:
            for n, chunk_input_files in randomized:
                self.chunks_in_flight_semaphore.acquire()
                self.check_for_errors(mutex, chunk_output_files, input_chunks,
                                      self.alignment_algorithm)
                t = threading.Thread(
                    target=PipelineStepRunAlignment.run_chunk_wrapper,
                    kwargs={
                        'chunks_in_flight_semaphore':
                        self.chunks_in_flight_semaphore,
                        'chunk_output_files': chunk_output_files,
                        'n': n,
                        'mutex': mutex,
                        'target': self.run_chunk,
                        'kwargs': {
                            'part_suffix': part_suffix,
                            'input_files': chunk_input_files,
                            'lazy_run': False,
                        },
                    })
                t.start()
                chunk_threads.append(t)

        finally:
            # Check chunk completion
            for ct in chunk_threads:
                ct.join()

        self.check_for_errors(mutex, chunk_output_files, input_chunks,
                              self.alignment_algorithm)

        assert None not in chunk_output_files
        # Concatenate the pieces and upload results
        self.concatenate_files(chunk_output_files, output_m8)
Exemple #7
0
class LongRunningCodeSection(Updater):
    """
    Make sure we print something periodically while a long section of code is running.
    """
    lock = TraceLock("LongRunningCodeSection", multiprocessing.RLock())
    count = multiprocessing.Value('i', 0)

    def __init__(self, name, update_period=15):
        super(LongRunningCodeSection, self).__init__(
            update_period, self.print_update)
        with LongRunningCodeSection.lock:
            self.id = LongRunningCodeSection.count.value
            LongRunningCodeSection.count.value += 1
        self.name = name

    def print_update(self, t_elapsed):
        """Log an update after every polling period to indicate the code section is
        still active.
        """
        log.write("LongRunningCodeSection %d (%s) still running after %3.1f seconds." %
                  (self.id, self.name, t_elapsed))
    def test_aquire_free_lock(mock_log_event):
        '''Test lock acquiring when lock is free'''
        lock = TraceLock('lock1')

        with lock:
            pass

        mock_log_event.assert_has_calls([
            call('trace_lock',
                 values={
                     'state': 'acquired',
                     'lock_name': 'lock1',
                     'thread_name': 'MainThread'
                 },
                 debug=True),
            call('trace_lock',
                 values={
                     'state': 'released',
                     'lock_name': 'lock1',
                     'thread_name': 'MainThread'
                 },
                 debug=True)
        ])
 def fetch_wiki_content(num_threads, taxid2wikidict, taxid2wikicontent,
                        id2namedict):
     ''' Fetch wikipedia content based on taxid2wikidict '''
     threads = []
     semaphore = threading.Semaphore(num_threads)
     mutex = TraceLock("fetch_wiki_content", threading.RLock())
     for taxid, url in taxid2wikidict.items():
         m = re.search(r"curid=(\d+)", url)
         pageid = None
         if m:
             pageid = m[1]
         name = id2namedict.get(taxid)
         if pageid or name:
             semaphore.acquire()
             t = threading.Thread(
                 target=PipelineStepFetchTaxInfo.get_wiki_content_for_page,
                 args=[
                     taxid, pageid, name, taxid2wikicontent, mutex,
                     semaphore
                 ])
             t.start()
             threads.append(t)
     for t in threads:
         t.join()
Exemple #10
0
    def run_remotely(self, input_fas, output_m8, service):
        key_path = self.fetch_key(os.environ['KEY_PATH_S3'])
        sample_name = self.output_dir_s3.rstrip('/').replace('s3://',
                                                             '').replace(
                                                                 '/', '-')
        chunk_size = int(self.additional_attributes["chunk_size"])
        index_dir_suffix = self.additional_attributes.get("index_dir_suffix")
        remote_username = "******"
        remote_home_dir = os.path.join("/home", remote_username)
        if service == "gsnap":
            remote_index_dir = os.path.join(remote_home_dir, "share")
        elif service == "rapsearch2":
            remote_index_dir = os.path.join(remote_home_dir, "references",
                                            "nr_rapsearch")

        if index_dir_suffix:
            remote_index_dir = os.path.join(remote_index_dir, index_dir_suffix)

        sample_remote_work_dir = os.path.join(remote_home_dir,
                                              "batch-pipeline-workdir",
                                              sample_name)

        # Split files into chunks for performance
        part_suffix, input_chunks = self.chunk_input(input_fas, chunk_size)

        # Process chunks
        chunk_output_files = [None] * len(input_chunks)
        chunk_threads = []
        mutex = TraceLock("run_remotely", threading.RLock())
        # Randomize execution order for performance
        randomized = list(enumerate(input_chunks))
        random.shuffle(randomized)

        try:
            for n, chunk_input_files in randomized:
                self.chunks_in_flight.acquire()
                self.check_for_errors(mutex, chunk_output_files, input_chunks,
                                      service)
                chunk_remote_work_dir = f"{sample_remote_work_dir}-chunk-{n}"
                t = threading.Thread(
                    target=PipelineStepRunAlignmentRemotely.run_chunk_wrapper,
                    args=[
                        self.chunks_in_flight, chunk_output_files, n, mutex,
                        self.run_chunk,
                        [
                            part_suffix, remote_home_dir, remote_index_dir,
                            chunk_remote_work_dir, remote_username,
                            chunk_input_files, key_path, service, True
                        ]
                    ])
                t.start()
                chunk_threads.append(t)

        finally:
            # Check chunk completion
            for ct in chunk_threads:
                ct.join()
            try:
                chunk_status_tracker(service).log_stats(len(input_chunks))
            except:
                log.write(f"Problem dumping status report for {service}")
                log.write(traceback.format_exc())

        self.check_for_errors(mutex, chunk_output_files, input_chunks, service)

        assert None not in chunk_output_files
        # Concatenate the pieces and upload results
        self.concatenate_files(chunk_output_files, output_m8)
    def run(self):
        '''
            1. summarize hits
            2. built blast index
            3. blast assembled contigs to the index
            4. update the summary
        '''
        _align_m8, deduped_m8, hit_summary, orig_counts_with_dcr = self.input_files_local[0]
        assembled_contig, _assembled_scaffold, bowtie_sam, _contig_stats = self.input_files_local[1]
        reference_fasta, = self.input_files_local[2]
        duplicate_cluster_sizes_path, = self.input_files_local[3]

        blast_m8, refined_m8, refined_hit_summary, refined_counts_with_dcr, contig_summary_json, blast_top_m8 = self.output_files_local()

        assert refined_counts_with_dcr.endswith("with_dcr.json"), self.output_files_local()
        assert orig_counts_with_dcr.endswith("with_dcr.json"), self.output_files_local()

        db_type = self.additional_attributes["db_type"]
        no_assembled_results = (
            os.path.getsize(assembled_contig) < MIN_ASSEMBLED_CONTIG_SIZE or
            os.path.getsize(reference_fasta) < MIN_REF_FASTA_SIZE)

        if no_assembled_results:
            # No assembled results or refseq fasta available.
            # Create empty output files.
            command.write_text_to_file(' ', blast_m8)
            command.write_text_to_file(' ', blast_top_m8)
            command.copy_file(deduped_m8, refined_m8)
            command.copy_file(hit_summary, refined_hit_summary)
            command.copy_file(orig_counts_with_dcr, refined_counts_with_dcr)
            command.write_text_to_file('[]', contig_summary_json)
            return  # return in the middle of the function

        (read_dict, accession_dict, _selected_genera) = m8.summarize_hits(hit_summary)
        PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig, reference_fasta, blast_top_m8)
        read2contig = {}
        PipelineStepRunAssembly.generate_info_from_sam(bowtie_sam, read2contig, duplicate_cluster_sizes_path)

        (updated_read_dict, read2blastm8, contig2lineage, added_reads) = self.update_read_dict(
            read2contig, blast_top_m8, read_dict, accession_dict, db_type)
        self.generate_m8_and_hit_summary(updated_read_dict, added_reads, read2blastm8,
                                         hit_summary, deduped_m8,
                                         refined_hit_summary, refined_m8)

        # Generating taxon counts based on updated results
        lineage_db = s3.fetch_reference(
            self.additional_files["lineage_db"],
            self.ref_dir_local,
            allow_s3mi=False)  # Too small to waste s3mi

        deuterostome_db = None
        if self.additional_files.get("deuterostome_db"):
            deuterostome_db = s3.fetch_reference(self.additional_files["deuterostome_db"],
                                                 self.ref_dir_local, allow_s3mi=False)  # Too small for s3mi

        blacklist_s3_file = self.additional_files.get('taxon_blacklist', DEFAULT_BLACKLIST_S3)
        taxon_blacklist = s3.fetch_reference(blacklist_s3_file, self.ref_dir_local)

        taxon_whitelist = None
        if self.additional_attributes.get("use_taxon_whitelist"):
            taxon_whitelist = s3.fetch_reference(self.additional_files.get("taxon_whitelist", DEFAULT_WHITELIST_S3),
                                                 self.ref_dir_local)

        with TraceLock("PipelineStepBlastContigs-CYA", PipelineStepBlastContigs.cya_lock, debug=False):
            with log.log_context("PipelineStepBlastContigs", {"substep": "generate_taxon_count_json_from_m8", "db_type": db_type, "refined_counts": refined_counts_with_dcr}):
                m8.generate_taxon_count_json_from_m8(refined_m8, refined_hit_summary, db_type.upper(),
                                                     lineage_db, deuterostome_db, taxon_whitelist, taxon_blacklist,
                                                     duplicate_cluster_sizes_path, refined_counts_with_dcr)

        # generate contig stats at genus/species level
        with log.log_context("PipelineStepBlastContigs", {"substep": "generate_taxon_summary"}):
            contig_taxon_summary = self.generate_taxon_summary(
                read2contig,
                contig2lineage,
                updated_read_dict,
                added_reads,
                db_type,
                duplicate_cluster_sizes_path,
                # same filter as applied in generate_taxon_count_json_from_m8
                m8.build_should_keep_filter(deuterostome_db, taxon_whitelist, taxon_blacklist)
            )

        with log.log_context("PipelineStepBlastContigs", {"substep": "generate_taxon_summary_json", "contig_summary_json": contig_summary_json}):
            with open(contig_summary_json, 'w') as contig_outf:
                json.dump(contig_taxon_summary, contig_outf)

        # Upload additional file
        contig2lineage_json = os.path.join(os.path.dirname(contig_summary_json), f"contig2lineage.{db_type}.json")
        with log.log_context("PipelineStepBlastContigs", {"substep": "contig2lineage_json", "contig2lineage_json": contig2lineage_json}):
            with open(contig2lineage_json, 'w') as c2lf:
                json.dump(contig2lineage, c2lf)

        self.additional_output_files_hidden.append(contig2lineage_json)
Exemple #12
0
class CommandTracker(Updater):
    """CommandTracker is for running external and remote commands and
    monitoring their progress with log updates and timeouts.
    """
    lock = TraceLock("CommandTracker", multiprocessing.RLock())
    count = multiprocessing.Value('i', 0)

    def __init__(self, update_period=15):
        super(CommandTracker, self).__init__(
            update_period, self.print_update_and_enforce_timeout)
        # User can set the watchdog to a function that takes self.id and
        # t_elapsed as single arg
        self.proc = None  # Value indicates registered subprocess.
        self.timeout = None
        self.t_sigterm_sent = None  # First sigterm, then sigkill.
        self.t_sigkill_sent = None
        self.grace_period = update_period / 2.0
        with CommandTracker.lock:
            self.id = CommandTracker.count.value
            CommandTracker.count.value += 1

    def print_update_and_enforce_timeout(self, t_elapsed):
        """Log an update after every polling period to indicate the command is
        still active.
        """
        if self.proc is None or self.proc.poll() is None:
            log.write("Command %d still running after %3.1f seconds." %
                      (self.id, t_elapsed))
        else:
            # This should be uncommon, unless there is lengthy python
            # processing following the command in the same CommandTracker
            # "with" block. Note: Not to be confused with post-processing
            # on the data.
            log.write(
                "Command %d still postprocessing after %3.1f seconds." %
                (self.id, t_elapsed))
        self.enforce_timeout(t_elapsed)

    def enforce_timeout(self, t_elapsed):
        """Check the timeout and send SIGTERM then SIGKILL to end a command's
        execution.
        """
        if self.timeout is None or not self.proc or \
                t_elapsed <= self.timeout or self.proc.poll() is not None:
            # Skip if unregistered subprocess, subprocess not yet timed out,
            # or subprocess already exited.
            pass
        elif not self.t_sigterm_sent:
            # Send SIGTERM first.
            msg = "Command %d has exceeded timeout of %3.1f seconds. " \
                "Sending SIGTERM." % (self.id, self.timeout)
            log.write(msg)
            self.t_sigterm_sent = time.time()
            self.proc.terminate()
        elif not self.t_sigkill_sent:
            # Grace_period after SIGTERM, send SIGKILL.
            if time.time() > self.t_sigterm_sent + self.grace_period:
                msg = "Command %d still alive %3.1f seconds after " \
                    "SIGTERM. Sending SIGKILL." % (self.id, time.time() - self.t_sigterm_sent)
                log.write(msg)
                self.t_sigkill_sent = time.time()
                self.proc.kill()
        else:
            msg = "Command %d still alive %3.1f seconds after " \
                "SIGKILL." % (self.id, time.time() - self.t_sigkill_sent)
            log.write(msg)
Exemple #13
0
def fetch_from_s3(
        src,  # pylint: disable=dangerous-default-value
        dst,
        auto_unzip=DEFAULT_AUTO_UNZIP,
        auto_untar=DEFAULT_AUTO_UNTAR,
        allow_s3mi=DEFAULT_ALLOW_S3MI,
        okay_if_missing=False,
        is_reference=False,
        touch_only=False,
        mutex=TraceLock("fetch_from_s3", multiprocessing.RLock()),
        locks={}):
    """Fetch a file from S3 if needed, using either s3mi or aws cp.

    IT IS NOT SAFE TO CALL THIS FUNCTION FROM MULTIPLE PROCESSES.
    It is totally fine to call it from multiple threads (it is designed for that).

    When is_reference=True, "dst" must be an existing directory.

    If src does not exist or there is a failure fetching it, the function returns None,
    without raising an exception.  If the download is successful, it returns the path
    to the downloaded file or folder.  If the download already exists, it is touched
    to update its timestamp.

    When touch_only=True, if the destination does not already exist, the function
    simply returns None (as if the download failed).  If the destination does exist,
    it is touched as usual.  This is useful in implementing an LRU cache policy.

    An exception is raised only if there is a coding error or equivalent problem,
    not if src simply doesn't exist.
    """
    # FIXME: this is a compatibility hack so we can replace this function
    #   We are removing ad-hoc s3 downloads from within steps and converting
    #   additional_files to wdl inputs. These files will be transparently
    #   downloaded by miniwdl. miniwdl will also handle the caching that
    #   is currently done here. This hack bypasses the s3 download if the
    #   source is already a local file, and returns the source (which is
    #   a local file path). This way, when we change the additional_files
    #   to inputs we can provide the local file path to the step instead
    #   of the s3 path and seamlessly transition without a coordinated
    #   change between idseq-dag and the idseq monorepo.
    if not src.startswith("s3://"):
        log.write(
            f"fetch_from_s3 is skipping download because source: {src} does not start with s3://"
        )
        if not os.path.isfile(src):
            return None
        if auto_untar and src.endswith(".tar"):
            dst = src[:-4]
            if not os.path.isdir(dst):
                command.make_dirs(dst + ".untarring")
                script = 'tar xvf "${src}" -C "${tmp_destdir}"'
                named_args = {"src": src, "tmp_destdir": dst + ".untarring"}
                command.execute(
                    command_patterns.ShellScriptCommand(script=script,
                                                        named_args=named_args))
                command.rename(dst + ".untarring/" + os.path.basename(dst),
                               dst)
            return dst
        return src

    # Do not be mislead by the multiprocessing.RLock() above -- that just means it won't deadlock
    # if called from multiple processes but does not mean the behaivior will be correct.  It will
    # be incorrect, because the locks dict (cointaining per-file locks) cannot be shared across
    # processes, the way it can be shared across threads.

    if is_reference:
        assert config[
            "REF_DIR"], "The is_reference code path becomes available only after initializing gloabal config['REF_DIR']"

    if os.path.exists(dst) and os.path.isdir(dst):
        dirname, basename = os.path.split(src)
        if is_reference or os.path.abspath(dst).startswith(config["REF_DIR"]):
            # Downloads to the reference dir are persisted from job to job, so we must include
            # version information from the full s3 path.
            #
            # The final destination for s3://path/to/source.db will look like /mnt/ref/s3__path__to/source.db
            # The final destination for s3://path/to/myarchive.tar will look like /mnt/ref/s3__path__to/myarchive/...
            #
            # We considered some other alternatives, for example /mnt/ref/s3__path__to__source.db, but unfortunately,
            # some tools incorporate the base name of their database input into the output filenames, so any approach
            # that changes the basename causes problems downstream.  An example such tool is srst2.
            is_reference = True
            if dirname.startswith("s3://"):
                dirname = dirname.replace("s3://", "s3__", 1)
            # If dirname contains slashes, it has to be flattened to single level.
            dirname = dirname.replace("/", "__")
            dst = os.path.join(dst, dirname, basename)
        else:
            dst = os.path.join(dst, basename)
    else:
        assert not is_reference, f"When fetching references, dst must be an existing directory: {dst}"

    unzip = ""
    if auto_unzip:
        file_without_ext, ext = os.path.splitext(dst)
        if ext in ZIP_EXTENSIONS:
            unzip = " | " + ZIP_EXTENSIONS[
                ext]  # this command will be used to decompress stdin to stdout
            dst = file_without_ext  # remove file extension from dst
    untar = auto_untar and dst.lower().endswith(".tar")
    if untar:
        dst = dst[:-4]  # Remove .tar

    # Downloads are staged under tmp_destdir.  Only after a download completes successfully it is moved to dst.
    destdir = os.path.dirname(dst)
    tmp_destdir = os.path.join(destdir, "tmp_downloads")
    tmp_dst = os.path.join(tmp_destdir, os.path.basename(dst))

    abspath = os.path.abspath(dst)
    with mutex:
        if abspath not in locks:
            locks[abspath] = TraceLock(f"fetch_from_s3: {abspath}",
                                       multiprocessing.RLock())
        destination_lock = locks[abspath]

    # shouldn't happen and makes it impossible to ensure that any dst that exists is complete and correct.
    assert tmp_dst != dst, f"Problematic use of fetch_from_s3 with tmp_dst==dst=='{dst}'"

    with destination_lock:
        # This check is a bit imperfect when untarring... unless you follow the discipline that
        # all contents of file foo.tar are under directory foo/... (which we do follow in IDseq)
        if os.path.exists(dst):
            command.touch(dst)
            return dst

        if touch_only:
            return None

        for (kind, ddir) in [("destinaiton", destdir),
                             ("temporary download", tmp_destdir)]:
            try:
                if ddir:
                    command.make_dirs(ddir)
            except OSError as e:
                # It's okay if the parent directory already exists, but all other
                # errors fail the download.
                if e.errno != errno.EEXIST:
                    log.write(f"Error in creating {kind} directory.")
                    return None

        with IOSTREAM:
            try:
                if allow_s3mi:
                    wait_start = time.time()
                    allow_s3mi = S3MI_SEM.acquire(timeout=MAX_S3MI_WAIT)
                    wait_duration = time.time() - wait_start
                    if not allow_s3mi:
                        log.write(
                            f"Failed to acquire S3MI semaphore after waiting {wait_duration} seconds for {src}."
                        )
                    elif wait_duration >= 5:
                        log.write(
                            f"Waited {wait_duration} seconds to acquire S3MI semaphore for {src}."
                        )

                if untar:
                    write_dst = r''' | tar xvf - -C "${tmp_destdir}";'''
                    named_args = {'tmp_destdir': tmp_destdir}
                else:
                    write_dst = r''' > "${tmp_dst}";'''
                    named_args = {'tmp_dst': tmp_dst}
                command_params = f"{unzip} {write_dst}"

                named_args.update({'src': src})

                try_cli = not allow_s3mi
                if allow_s3mi:
                    if os.path.exists(tmp_dst):
                        command.remove_rf(tmp_dst)
                    try:
                        command.execute(
                            command_patterns.ShellScriptCommand(
                                script=
                                r'set -o pipefail; s3mi cat --quiet "${src}" '
                                + command_params,
                                named_args=named_args))
                    except subprocess.CalledProcessError:
                        try_cli = not okay_if_missing
                        allow_s3mi = False
                        S3MI_SEM.release()
                        if try_cli:
                            log.write(
                                "Failed to download with s3mi. Trying with aws s3 cp..."
                            )
                        else:
                            raise
                if try_cli:
                    if os.path.exists(tmp_dst):
                        command.remove_rf(tmp_dst)
                    if okay_if_missing:
                        script = r'set -o pipefail; aws s3 cp --quiet "${src}" - ' + command_params
                    else:
                        script = r'set -o pipefail; aws s3 cp --only-show-errors "${src}" - ' + command_params
                    command.execute(
                        command_patterns.ShellScriptCommand(
                            script=script,
                            named_args=named_args,
                            env=dict(os.environ, **refreshed_credentials())))
                # Move staged download into final location.  Leave this last, so it only happens if no exception has occurred.
                # By this point we have already asserted that tmp_dst != dst.
                command.rename(tmp_dst, dst)
                return dst
            except BaseException as e:  # Deliberately super broad to make doubly certain that dst will be removed if there has been any exception
                if os.path.exists(dst):
                    command.remove_rf(dst)
                if not isinstance(e, subprocess.CalledProcessError):
                    # Coding error of some sort.  Best not hide it.
                    raise
                if okay_if_missing:
                    # We presume.
                    log.write("File most likely does not exist in S3.")
                else:
                    log.write("Failed to fetch file from S3.")
                return None
            finally:
                if allow_s3mi:
                    S3MI_SEM.release()
                if os.path.exists(
                        tmp_dst
                ):  # by this point we have asserted that tmp_dst != dst (and that assert may have failed, but so be it)
                    command.remove_rf(tmp_dst)
Exemple #14
0
    def run(self):
        '''
            1. summarize hits
            2. built blast index
            3. blast assembled contigs to the index
            4. update the summary
        '''
        (_align_m8, deduped_m8, hit_summary,
         orig_counts) = self.input_files_local[0]
        assembled_contig, _assembled_scaffold, bowtie_sam, _contig_stats = self.input_files_local[
            1]
        reference_fasta = self.input_files_local[2][0]

        (blast_m8, refined_m8, refined_hit_summary, refined_counts,
         contig_summary_json, blast_top_m8) = self.output_files_local()
        db_type = self.additional_attributes["db_type"]
        if os.path.getsize(assembled_contig) < MIN_ASSEMBLED_CONTIG_SIZE or \
           os.path.getsize(reference_fasta) < MIN_REF_FASTA_SIZE:
            # No assembled results or refseq fasta available.
            # Create empty output files.
            command.write_text_to_file(' ', blast_m8)
            command.write_text_to_file(' ', blast_top_m8)
            command.copy_file(deduped_m8, refined_m8)
            command.copy_file(hit_summary, refined_hit_summary)
            command.copy_file(orig_counts, refined_counts)
            command.write_text_to_file('[]', contig_summary_json)
            return  # return in the middle of the function

        (read_dict, accession_dict,
         _selected_genera) = m8.summarize_hits(hit_summary)
        PipelineStepBlastContigs.run_blast(db_type, blast_m8, assembled_contig,
                                           reference_fasta, blast_top_m8)
        read2contig = {}
        contig_stats = defaultdict(int)
        PipelineStepRunAssembly.generate_info_from_sam(bowtie_sam, read2contig,
                                                       contig_stats)

        (updated_read_dict, read2blastm8, contig2lineage,
         added_reads) = self.update_read_dict(read2contig, blast_top_m8,
                                              read_dict, accession_dict,
                                              db_type)
        self.generate_m8_and_hit_summary(updated_read_dict, added_reads,
                                         read2blastm8, hit_summary, deduped_m8,
                                         refined_hit_summary, refined_m8)

        # Generating taxon counts based on updated results
        lineage_db = s3.fetch_reference(
            self.additional_files["lineage_db"],
            self.ref_dir_local,
            allow_s3mi=False)  # Too small to waste s3mi
        deuterostome_db = None
        evalue_type = 'raw'
        if self.additional_files.get("deuterostome_db"):
            deuterostome_db = s3.fetch_reference(
                self.additional_files["deuterostome_db"],
                self.ref_dir_local,
                allow_s3mi=False)  # Too small for s3mi
        with TraceLock("PipelineStepBlastContigs-CYA",
                       PipelineStepBlastContigs.cya_lock,
                       debug=False):
            with log.log_context(
                    "PipelineStepBlastContigs", {
                        "substep": "generate_taxon_count_json_from_m8",
                        "db_type": db_type,
                        "refined_counts": refined_counts
                    }):
                m8.generate_taxon_count_json_from_m8(
                    refined_m8, refined_hit_summary, evalue_type,
                    db_type.upper(), lineage_db, deuterostome_db,
                    refined_counts)

        # generate contig stats at genus/species level
        with log.log_context("PipelineStepBlastContigs",
                             {"substep": "generate_taxon_summary"}):
            contig_taxon_summary = self.generate_taxon_summary(
                read2contig, contig2lineage, updated_read_dict, added_reads,
                db_type)

        with log.log_context(
                "PipelineStepBlastContigs", {
                    "substep": "generate_taxon_summary_json",
                    "contig_summary_json": contig_summary_json
                }):
            with open(contig_summary_json, 'w') as contig_outf:
                json.dump(contig_taxon_summary, contig_outf)

        # Upload additional file
        contig2lineage_json = os.path.join(
            os.path.dirname(contig_summary_json),
            f"contig2lineage.{db_type}.json")
        with log.log_context(
                "PipelineStepBlastContigs", {
                    "substep": "contig2lineage_json",
                    "contig2lineage_json": contig2lineage_json
                }):
            with open(contig2lineage_json, 'w') as c2lf:
                json.dump(contig2lineage, c2lf)

        self.additional_output_files_hidden.append(contig2lineage_json)