Beispiel #1
0
def create_virtual_digest(
    reference_fasta: IndexedFasta,
    digest_type: str,
    digest_param: str,
    fragments: Path = None,
    digest_stats: Path = None,
    n_workers: int = 1,
) -> FragmentDf:
    """Iterate over the sequences in a fasta file and find the match positions for the restriction fragment"""

    parallel = n_workers > 1
    if parallel:
        from dask.distributed import Client, LocalCluster

        cluster = LocalCluster(processes=True, n_workers=n_workers, threads_per_worker=1)
        client = Client(cluster)

    # convert the sequences to a dask bag
    seq_bag = reference_fasta.to_dask()
    chrom_dtype = pd.CategoricalDtype(reference_fasta._chroms, ordered=True)
    FragmentDf.set_dtype("chrom", chrom_dtype)

    frag_df = (
        pd.concat(
            seq_bag.map(lambda x: (x["seqid"], x["seq"], digest_type, digest_param))
            .starmap(create_fragment_dataframe)
            .compute()
        )
        .astype({"chrom": chrom_dtype})
        .sort_values(["chrom", "start"])
        .assign(fragment_id=lambda x: np.arange(len(x), dtype=int) + 1)
        .fragmentdf.cast(subset=True)
    )

    if parallel:
        while True:
            processing = client.processing()
            still_running = [len(v) > 0 for k, v in processing.items()]
            if any(still_running):
                sleep(10)
            else:
                break
        client.close()
        cluster.close()

    # use pandas accessor extension
    frag_df.fragmentdf.assert_valid()

    frag_df.to_parquet(str(fragments), index=False)

    summary_stats_df = (
        frag_df.groupby("chrom")["fragment_length"]
        .agg(["size", "mean", "median", "min", "max"])
        .fillna(-1)
        .astype({"size": int, "min": int, "max": int})
        .rename(columns={"size": "num_fragments"})
    )
    summary_stats_df.to_csv(digest_stats)

    return frag_df
Beispiel #2
0
class DaskExecEnv(AbstractContextManager):
    def __init__(
        self,
        n_workers: int = 1,
        processes: bool = True,
        threads_per_worker: int = 1,
        scheduler_port: int = 0,
        dashboard_port: Optional[int] = None,
    ):
        self._cluster_kwds = {
            "processes": processes,
            "n_workers": n_workers,
            "scheduler_port": scheduler_port,
            "dashboard_address": f"127.0.0.1:{dashboard_port}",
            "threads_per_worker": threads_per_worker,
        }
        if dashboard_port is None:
            self._cluster_kwds["dashboard_address"] = None
        self._cluster, self._client = None, None

    def scatter(self, data):
        return self._client.scatter(data)

    def __enter__(self):
        self._cluster = LocalCluster(**self._cluster_kwds)
        self._client = Client(self._cluster)
        logger.debug(f"Cluster started: {self._cluster}")
        logger.debug(f"Client started: {self._client}")
        return self

    def __exit__(self, *args):
        if self._cluster:
            max_tries = 10
            backoff = 2
            delay = 1
            while max_tries > 1:
                processing = self._client.processing()
                still_running = [len(v) > 0 for k, v in processing.items()]
                if any(still_running):
                    sleep(delay)
                    max_tries -= 1
                    delay = delay * backoff
                else:
                    sleep(1)
                    break
            self._client.close()
            self._cluster.close()
class MyDaskClient():
    def __init__(self, address=None):
        self._client = Client(address)

    def _who_has(self, key):
        who_has_dict = self._client.who_has()
        if key in who_has_dict:
            return {"key": key, "worker": who_has_dict[key]}

    def get_status(self, key):
        # first we check if a worker has it
        processing_dict = self._client.processing()
        for worker in processing_dict.keys():
            if key in processing_dict[worker]:
                return {"status": "running", "worker": worker}
        # then we check if the task is in the stream
        for task in reversed(self._client.get_task_stream()):
            if task["key"] == key:
                return {"status": "done", "dask_status": task["status"]}
