Ejemplo n.º 1
0
class DaskDelegate(Delegate):
    type: str = "dask"

    def __init__(self, delegate_config: DaskDelegateConfig):
        super()

        self.delegate_config = delegate_config
        self.cache_provider = self.delegate_config.cache_provider

        # Attempt to load the global Dask client.
        try:
            self.client = get_client()

        except ValueError as _:
            if self.delegate_config.kube_cluster is not None:
                self.client = Client(self.delegate_config.kube_cluster)
                print(self.delegate_config.kube_cluster)

            else:
                self.client = Client(f"{self.delegate_config.dask_cluster_address}:{self.delegate_config.dask_cluster_port}")

        # Setup functions to be run on the schedule.
        def __scheduler_job_exists(dask_scheduler, job_id: str) -> bool:
            return job_id in dask_scheduler.tasks

        def __scheduler_job_state(dask_scheduler, job_id: str) -> TaskState:
            return dask_scheduler.tasks[job_id].state

        self.scheduler_job_exists = __scheduler_job_exists
        self.scheduler_job_state = __scheduler_job_state

    def __job_state(self, job_id: str) -> TaskState:
        return self.client.run_on_scheduler(self.scheduler_job_state, job_id=job_id)

    def connect(self) -> bool:
        # No need to connect.
        return True

    def test_connection(self) -> bool:
        # Shim this out until I figure out a good way to test a Dask and Redis connection.
        return True

    def create_job(self, job_id: str) -> bool:
        # No concept of creating a job.
        return True

    def start_job(self, job_id: str, work: Callable, *args, **kwargs) -> bool:
        if self.job_exists(job_id) or self.job_complete(job_id):
            return False

        # Parse and replace instances of the internal `result://` proxy protocol.
        # In short, this allows for callees to reference an in-progress or remote job without needing direct access.
        function_args = [(self.client.get_dataset(arg.replace("result://", "")) if isinstance(arg, str) and arg.startswith("result://") else arg) for arg in args]

        # Create a job to run the desired function.
        job_future: Future = self.client.submit(work, *function_args, **kwargs, key=job_id, pure=False)

        # Start additional cache job which depends on the results of the previous.
        cache_future: Future = self.client.submit(self.cache_provider.put, *[job_id, job_future], pure=False)

        # Publish the job as a dataset to maintain state across requests.
        self.client.publish_dataset(job_future, name=job_id, override=True)
        self.client.publish_dataset(cache_future, override=True)

        return True

    def stop_job(self, job_id: str) -> bool:
        if not self.job_exists(job_id):
            return False

        try:
            # Iterate through the dependencies of this job.
            dependencies = self.client.run_on_scheduler(lambda dask_scheduler: [(state.key) for state in dask_scheduler.tasks[id].dependencies])

            # Filter out any weak depenencies. Strong dependencies are suffixed with "/" and the name of the job.
            dependencies = [(dependency) for dependency in dependencies if dependency.replace(id, "").startswith("/")]

            futures = [(Future(key)) for key in dependencies]
            futures.append(Future(job_id))
        except KeyError:
            # do nothing if no dependencies
            pass

        self.client.cancel(Future(job_id))
        self.client.unpublish_dataset(job_id)

        # Hacky fix -- Simulation processes continue executing EVEN IF the parent task is killed.
        def hacky():
            os.system("pkill -f 'Simulation.out'")

        self.client.run(hacky, nanny=True)

        return True

    def job_status(self, job_id: str) -> JobStatus:
        # If the job is complete (results exist as a dataset or in the vault).
        if self.job_complete(job_id):
            status = JobStatus()
            status.status_id = JobState.DONE
            status.status_text = "The job is complete."
            status.has_failed = False
            status.is_done = True

            return status

        # If the job doesn't exist.
        if not self.job_exists(job_id):
            status = JobStatus()
            status.status_id = JobState.DOES_NOT_EXIST
            status.status_text = f"A job with job_id: '{job_id}' does not exist."
            status.has_failed = True
            status.is_done = False

            return status

        status_mapping = {
            "released": (JobState.STOPPED, "The job is known but not actively computing or in memory."),
            "waiting": (JobState.WAITING, "The job is waiting for dependencies to arrive in memory."),
            "no-worker": (JobState.WAITING, "The job is waiting for a worker to become available."),
            "processing": (JobState.RUNNING, "The job is running."),
            "memory": (JobState.DONE, "The job is done and is being held in memory."),
            "erred": (JobState.FAILED, "The job has failed."),
            "done": (JobState.DONE, "The job is done and has been cached / stored on disk.")
        }

        # Grab the task state from the scheduler.
        future_status = self.__job_state(job_id)

        status = JobStatus()
        status.status_id = status_mapping[future_status][0]
        status.status_text = status_mapping[future_status][1]

        status.is_done = status.status_id is JobState.DONE
        status.has_failed = status.status_id is JobState.FAILED

        return status

    def job_results(self, job_id: str):
        # The results of this job may exist on the client dataset.
        if job_id in self.client.datasets:
            print("[DEBUG] Getting results from dataset.")
            return self.client.get_dataset(name=job_id).result()

        # If the results are not in the cache, raise an exception.
        if not self.cache_provider.exists(job_id):
            raise Exception(f"Result with ID '{job_id}' does not exist in the cache.")

        return self.cache_provider.get(job_id)

    def job_complete(self, job_id: str) -> bool:
        # Finished job results must exist within the cache for it to be considered 'done'.
        return self.cache_provider.exists(job_id)

    def job_exists(self, job_id: str) -> bool:
        # Check if the job exists in the scheduler.
        return self.client.run_on_scheduler(self.scheduler_job_exists, job_id=job_id)

    def get_remote_dependency(self, dependency_id: str):
        # Check to see if the job exists as a dataset.
        dependency = self.client.get_dataset(name=dependency_id)

        if dependency is not None:
            return dependency

        raise Exception("Something broke, dependency does not exist within distributed memory.")
