Beispiel #1
0
class DaskHandler(IProcessingHandler):
    """This class wraps all Dask related functions."""

    def __init__(self, number_of_workers, class_cb: Callable, brain_class, worker_log_level=logging.WARNING):
        super().__init__(number_of_workers)
        self._client: Optional[Client] = None
        self._cluster: Optional[LocalCluster] = None

        self.class_cb = class_cb
        self.brain_class = brain_class
        self.worker_log_level = worker_log_level

    def init_framework(self):
        if self._client:
            raise RuntimeError("Dask client already initialized.")

        # threads_per_worker must be one, because atari-env is not thread-safe.
        # And because lower the thread-count from the default, we must increase the number of workers
        self._cluster = LocalCluster(processes=True, asynchronous=False, threads_per_worker=1,
                                     silence_logs=self.worker_log_level,
                                     n_workers=self.number_of_workers,
                                     memory_pause_fraction=False,
                                     lifetime='1 hour', lifetime_stagger='5 minutes', lifetime_restart=True,
                                     interface="lo")
        self._client = Client(self._cluster)
        self._client.register_worker_plugin(_CreatorPlugin(self.class_cb, self.brain_class), name="creator-plugin")
        logging.info("Dask dashboard available at port: " + str(self._client.scheduler_info()["services"]["dashboard"]))

    def map(self, func, *iterable):
        if not self._client:
            raise RuntimeError("Dask client not initialized. Call \"init_framework\" before calling \"map\"")
        return self._client.gather(self._client.map(func, *iterable))

    def cleanup_framework(self):
        self._client.shutdown()
Beispiel #2
0
def add_global_preloader(
    preloader: WorkerPreloader,
    client: Client = None,
) -> None:
    with open(PRELOADER_PATH, 'wb') as f:
        cloudpickle.dump(preloader, f)
    if client is not None:
        client.register_worker_plugin(preloader, name="global_preloader")
    LOCAL_MOCK_WORKER.plugins["global_preloader"] = preloader
    LOCAL_MOCK_WORKER.plugins["global_preloader"].setup(LOCAL_MOCK_WORKER)