Beispiel #4
0
def parse_alignment_bam(
    input_bam: Path,
    fragment_df: FragmentDf,
    alignment_table: Path = None,
    read_table: Path = None,
    overlap_table: Path = None,
    alignment_summary: Path = None,
    read_summary: Path = None,
    chunksize: int = 50000,
    n_workers: int = 1,
):
    """Filter alignments to keep only alignments that contribute to contacts

    Parameters
    ----------

    input_bam : str
                Path to a namesorted bam with unfiltered alignments
    chunksize: int
                The alignments are batched for processing, this controls the batch size

    """

    source_aligns = NameSortedBamSource(input_bam, metadata={})
    source_aligns.discover()

    parallel = n_workers > 1
    fragment_df = fragment_df.set_index(
        ["fragment_id"]).sort_index()  # .rename_axis("index", axis=0)
    if parallel:
        from dask.distributed import Client, LocalCluster
        from time import sleep

        cluster = LocalCluster(processes=True,
                               n_workers=n_workers,
                               threads_per_worker=1)
        client = Client(cluster)
        fragment_df = client.scatter(fragment_df)

    writers = dict(
        alignment_table=TableWriter(alignment_table),
        read_table=TableWriter(read_table),
        overlap_table=TableWriter(overlap_table),
    )

    batch_progress_bar = tqdm(total=None,
                              desc="Alignments submitted: ",
                              unit=" alignments",
                              position=0)
    alignment_progress = AlignmentProgress(position=1)
    read_progress = ReadProgress(position=2)
    # perc_alignment_bar = tqdm(total=None, desc="Alignments processed: ", unit=" alignments", position=1)

    # stream that holds the raw alignment dfs
    bam_stream = Stream()

    # stream that holds the filtered/processed alignments
    if parallel:
        filtered_align_stream = (bam_stream.scatter().map(
            filter_read_alignments,
            fragment_df=fragment_df).buffer(n_workers).gather())
    else:
        filtered_align_stream = bam_stream.map(filter_read_alignments,
                                               fragment_df=fragment_df)

    # write the alignments using the table writer, updating progress bar as we go
    align_sink = (  # noqa: F841
        filtered_align_stream.pluck("alignment_table").accumulate(
            alignment_progress, returns_state=True,
            start=alignment_progress).sink(writers["alignment_table"]))

    read_sink = (  # noqa: F841
        filtered_align_stream.pluck("read_table").accumulate(
            read_progress, returns_state=True,
            start=read_progress).sink(writers["read_table"]))

    overlap_sink = filtered_align_stream.pluck("overlap_table").sink(
        writers["overlap_table"])  # noqa: F841

    for batch_idx, align_df in enumerate(
            source_aligns.read_chunked(chunksize=chunksize)):
        bam_stream.emit(align_df)
        batch_progress_bar.update(len(align_df))
        batch_progress_bar.set_postfix({"batches": batch_idx})

    if parallel:
        while True:
            processing = client.processing()
            still_running = [len(v) > 0 for k, v in processing.items()]
            if any(still_running):
                sleep(10)
            else:
                break
        client.close()
        cluster.close()

    batch_progress_bar.close()
    alignment_progress.close()
    alignment_progress.save(alignment_summary)
    read_progress.close()
    read_progress.save(read_summary)
    sys.stderr.write("\n\n\n")
    sys.stdout.write("\n")
    return read_progress.final_stats()