Ejemplo n.º 2
0
class DaskExecutor(BaseExecutor):
    """
    DaskExecutor submits tasks to a Dask Distributed cluster.
    """
    def __init__(self, cluster_address=None):
        super().__init__(parallelism=0)
        if cluster_address is None:
            cluster_address = conf.get('dask', 'cluster_address')
        if not cluster_address:
            raise ValueError(
                'Please provide a Dask cluster address in airflow.cfg')
        self.cluster_address = cluster_address
        # ssl / tls parameters
        self.tls_ca = conf.get('dask', 'tls_ca')
        self.tls_key = conf.get('dask', 'tls_key')
        self.tls_cert = conf.get('dask', 'tls_cert')
        self.client: Optional[Client] = None
        self.futures: Optional[Dict[Future, TaskInstanceKeyType]] = None

    def start(self) -> None:
        if self.tls_ca or self.tls_key or self.tls_cert:
            security = Security(
                tls_client_key=self.tls_key,
                tls_client_cert=self.tls_cert,
                tls_ca_file=self.tls_ca,
                require_encryption=True,
            )
        else:
            security = None

        self.client = Client(self.cluster_address, security=security)
        self.futures = {}

    def execute_async(self,
                      key: TaskInstanceKeyType,
                      command: CommandType,
                      queue: Optional[str] = None,
                      executor_config: Optional[Any] = None) -> None:
        def airflow_run():
            return subprocess.check_call(command, close_fds=True)

        if not self.client:
            raise AirflowException(NOT_STARTED_MESSAGE)

        future = self.client.submit(airflow_run, pure=False)
        self.futures[future] = key  # type: ignore

    def _process_future(self, future: Future) -> None:
        if not self.futures:
            raise AirflowException(NOT_STARTED_MESSAGE)
        if future.done():
            key = self.futures[future]
            if future.exception():
                self.log.error("Failed to execute task: %s",
                               repr(future.exception()))
                self.fail(key)
            elif future.cancelled():
                self.log.error("Failed to execute task")
                self.fail(key)
            else:
                self.success(key)
            self.futures.pop(future)

    def sync(self) -> None:
        if not self.futures:
            raise AirflowException(NOT_STARTED_MESSAGE)
        # make a copy so futures can be popped during iteration
        for future in self.futures.copy():
            self._process_future(future)

    def end(self) -> None:
        if not self.client:
            raise AirflowException(NOT_STARTED_MESSAGE)
        if not self.futures:
            raise AirflowException(NOT_STARTED_MESSAGE)
        self.client.cancel(list(self.futures.keys()))
        for future in as_completed(self.futures.copy()):
            self._process_future(future)

    def terminate(self):
        if not self.futures:
            raise AirflowException(NOT_STARTED_MESSAGE)
        self.client.cancel(self.futures.keys())
        self.end()