def get_dask_client(n_workers,
                    processes=True,
                    n_threads=1,
                    max_mem_fraction=.9,
                    plugins=None):
    import psutil
    from dask.distributed import Client

    ml = psutil.virtual_memory().total * max(min(max_mem_fraction, 1), 0)
    ml = str(int(ml // 1e9) // n_workers)
    client = Client(processes=processes,
                    threads_per_worker=n_threads,
                    n_workers=n_workers,
                    memory_limit=ml + 'GB')

    if plugins is not None and 'pgb' in plugins:
        register_codec(codecs.PackGeneticBits)
        client.register_worker_plugin(codecs.CodecPlugin())

    return client
Beispiel #4
0
        download_if_not_exists(
            "data/model_kf{0}.h5".format(i),
            "https://jpata.web.cern.ch/jpata/hepaccelerate/model_kf{0}.h5".
            format(i),
        )

    print(
        "Trying to connect to dask cluster, please start it with examples/dask_cluster.sh or examples/dask_cluster_gpu.sh"
    )

    if args.dask_server == "debug":
        multiprocessing_initializer(args)
    else:
        client = Client(args.dask_server)
        plugin = InitializerPlugin(args)
        client.register_worker_plugin(plugin)

    print("Processing all datasets")
    arglist = []

    walltime_t0 = time.time()
    for dataset, fn_pattern, ismc in datasets:
        filenames = glob.glob(args.datapath + fn_pattern)
        if len(filenames) == 0:
            raise Exception(
                "Could not find any filenames for dataset={0}: {{datapath}}/{{fn_pattern}}={1}/{2}"
                .format(dataset, args.datapath, fn_pattern))
        ichunk = 0
        for fn in filenames:
            nev = len(uproot.open(fn).get("Events"))
            # Process in chunks of 500k events to limit peak memory usage
Beispiel #5
0
    def run(self, current_date: datetime, dry_run: bool = False) -> None:
        """
        Run analysis using mozanalysis for a specific experiment.
        """
        global _dask_cluster
        logger.info("Analysis.run invoked for experiment %s", self.config.experiment.normandy_slug)

        self.check_runnable(current_date)
        assert self.config.experiment.start_date is not None  # for mypy

        self.ensure_enrollments(current_date)

        # set up dask
        _dask_cluster = _dask_cluster or LocalCluster(
            dashboard_address=DASK_DASHBOARD_ADDRESS,
            processes=True,
            threads_per_worker=1,
            n_workers=DASK_N_PROCESSES,
        )
        client = Client(_dask_cluster)

        results = []

        if self.log_config:
            log_plugin = LogPlugin(self.log_config)
            client.register_worker_plugin(log_plugin)

            # add profiling plugins
            # resource_profiling_plugin = ResourceProfilingPlugin(
            #     scheduler=_dask_cluster.scheduler,
            #     project_id=self.log_config.log_project_id,
            #     dataset_id=self.log_config.log_dataset_id,
            #     table_id=self.log_config.task_profiling_log_table_id,
            #     experiment=self.config.experiment.normandy_slug,
            # )
            # _dask_cluster.scheduler.add_plugin(resource_profiling_plugin)

            # task_monitoring_plugin = TaskMonitoringPlugin(
            #     scheduler=_dask_cluster.scheduler,
            #     project_id=self.log_config.log_project_id,
            #     dataset_id=self.log_config.log_dataset_id,
            #     table_id=self.log_config.task_monitoring_log_table_id,
            #     experiment=self.config.experiment.normandy_slug,
            # )
            # _dask_cluster.scheduler.add_plugin(task_monitoring_plugin)

        table_to_dataframe = dask.delayed(self.bigquery.table_to_dataframe)

        for period in self.config.metrics:
            segment_results = []
            time_limits = self._get_timelimits_if_ready(period, current_date)

            if time_limits is None:
                logger.info(
                    "Skipping %s (%s); not ready",
                    self.config.experiment.normandy_slug,
                    period.value,
                )
                continue

            exp = mozanalysis.experiment.Experiment(
                experiment_slug=self.config.experiment.normandy_slug,
                start_date=self.config.experiment.start_date.strftime("%Y-%m-%d"),
                app_id=self._app_id_to_bigquery_dataset(self.config.experiment.app_id),
            )

            analysis_bases = []

            for m in self.config.metrics[period]:
                for analysis_basis in m.metric.analysis_bases:
                    analysis_bases.append(analysis_basis)

            analysis_bases = list(set(analysis_bases))

            if len(analysis_bases) == 0:
                continue

            for analysis_basis in analysis_bases:
                metrics_table = self.calculate_metrics(
                    exp, time_limits, period, analysis_basis, dry_run
                )

                if dry_run:
                    results.append(metrics_table)
                else:
                    metrics_dataframe = table_to_dataframe(metrics_table)

                if dry_run:
                    logger.info(
                        "Not calculating statistics %s (%s); dry run",
                        self.config.experiment.normandy_slug,
                        period.value,
                    )
                    continue

                segment_labels = ["all"] + [s.name for s in self.config.experiment.segments]
                for segment in segment_labels:
                    segment_data = self.subset_to_segment(segment, metrics_dataframe)
                    for m in self.config.metrics[period]:
                        segment_results += self.calculate_statistics(
                            m,
                            segment_data,
                            segment,
                            analysis_basis,
                        ).to_dict()["data"]

                    segment_results += self.counts(segment_data, segment, analysis_basis).to_dict()[
                        "data"
                    ]

            results.append(
                self.save_statistics(
                    period,
                    segment_results,
                    self._table_name(period.value, len(time_limits.analysis_windows)),
                )
            )

        result_futures = client.compute(results)
        client.gather(result_futures)  # block until futures have finished
Beispiel #6
0
        ])
        sitepackages = os.path.join(installdir, 'lib',
                                    'python' + sys.version[:3],
                                    'site-packages')
        if sitepackages not in sys.path:
            sys.path.insert(0, sitepackages)

    def teardown(self, worker):
        pass


client = Client(os.environ['DASK_SCHEDULER'])

# one-time setup
if True:
    client.register_worker_plugin(ConfigureXRootD(), 'user_proxy')
    client.register_worker_plugin(
        InstallPackage(
            'https://github.com/nsmith-/boostedhiggs/archive/dev.zip'),
        'boostedhiggs')
    # newcoffea = DistributeZipball('/home/ncsmith/coffea/dist/coffea-0.6.23.zip')
    # client.register_worker_plugin(newcoffea, 'coffeaupdate')
    # jump_assignment = WorkerJumpAssignment()
    # def put(dask_scheduler=None):
    #     dask_scheduler.add_plugin(WorkerJumpAssignment(list(dask_scheduler.workers.keys())))

    # def get(dask_scheduler=None):
    #     for p in dask_scheduler.plugins:
    #         try:
    #             return p.get_jump_mapping()
    #         except AttributeError:
Beispiel #7
0
        if args.scheduler_file:
            # We're wanting to submit workers onto other nodes, and *not* run
            # them locally because we went through the trouble of specifying
            # a scheduler file that the scheduler and workers will use to
            # coordinate with one another.
            logger.info('Using a remote distributed model')
            client = Client(scheduler_file=args.scheduler_file)
        else:
            logger.info('Using a local distributed model')
            cluster = LocalCluster(n_workers=args.workers,
                                   processes=False,
                                   silence_logs=logger.level)
            logger.info("Cluster: %s", cluster)
            client = Client(cluster)

        client.register_worker_plugin(WorkerLoggerPlugin(verbose=args.verbose))

        logger.info('Client: %s', client)

        if args.track_workers_file:
            track_workers_stream = open(args.track_workers_file, 'w')
            track_workers_func = log_worker_location(track_workers_stream)

        if args.track_pop_file is not None:
            track_pop_stream = open(args.track_pop_file, 'w')
            track_pop_func = log_pop(args.update_interval, track_pop_stream)

        final_pop = asynchronous.steady_state(
            client,
            births=args.max_births,
            init_pop_size=5,
Beispiel #8
0
def start_dask_cluster(environment=os.path.basename(
    os.environ['CONDA_PREFIX']),
                       worker_profile='Medium Worker',
                       profile='default',
                       region='us-west-2',
                       endpoint=None,
                       worker_min=2,
                       worker_max=20,
                       adaptive_scaling=True,
                       wait_for_cluster=True,
                       cfile=None,
                       use_existing_cluster=True,
                       propagate_env=False):
    '''
    environment      - should match the kernel running, and will be set autmatically
    worker profile   - 'Small Worker', 'Medium Worker', or 'Pangeo Worker' (determines available memory in a worker)
    profile          - 'default' is good, but others can be used 
    region           - AWS region
    endpoint         - None by default matches region. Set correct endpoint to s3 buckets
    worker_min       - minumum number of workers (for adaptive scaling)
    worker_max       - maximum number of workers
    adaptive_scaling - Default True. If False, launches worker_max workers
    wait_for_cluster - Default True. 
    cfile            - None. Finds aws credentials in this file
    use_existing_cluster - Default True.
    propagate_env    - Default False. Set to True when working with Cloud VRTs
    '''
    if not endpoint:
        endpoint = f's3.{region}.amazonaws.com'

    set_credentials2(profile=profile,
                     region=region,
                     endpoint=endpoint,
                     cfile=cfile)

    try:
        gateway.list_clusters()
    except:
        gateway = Gateway()

    if gateway.list_clusters():
        print('Existing Dask clusters:')
        j = 0
        for c in gateway.list_clusters():
            print(f'Cluster Index c_idx: {j} / Name:', c.name, c.status)
            j += 1
    else:
        print('No Cluster running.')

    # TODO Check if worker_profile is the same, otherwise start new cluster
    if gateway.list_clusters() and use_existing_cluster:
        print('Using existing cluster [0].')
        cluster = gateway.connect(gateway.list_clusters()[0].name)
    else:
        print('Starting new cluster.')
        cluster = gateway.new_cluster(environment=environment,
                                      profile=worker_profile)

    if adaptive_scaling:
        print(f'Setting Adaptive Scaling min={worker_min}, max={worker_max}')
        cluster.adapt(minimum=worker_min, maximum=worker_max)
    else:
        print(f'Setting Fixed Scaling workers={worker_max}')
        cluster.scale(worker_max)

    try:
        client = Client(cluster)
        client.close()
        print('Reconnect client to clear cache')
    except:
        pass
    client = Client(cluster)

    print(
        f'client.dashboard_link (for new browser tab/window or dashboard searchbar in Jupyterhub):\n{client.dashboard_link}'
    )

    if wait_for_cluster:
        target_workers = worker_min if adaptive_scaling else worker_max
        live_workers = len(list(cluster.scheduler_info['workers']))
        t = 0
        interval = 2
        print(
            f'Elapsed time to wait for {target_workers} live workers:\n{live_workers}/{target_workers} workers - {t} seconds',
            end='')
        while not live_workers >= target_workers:
            sleep(interval)
            t += interval
            print(f'\r{live_workers}/{target_workers} workers - {t} seconds',
                  end='')
            live_workers = len(client.scheduler_info()['workers'])
        print(f'\r{live_workers}/{target_workers} workers - {t} seconds')

    # We need to propagate credentials to the workers
    #set_credentials(profile=profile,region=region,endpoint=endpoint)

    if propagate_env:
        print('Propagating environment variables to workers')

        class InitWorker(WorkerPlugin):
            name = "init_worker"

            def __init__(self, filepath=None, script=None):
                self.data = {}
                if filepath:
                    if isinstance(filepath, str):
                        filepath = [filepath]
                    for file_ in filepath:
                        with open(file_, "rb") as f:
                            filename = os.path.basename(file_)
                            self.data[filename] = f.read()
                if script:
                    filename = f"{uuid.uuid1()}.py"
                    self.data[filename] = script

            async def setup(self, worker):
                responses = await asyncio.gather(*[
                    worker.upload_file(
                        comm=None, filename=filename, data=data, load=True)
                    for filename, data in self.data.items()
                ])
                assert all(
                    len(data) == r["nbytes"]
                    for r, data in zip(responses, self.data.values()))

        script = f"""
        \rimport os
        \ros.environ["AWS_ACCESS_KEY_ID"] = "{os.getenv("AWS_ACCESS_KEY_ID")}"
        \ros.environ["AWS_SECRET_ACCESS_KEY"] = "{os.getenv("AWS_SECRET_ACCESS_KEY")}"
        \ros.environ["AWS_DEFAULT_REGION"] = "{os.getenv("AWS_DEFAULT_REGION")}"
        \ros.environ["GDAL_DISABLE_READDIR_ON_OPEN"] ="EMPTY_DIR"
        """

        plugin = InitWorker(script=script)
        client.register_worker_plugin(plugin)

    return client, cluster