Beispiel #5
0
class ProcessManager(GenericProcessManager):
    manager: "ProcessManager" = None

    @classmethod
    def getManager(cls) -> Optional["ProcessManager"]:
        return cls.manager

    @classmethod
    def initManager(cls, serverConfiguration: Dict[str,
                                                   str]) -> "ProcessManager":
        if cls.manager is None:
            cls.manager = ProcessManager(serverConfiguration)
        return cls.manager

    def __init__(self, serverConfiguration: Dict[str, str]):
        self.config = serverConfiguration
        self.logger = EDASLogger.getLogger()
        self.num_wps_requests = 0
        self.scheduler_address = serverConfiguration.get(
            "scheduler.address", None)
        self.submitters = []
        self.active = True
        if self.scheduler_address is not None:
            self.logger.info(
                "Initializing Dask-distributed cluster with scheduler address: "
                + self.scheduler_address)
            self.client = Client(self.scheduler_address, timeout=60)
        else:
            nWorkers = int(
                self.config.get("dask.nworkers", multiprocessing.cpu_count()))
            self.client = Client(LocalCluster(n_workers=nWorkers))
            self.scheduler_address = self.client.scheduler.address
            self.logger.info(
                f"Initializing Local Dask cluster with {nWorkers} workers,  scheduler address = {self.scheduler_address}"
            )
            self.client.submit(lambda x: edasOpManager.buildIndices(x),
                               nWorkers)
        self.ncores = self.client.ncores()
        self.logger.info(f" ncores: {self.ncores}")
        self.scheduler_info = self.client.scheduler_info()
        self.workers: Dict = self.scheduler_info.pop("workers")
        self.logger.info(f" workers: {self.workers}")
        log_metrics = serverConfiguration.get("log.scheduler.metrics", False)
        if log_metrics:
            self.metricsThread = Thread(target=self.trackMetrics)
            self.metricsThread.start()

    def getCWTMetrics(self) -> Dict:
        metrics_data = {
            key: {}
            for key in [
                'user_jobs_queued', 'user_jobs_running', 'wps_requests',
                'cpu_ave', 'cpu_count', 'memory_usage', 'memory_available'
            ]
        }
        metrics = self.getProfileData()
        counts = metrics["counts"]
        workers = metrics["workers"]
        for key in [
                'tasks', 'processing', 'released', 'memory', 'saturated',
                'waiting', 'waiting_data', 'unrunnable'
        ]:
            metrics_data['user_jobs_running'][key] = counts[key]
        for key in ['tasks', 'waiting', 'waiting_data', 'unrunnable']:
            metrics_data['user_jobs_queued'][key] = counts[key]
        for wId, wData in workers.items():
            worker_metrics = wData["metrics"]
            total_memory = wData["memory_limit"]
            memory_usage = worker_metrics["memory"]
            metrics_data['memory_usage'][wId] = memory_usage
            metrics_data['memory_available'][wId] = total_memory - memory_usage
            metrics_data['cpu_count'][wId] = wData["ncores"]
            metrics_data['cpu_ave'][wId] = worker_metrics["cpu"]
        return metrics_data

    def trackMetrics(self, sleepTime=1.0):
        isIdle = False
        self.logger.info(f" ** TRACKING METRICS ** ")
        while self.active:
            metrics = self.getProfileData()
            counts = metrics["counts"]
            if counts['processing'] == 0:
                if not isIdle:
                    self.logger.info(f" ** CLUSTER IS IDLE ** ")
                    isIdle = True
            else:
                isIdle = False
                self.logger.info(f" METRICS: {metrics['counts']} ")
                workers = metrics["workers"]
                for key, value in workers.items():
                    self.logger.info(f" *** {key}: {value}")
                self.logger.info(f" HEALTH: {self.getHealth()}")
                time.sleep(sleepTime)

    def getWorkerMetrics(self):
        metrics = {}
        wkeys = ['ncores', 'memory_limit', 'last_seen', 'metrics']
        scheduler_info = self.client.scheduler_info()
        workers: Dict = scheduler_info.get("workers", {})
        for iW, worker in enumerate(workers.values()):
            metrics[f"W{iW}"] = {wkey: worker[wkey] for wkey in wkeys}
        return metrics

    def getDashboardAddress(self):
        stoks = self.scheduler_address.split(":")
        host_address = stoks[-2].strip("/")
        return f"http://{host_address}:8787"

    def getCounts(self) -> Dict:
        profile_address = f"{self.getDashboardAddress()}/json/counts.json"
        return requests.get(profile_address).json()

    def getHealth(self, mtype: str = "") -> str:
        profile_address = f"{self.getDashboardAddress()}/health"
        return requests.get(profile_address).text

    def getMetrics(self, mtype: str = "") -> Optional[Dict]:
        counts = self.getCounts()
        if counts['processing'] == 0: return None
        mtypes = mtype.split(",")
        metrics = {"counts": counts}
        if "processing" in mtypes:
            metrics["processing"] = self.client.processing()
        if "profile" in mtypes: metrics["profile"] = self.client.profile()
        return metrics

    def getProfileData(self, mtype: str = "") -> Dict:
        try:
            return {
                "counts": self.getCounts(),
                "workers": self.getWorkerMetrics()
            }
        except Exception as err:
            self.logger.error("Error in getProfileData")
            self.logger.error(traceback.format_exc())

        # response2: requests.Response = requests.get(tasks_address)
        # print(f"\n  ---->  Tasks Data from {tasks_address}: \n **  {response2.text} ** \n" )
        # response3: requests.Response = requests.get(workers_address)
        # print(f"\n  ---->  Workers Data from {workers_address}: \n **  {response3.text} ** \n" )


#      data = json.loads(counts)

# (r"info/main/workers.html", Workers),
# (r"info/worker/(.*).html", Worker),
# (r"info/task/(.*).html", Task),
# (r"info/main/logs.html", Logs),
# (r"info/call-stacks/(.*).html", WorkerCallStacks),
# (r"info/call-stack/(.*).html", TaskCallStack),
# (r"info/logs/(.*).html", WorkerLogs),
# (r"json/counts.json", CountsJSON),
# (r"json/identity.json", IdentityJSON),
# (r"json/index.html", IndexJSON),
# (r"individual-plots.json", IndividualPlots),
# (r"metrics", PrometheusHandler),
# (r"health", HealthHandler),

# "/system": systemmonitor_doc,
# "/stealing": stealing_doc,
# "/workers": workers_doc,
# "/events": events_doc,
# "/counters": counters_doc,
# "/tasks": tasks_doc,
# "/status": status_doc,
# "/profile": profile_doc,
# "/profile-server": profile_server_doc,
# "/graph": graph_doc,
# "/individual-task-stream": individual_task_stream_doc,
# "/individual-progress": individual_progress_doc,
# "/individual-graph": individual_graph_doc,
# "/individual-profile": individual_profile_doc,
# "/individual-profile-server": individual_profile_server_doc,
# "/individual-nbytes": individual_nbytes_doc,
# "/individual-nprocessing": individual_nprocessing_doc,
# "/individual-workers": individual_workers_doc,

    def term(self):
        self.active = False
        self.client.close()

    def runProcess(self, job: Job) -> EDASDataset:
        start_time = time.time()
        try:
            self.logger.info(
                f"Running workflow for requestId: {job.requestId}, scheduler: {self.scheduler_address}"
            )
            result = edasOpManager.buildTask(job)
            self.logger.info("Completed EDAS workflow in time " +
                             str(time.time() - start_time))
            return result
        except Exception as err:
            self.logger.error("Execution error: " + str(err))
            traceback.print_exc()

    def submitProcess(self, service: str, job: Job,
                      resultHandler: ExecHandler):
        submitter: SubmissionThread = SubmissionThread(job, resultHandler)
        self.submitters.append(submitter)
        submitter.start()