Ejemplo n.º 3
0
class Runner:
    def __init__(self, input_file):
        import yaml

        try:
            with open(input_file) as f:
                self.params = yaml.safe_load(f)
        except Exception as exc:
            raise exc
        self.operations = {}
        self.operations['computations'] = [
            spatial_mean, temporal_mean, climatology, anomaly
        ]
        self.operations['readwrite'] = [
            writefile, openfile, readfile, deletefile
        ]
        self.operations['write'] = [writefile]
        self.operations['read'] = [openfile, readfile]
        self.client = None

    def create_cluster(self, job_scheduler, maxcore, walltime, memory, queue,
                       wpn):
        """ Creates a dask cluster using dask_jobqueue
        """
        logger.warning('Creating a dask cluster using dask_jobqueue')
        logger.warning(f'Job Scheduler: {job_scheduler}')
        logger.warning(f'Memory size for each node: {memory}')
        logger.warning(f'Number of cores for each node: {maxcore}')
        logger.warning(f'Number of workers for each node: {wpn}')

        from dask_jobqueue import PBSCluster, SLURMCluster

        job_schedulers = {'pbs': PBSCluster, 'slurm': SLURMCluster}

        # Note about OMP_NUM_THREADS=1, --threads 1:
        # These two lines are to ensure that each benchmark workers
        # only use one threads for benchmark.
        # in the job script one sees twice --nthreads,
        # but it get overwritten by --nthreads 1
        cluster = job_schedulers[job_scheduler](
            cores=maxcore,
            memory=memory,
            processes=wpn,
            local_directory='$TMPDIR',
            interface='ib0',
            queue=queue,
            walltime=walltime,
            env_extra=['OMP_NUM_THREADS=1'],
            extra=['--nthreads 1'],
            project='ntdd0004',
        )

        self.client = Client(cluster)

        logger.warning('************************************\n'
                       'Job script created by dask_jobqueue:\n'
                       f'{cluster.job_script()}\n'
                       '***************************************')
        logger.warning(
            f'Dask cluster dashboard_link: {self.client.cluster.dashboard_link}'
        )

    def run(self):
        logger.warning('Reading configuration YAML config file')
        operation_choice = self.params['operation_choice']
        machine = self.params['machine']
        job_scheduler = self.params['job_scheduler']
        queue = self.params['queue']
        walltime = self.params['walltime']
        maxmemory_per_node = self.params['maxmemory_per_node']
        maxcore_per_node = self.params['maxcore_per_node']
        chunk_per_worker = self.params['chunk_per_worker']
        freq = self.params['freq']
        spil = self.params['spil']
        output_dir = self.params.get('output_dir', results_dir)
        now = datetime.datetime.now()
        output_dir = os.path.join(output_dir, f'{machine}/{str(now.date())}')
        os.makedirs(output_dir, exist_ok=True)
        parameters = self.params['parameters']
        num_workers = parameters['number_of_workers_per_nodes']
        num_threads = parameters.get('number_of_threads_per_workers', 1)
        num_nodes = parameters['number_of_nodes']
        chunking_schemes = parameters['chunking_scheme']
        io_formats = parameters['io_format']
        filesystems = parameters['filesystem']
        fixed_totalsize = parameters['fixed_totalsize']
        chsz = parameters['chunk_size']
        writefile_dir = parameters['writefile_dir']
        for wpn in num_workers:
            self.create_cluster(
                job_scheduler=job_scheduler,
                maxcore=maxcore_per_node,
                walltime=walltime,
                memory=maxmemory_per_node,
                queue=queue,
                wpn=wpn,
            )
            for num in num_nodes:
                self.client.cluster.scale(num * wpn)
                cluster_wait(self.client, num * wpn)
                timer = DiagnosticTimer()
                # dfs = []
                logger.warning(
                    '#####################################################################\n'
                    f'Dask cluster:\n'
                    f'\t{self.client.cluster}\n')
                now = datetime.datetime.now()
                csv_filename = f"{output_dir}/compute_study_{now.strftime('%Y-%m-%d_%H-%M-%S')}.csv"
                for chunk_size in chsz:

                    for io_format in io_formats:

                        for filesystem in filesystems:

                            if filesystem == 's3':
                                profile = parameters['profile']
                                bucket = parameters['bucket']
                                endpoint_url = parameters['endpoint_url']
                                fs = fsspec.filesystem(
                                    's3',
                                    profile=profile,
                                    anon=False,
                                    client_kwargs={
                                        'endpoint_url': endpoint_url
                                    },
                                )
                                root = f'{bucket}/test1'
                            elif filesystem == 'posix':
                                fs = LocalFileSystem()
                                root = writefile_dir
                                if not os.path.isdir(f'{root}'):
                                    os.makedirs(f'{root}')
                            for chunking_scheme in chunking_schemes:

                                logger.warning(
                                    f'Benchmark starting with: \n\tworker_per_node = {wpn},'
                                    f'\n\tnum_nodes = {num}, \n\tchunk_size = {chunk_size},'
                                    f'\n\tchunking_scheme = {chunking_scheme},'
                                    f'\n\tchunk per worker = {chunk_per_worker}'
                                    f'\n\tio_format = {io_format}'
                                    f'\n\tfilesystem = {filesystem}')
                                ds, chunks = timeseries(
                                    fixed_totalsize=fixed_totalsize,
                                    chunk_per_worker=chunk_per_worker,
                                    chunk_size=chunk_size,
                                    chunking_scheme=chunking_scheme,
                                    io_format=io_format,
                                    num_nodes=num,
                                    freq=freq,
                                    worker_per_node=wpn,
                                )
                                # wait(ds)
                                dataset_size = format_bytes(ds.nbytes)
                                logger.warning(ds)
                                logger.warning(
                                    f'Dataset total size: {dataset_size}')

                                for op in self.operations[operation_choice]:
                                    with timer.time(
                                            'runtime',
                                            operation=op.__name__,
                                            fixed_totalsize=fixed_totalsize,
                                            chunk_size=chunk_size,
                                            chunk_per_worker=chunk_per_worker,
                                            dataset_size=dataset_size,
                                            worker_per_node=wpn,
                                            threads_per_worker=num_threads,
                                            num_nodes=num,
                                            chunking_scheme=chunking_scheme,
                                            io_format=io_format,
                                            filesystem=filesystem,
                                            root=root,
                                            machine=machine,
                                            maxmemory_per_node=
                                            maxmemory_per_node,
                                            maxcore_per_node=maxcore_per_node,
                                            spil=spil,
                                    ):
                                        fname = f'{chunk_size}{chunking_scheme}{filesystem}{num}'
                                        if op.__name__ == 'writefile':
                                            print(ds.sst.data.chunksize)
                                            filename = op(
                                                ds, fs, io_format, root, fname)
                                        elif op.__name__ == 'openfile':
                                            ds = op(fs, io_format, root,
                                                    chunks, chunk_size)
                                        elif op.__name__ == 'deletefile':
                                            ds = op(fs, io_format, root,
                                                    filename)
                                        else:
                                            op(ds)
                        # kills ds, and every other dependent computation
                        logger.warning('Computation done')
                        self.client.cancel(ds)
                        temp_df = timer.dataframe()
                        temp_df.to_csv(csv_filename, index=False)
                        # dfs.append(temp_df)

                # now = datetime.datetime.now()
                # filename = f"{output_dir}/compute_study_{now.strftime('%Y-%m-%d_%H-%M-%S')}.csv"
                # df = pd.concat(dfs)
                # df.to_csv(filename, index=False)
                logger.warning(
                    f'Persisted benchmark result file: {csv_filename}')

            logger.warning(
                'Shutting down the client and cluster before changing number of workers per nodes'
            )
            self.client.cluster.close()
            logger.warning('Cluster shutdown finished')
            self.client.close()
            logger.warning('Client shutdown finished')

        logger.warning('=====> The End <=========')