def main():
    """."""
    host = os.getenv('DASK_SCHEDULER_HOST', default='localhost')
    port = os.getenv('DASK_SCHEDULER_PORT', default=8786)
    print(host, port)
    client = Client('{}:{}'.format(host, port))
    # client.run(init_logging)
    # client.run_on_scheduler(init_logging)

    # Run some mock functions and gather a result
    data = client.map(print_listdir, range(10))
    future = client.submit(print_values, data)
    progress(future)
    print('')
    result = client.gather(future)
    print(result)

    # Run a second stage which runs some additional processing.
    print('here A')
    data_a = client.map(set_value, range(100))
    print('here B')
    data_b = client.map(square, data_a)
    print('here C')
    data_c = client.map(neg, data_b)
    print('here D')
    # Submit a function application to the scheduler
    total = client.submit(sum, data_c)
    print('here E')
    progress(total)
    print(total.result())
    print('here F')
Beispiel #2
0
def test_dask_connection():
    cluster = LocalCluster(
        scheduler_port=0,
        silence_logs=True,
        processes=False,
        asynchronous=False,
    )
    client = Client(cluster, asynchronous=False)

    def square(x):
        return x**2

    def neg(x):
        return -x

    # Run a computation on Dask
    a = client.map(square, range(10))
    b = client.map(neg, a)
    total = client.submit(sum, b)
    result = total.result()

    if result != -285:
        raise AssertionError("Result is " + str(result))
    else:
        print("The result is correct!!!")

    client.close()
    cluster.close()
    return True
Beispiel #3
0
def start_futures():
    t = time()
    isins = get_isins()

    client = Client('127.0.0.1:8786')

    data = client.map(load_data, isins)
    params_a = client.map(get_param, data, ['param_a'] * len(isins))
    params_b = client.map(get_param, data, ['param_b'] * len(isins))

    result_a = client.map(task_a, isins, params_a, params_b)

    group_args = list(chain(*zip(isins, result_a, params_b)))
    result_group = client.submit(task_group_alter, *group_args)

    result_b = client.map(task_b, isins, params_b, [result_group] * len(isins))

    result_c = client.map(task_c, isins, params_b)

    result = client.gather([result_group] + result_a + result_b + result_c)

    total = time() - t
    print(total)
    print(len(result))
    with open('/Users/vladimirmarunov/git/dask-test/res.txt', 'w') as f:
        f.write('{}\n'.format(total))
        json.dump(result, f, indent=4)
def main():
    client = Client('localhost:8786')
    A = client.map(set_value, range(100))
    B = client.map(square, A)
    C = client.map(neg, B)
    total = client.submit(sum, C)
    print(progress(total))
    print(total.result())
def main():
    client = Client('localhost:8786')
    A = client.map(set_value, range(100))
    B = client.map(square, A)
    C = client.map(neg, B)
    total = client.submit(sum, C)
    print(progress(total))
    print(total.result())
Beispiel #6
0
def init_workers(env_name, num_processes):
    cluster = LocalCluster(n_workers=num_processes)
    client = Client(cluster)
    pubs_reset = [
        Pub('env{}_reset'.format(seed)) for seed in range(num_processes)
    ]
    client.map(run_env, [env_name] * num_processes, range(num_processes))
    sub_obs = Sub('observations')
    # sleep while sub/pub is initialized
    time.sleep(5)
    return client, pubs_reset, sub_obs
def get_words(data):
    '''
    find all high_frequency_words in a given column of strings
    or other types of data whose elements are all strings
    
    a distributed client 'c' is applied 
    
    :param: data
    :type : pd.Series
    
    '''
    assert isinstance(data, pd.Series)
    assert all(isinstance(i, str) for i in data)
    from dask.distributed import Client

    c = Client()
    lines = [_ for _ in data]
    tasks = c.map(word_frequency, [_ for _ in lines])
    allDicts = c.gather(tasks)
    allDict = {}

    for dic in allDicts:
        for key in dic.keys():
            if key in allDict.keys():
                allDict[key] += 1
            else:
                allDict[key] = 1

    words = pd.Series(list(allDict.keys()), index=list(allDict.values()))
    words = words.sort_index()
    threshold = int(words.shape[0] / 50)
    words = words.loc[threshold:]
    c.close()
    return words
Beispiel #8
0
def main():
    import tensorflow as tf
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    tf.logging.set_verbosity(tf.logging.WARN)

    #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    start = timer()

    n_workers = 3
    n_gpus = 2

    cluster = LocalCluster(n_workers=n_workers,
                           threads_per_worker=1,
                           processes=True,
                           memory_limit=32e9)
    client = Client(cluster)
    logger.info('Created cluster')

    try:
        workers = cluster.workers
        #client.run(init_logging)

        tiles = np.array_split(np.arange(25), n_workers)
        print('Tile batches:', tiles)
        args = [(tiles[i], i % n_gpus) for i in range(n_workers)]
        res = client.map(run_pipeline, args)
        res = [r.result() for r in res]
    finally:
        client.close()
        cluster.close()

    stop = timer()

    print('Execution time:', stop - start)
def test_insert(pod):
    # Write with workers
    label = "my_label"
    repo = Repo(pod=pod)
    # Create collection and label
    collection = repo.create_collection(schema, "my_collection")
    token = pod.token
    cluster = LocalCluster(processes=False)
    client = Client(cluster)
    args = [(token, label, y) for y in years]
    with timeit(f"\nWRITE ({pod.protocol})"):
        fut = client.map(insert, args)
        assert sum(client.gather(fut)) == 10_519_200
    client.close()
    cluster.close()

    # Merge everything and read series
    with timeit(f"\nMERGE ({pod.protocol})"):
        collection.merge()

    with timeit(f"\nREAD ({pod.protocol})"):
        series = collection / label
        df = series["2015-01-01":"2015-01-02"].df()
        assert len(df) == 1440
        df = series["2015-12-31":"2016-01-02"].df()
        assert len(df) == 2880
Beispiel #10
0
def main():
    n_mutation = 100
    client = Client('scheduler:8786')
    futures = client.map(initialize_network, range(n_mutation))
    results = client.gather(futures)
    results.sort(key=lambda x: -x[1])

    truncated = list(map(lambda x: x[0], results[:3]))
    futures = []
    for i, seed in enumerate(truncated):
        name = 'top-{}'.format(i)
        futures.append(
            client.submit(initialize_network, seed, store=True, name=name))
    results = client.gather(futures)
    print(results, flush=True)

    for g in range(10):
        futures = []
        for seed in range(n_mutation):
            futures.append(client.submit(update_network, seed, g + 1))
        results = client.gather(futures)
        results.sort(key=lambda x: -x[1])
        truncated = list(map(lambda x: x[0], results[:3]))

        futures = []
        for i, seed in enumerate(truncated):
            name = 'top-{}'.format(i)
            futures.append(
                client.submit(update_network,
                              seed,
                              g + 1,
                              store=True,
                              name=name))
        results = client.gather(futures)
        print(results, flush=True)
    def main(self, gmrecords):
        """
        Assemble data and organize it into an ASDF file.

        Args:
            gmrecords:
                GMrecordsApp instance.
        """
        logging.info('Running subcommand \'%s\'' % self.command_name)
        self.gmrecords = gmrecords

        self._get_events()
        print(self.events)

        logging.info('Number of events to assemble: %s' % len(self.events))

        if gmrecords.args.num_processes:
            # parallelize processing on events
            try:
                client = Client(n_workers=gmrecords.args.num_processes)
            except BaseException as ex:
                print(ex)
                print("Could not create a dask client.")
                print("To turn off paralleization, use '--num-processes 0'.")
                sys.exit(1)
            futures = client.map(self._assemble_event, self.events)
            for result in as_completed(futures, with_results=True):
                print(result)
                # print('Completed event: %s' % result)
        else:
            for event in self.events:
                self._assemble_event(event)

        self._summarize_files_created()
Beispiel #12
0
def submit(cluster, config, merge_names):
    client = Client(cluster)
    opt_list = make_opt_list(config, merge_names)
    print(opt_list)
    func = locate(config["function"])
    future_list = client.map(func, opt_list)
    return opt_list, future_list
def main(args):
    config_file = args.config_file

    # Configure on cluster
    if config_file:
        stream = open(config_file, 'r')
        inp = yaml.load(stream)
        cores = inp['jobqueue']['slurm']['cores']
        memory = inp['jobqueue']['slurm']['memory']
        jobs = inp['jobqueue']['slurm']['jobs']
        cluster = SLURMCluster(
            cores=cores,
            memory=memory,
        )
        cluster.scale(jobs=jobs)

    # Configure locally
    else:
        cluster = LocalCluster()

    client = Client(cluster)
    raised_futures = client.map(sleep_more, range(100))
    progress(raised_futures)
    raised = client.gather(raised_futures)
    print('\n', raised)
Beispiel #14
0
    def main(self, gmrecords):
        """Compute waveform metrics.

        Args:
            gmrecords:
                GMrecordsApp instance.
        """
        logging.info('Running subcommand \'%s\'' % self.command_name)

        self.gmrecords = gmrecords
        self._get_events()

        if gmrecords.args.num_processes:
            # parallelize processing on events
            try:
                client = Client(n_workers=gmrecords.args.num_processes)
            except BaseException as ex:
                print(ex)
                print("Could not create a dask client.")
                print("To turn off paralleization, use '--num-processes 0'.")
                sys.exit(1)
            futures = client.map(self._compute_event_waveforms, self.events)
            for result in as_completed(futures, with_results=True):
                print(result)
        else:
            for event in self.events:
                self._compute_event_waveforms(event)

        self._summarize_files_created()
    def main(self, gmrecords):
        """Process data using steps defined in configuration file.

        Args:
            gmrecords:
                GMrecordsApp instance.
        """
        logging.info('Running subcommand \'%s\'' % self.command_name)

        self.gmrecords = gmrecords
        self._get_events()

        # get the process tag from the user or define by current datetime
        self.process_tag = (gmrecords.args.label
                            or datetime.utcnow().strftime(TAG_FMT))
        logging.info('Processing tag: %s' % self.process_tag)

        if gmrecords.args.num_processes:
            # parallelize processing on events
            try:
                client = Client(n_workers=gmrecords.args.num_processes)
            except BaseException as ex:
                print(ex)
                print("Could not create a dask client.")
                print("To turn off paralleization, use '--num-processes 0'.")
                sys.exit(1)
            futures = client.map(self._process_event, self.events)
            for result in as_completed(futures, with_results=True):
                print(result)
                # print('Completed event: %s' % result)
        else:
            for event in self.events:
                self._process_event(event)

        self._summarize_files_created()
Beispiel #16
0
    def handle(self, *args, **options):
        # Unpack variables
        model_id = options['model_id']
        out_dir = options['out_dir']

        # Create output dir if does not exist
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

        # datacube query
        gwf_kwargs = {
            k: options[k]
            for k in ['product', 'lat', 'long', 'region']
        }
        iterable = gwf_query(**gwf_kwargs)

        # Start cluster and run
        client = Client()
        client.restart()
        C = client.map(predict_pixel_tile, iterable, **{
            'model_id': model_id,
            'outdir': out_dir
        })
        filename_list = client.gather(C)
        print(filename_list)
Beispiel #17
0
class DaskHandler(IProcessingHandler):
    """This class wraps all Dask related functions."""

    def __init__(self, number_of_workers, class_cb: Callable, brain_class, worker_log_level=logging.WARNING):
        super().__init__(number_of_workers)
        self._client: Optional[Client] = None
        self._cluster: Optional[LocalCluster] = None

        self.class_cb = class_cb
        self.brain_class = brain_class
        self.worker_log_level = worker_log_level

    def init_framework(self):
        if self._client:
            raise RuntimeError("Dask client already initialized.")

        # threads_per_worker must be one, because atari-env is not thread-safe.
        # And because lower the thread-count from the default, we must increase the number of workers
        self._cluster = LocalCluster(processes=True, asynchronous=False, threads_per_worker=1,
                                     silence_logs=self.worker_log_level,
                                     n_workers=self.number_of_workers,
                                     memory_pause_fraction=False,
                                     lifetime='1 hour', lifetime_stagger='5 minutes', lifetime_restart=True,
                                     interface="lo")
        self._client = Client(self._cluster)
        self._client.register_worker_plugin(_CreatorPlugin(self.class_cb, self.brain_class), name="creator-plugin")
        logging.info("Dask dashboard available at port: " + str(self._client.scheduler_info()["services"]["dashboard"]))

    def map(self, func, *iterable):
        if not self._client:
            raise RuntimeError("Dask client not initialized. Call \"init_framework\" before calling \"map\"")
        return self._client.gather(self._client.map(func, *iterable))

    def cleanup_framework(self):
        self._client.shutdown()
    def rank_populations(self, top: int = 0.5):
        """Given all the populations, rank them according to the
        the given fitness function
        Inputs:
        =======
        populations (List): populations to evaluate
        top (int): percentage of top populations to return

        Outputs:
        ========
        best_populations (List): top populations
        """

        client = Client()

        client_input = [(self.data, p, self.yield_column)
                        for p in self.populations]

        futures = client.map(GeneticAlgorithm.evaluate_fitness, client_input)
        ranking = client.gather(futures)

        client.close()

        # return top performing populations
        top_n = int(top * len(self.populations))

        return [self.populations[i] for i in argsort(ranking)[-top_n:]]
    def handle(self, *args, **options):
        # Unpack variables
        name = options['name']
        model = options['model']
        segmentation = options['segmentation']
        spatial_aggregation = options['spatial_aggregation']
        categorical_variables = options['categorical_variables']
        scheduler_file = options['scheduler']

        # datacube query
        gwf_kwargs = { k: options[k] for k in ['product', 'lat', 'long', 'region']}
        iterable = gwf_query(**gwf_kwargs)

        # Start cluster and run 
        client = Client(scheduler_file=scheduler_file)
        client.restart()
        C = client.map(predict_object,
                       iterable,
                       pure=False,
                       **{'model_name': model,
                          'segmentation_name': segmentation,
                          'categorical_variables': categorical_variables,
                          'aggregation': spatial_aggregation,
                          'name': name,
                          })
        result = client.gather(C)

        print('Successfully ran prediction on %d tiles' % sum(result))
        print('%d tiles failed' % result.count(False))
Beispiel #20
0
class DaskParallelRunner(object):
    """Run the simulations using dask.distributed on a cluster. This requires some set up on the cluster
    (see the dask.distributed documentation).

    TO BE DOCUMENTED.
    """

    def __init__(self, client, chunk=10):

        if isinstance(client, str):
            from dask.distributed import Client
            self.client = Client(client)
        else:
            self.client = client
        self.chunk = chunk

    def __call__(self, function, argument_list):

        def function_with_single_numerical_threads(args):
            lib.set_max_numerical_threads(1)
            return function(*args)

        # make a bag
        argument_list = list(argument_list)
        n = self.chunk

        futures = []
        for i in range(0, len(argument_list), n):
            args = argument_list[i: i + n]
            future = self.client.map(function_with_single_numerical_threads, list(args))
            futures += future

        results = self.client.gather(futures, direct=False)

        return results
Beispiel #21
0
def _map_parallel_dask(
    f: 'Callable',
    *args: 'Iterable',
    processes: 'Optional[int]' = None,
    return_results: 'bool' = True,
) -> 'list':
    from dask.distributed import Client
    from dask.distributed import LocalCluster

    cluster = LocalCluster(n_workers=processes, dashboard_address=None)
    client = Client(cluster)
    if return_results:
        return [future.result() for future in client.map(f, *args)]
    else:
        for future in client.map(f, *args):
            future.result()
        return []
Beispiel #22
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("config", help="Configuration yaml file")
    parser.add_argument("-p",
                        "--proc",
                        type=int,
                        default=1,
                        help="Number of processors")
    args = parser.parse_args()
    if not exists(args.config):
        raise FileNotFoundError(args.config + " not found.")
    with open(args.config) as config_file:
        config = yaml.load(config_file)
    #time_files = get_cam_output_times(config["model_path"], time_var=config["time_var"],
    #                                  file_start=config["model_file_start"],
    #                                  file_end=config["model_file_end"])
    if not exists(config["out_path"]):
        makedirs(config["out_path"])
    #print(time_files)

    #filenames = np.sort(time_files["filename"].unique())
    filenames = sorted(
        glob(
            join(config["model_path"],
                 config["model_file_start"] + "*" + config["model_file_end"])))
    if "dt" not in config.keys():
        config["dt"] = 1800
    if args.proc == 1:
        for filename in filenames:
            process_cesm_file_subset(
                filename,
                staggered_variables=config["staggered_variables"],
                out_variables=config["out_variables"],
                subset_variable=config["subset_variable"],
                subset_threshold=config["subset_threshold"],
                out_path=config["out_path"],
                out_format=config["out_format"],
                dt=config["dt"])
    else:
        cluster = LocalCluster(n_workers=0)
        cluster.scale(args.proc)
        client = Client(cluster)
        print(client)
        futures = client.map(process_cesm_file_subset,
                             filenames,
                             staggered_variables=config["staggered_variables"],
                             out_variables=config["out_variables"],
                             subset_variable=config["subset_variable"],
                             subset_threshold=config["subset_threshold"],
                             out_path=config["out_path"],
                             out_start=config["out_start"],
                             out_format=config["out_format"],
                             dt=config["dt"])
        out = client.gather(futures)
        print(out)
        client.close()
    return
def main():
    args = parse_args()
    logging.info(args)
    cluster = init_cluster(args)
    client = Client(cluster)
    future_list = client.map(dummy_function, range(args.n_jobs))
    logging.info(cluster.job_script())
    for future in as_completed(future_list):
        exception = future.exception()
        traceback.print_exception(type(exception), exception, future.traceback())
Beispiel #24
0
def distributed_quickstart():
	# At least one dask-worker must be running after launching a scheduler.
	#client = Client('127.0.0.1:8786')  # Launch a Client and point it to the IP/port of the scheduler.
	#client = Client()  # Set up local cluster on your laptop.
	client = Client(n_workers=4, threads_per_worker=1)

	def square(x):
		return x ** 2

	def neg(x):
		return -x

	A = client.map(square, range(10))
	B = client.map(neg, A)

	total = client.submit(sum, B)
	print(total.result())  # Result for single future.

	print(client.gather(A))  # Gather for many futures.
Beispiel #25
0
def _get_online_sp():
    client = Client()  # start local workers as threads
    # TODO: Figure out a way to not hardwire the pages
    futures = client.map(request_online, range(1, 48))
    df = pd.concat(client.gather(futures)).reset_index(drop='index')
    cleaned_df = df[(df['snow_depth'] != '') & (df['lat'] != '') & (df['lon'] != '')].sort_values(
        by='time').reset_index(drop='index')
    cleaned_df.loc[:, 'lon'] = cleaned_df['lon'].apply(lambda x: float(x))
    cleaned_df.loc[:, 'lat'] = cleaned_df['lat'].apply(lambda x: float(x))
    return cleaned_df
Beispiel #26
0
def run(pl_conf, logging_init_fn=None):
    start = timer()

    # Initialize local dask cluster
    logger.info('Initializing pipeline tasks for %s workers', pl_conf.n_workers)
    logger.debug('Pipeline configuration: %s', pl_conf)
    cluster = LocalCluster(
        n_workers=pl_conf.n_workers, threads_per_worker=1,
        processes=True, memory_limit=pl_conf.memory_limit
    )
    client = Client(cluster)

    # Split total region + tile indexes to process into separate lists for each worker 
    # (by indexes of those index combinations)
    tiles = pl_conf.region_tiles
    idx_batches = np.array_split(np.arange(len(tiles)), pl_conf.n_workers)

    # Assign gpus to tasks in round-robin fashion
    def get_gpu(i):
        if pl_conf.gpus is None:
            return None
        return pl_conf.gpus[i % len(pl_conf.gpus)]

    # Generate a single task configuration for each worker
    tasks = [
        pl_conf.get_task_config(region_indexes=tiles[idx_batch, 0], tile_indexes=tiles[idx_batch, 1], gpu=get_gpu(i))
        for i, idx_batch in enumerate(idx_batches)
    ]

    logger.info('Starting pipeline for %s tasks', len(tasks))
    logger.debug('Task definitions:\n\t%s', '\n\t'.join([str(t) for t in tasks]))
    try:
        # Passing logging initialization operation, if given, to workers now
        # running in separate processes
        if logging_init_fn:
            client.run(logging_init_fn)

        # Disable the "auto_restart" feature of dask workers which is of no use in this context
        for worker in cluster.workers:
            worker.auto_restart = False

        # Pass tasks to each worker to execute in parallel
        res = client.map(run_pipeline_task, tasks)
        res = [r.result() for r in res]
        if len(res) != len(tasks):
            raise ValueError('Parallel execution returned {} results but {} were expected'.format(len(res), len(tasks)))
        stop = timer()
        if logger.isEnabledFor(logging.DEBUG):
            from scipy.stats import describe
            times = np.concatenate([np.array(t)[2] for t in res], 0)
            logger.debug('Per-tile execution time summary (all in seconds): %s', describe(times))
        logger.info('Pipeline execution completed in %s seconds', stop - start)
    finally:
        client.close()
        cluster.close()
Beispiel #27
0
 def run_distributed(self, qnos, queries, working_set=[], extra={}):
     """Set up a cluster first:
     dask-scheduler
     env PYTHONPATH=/research/remote/petabyte/users/binsheng/trec_tools/ dask-worker segsresap10:8786 --nprocs 50 --nthreads 1 --memory-limit 0 --name segsresap10
     env PYTHONPATH=/research/remote/petabyte/users/binsheng/trec_tools/ dask-worker segsresap10:8786 --nprocs 50 --nthreads 1 --memory-limit 0 --name segsresap09
     """
     client = Client(self._scheduler)
     futures = client.map(
         self.run_single, zip(qnos, queries, repeat(working_set), repeat(extra))
     )
     output = [f.result() for f in futures]
     return output
Beispiel #28
0
class daskerator(object):
    _DSCH = {
        'd': 'distributed',
        't': 'threads',
        'p': 'processes',
        's': 'synchronous'
    }

    def _get_sched(mp_type) -> str:
        if mp_type in daskerator._DSCH.keys():
            return daskerator._DSCH[mp_type]
        else:
            return mp_type

    mp_type = attr.ib(default='s',
                      type=str,
                      converter=_get_sched,
                      validator=attr.validators.in_(
                          list(_DSCH.keys()) + list(_DSCH.values())))
    sch_add = attr.ib(default='', type=str)

    @sch_add.validator
    def check_dask_opts(instance, attribute, value):
        if instance.mp_type != 'distributed' and value != '':
            raise ValueError(
                'Only distributed dask can accept scheduler address.')

    _client = attr.ib(default=None)
    _cluster = attr.ib(default=None)

    def __attrs_post_init__(self):
        if self.mp_type[0] == 'd':
            from dask.distributed import Client, LocalCluster
            dbg("Creating distributed client object.")
            if self.sch_add == '':
                dbg("Creating new cluster on localhost.")
                self._cluster = LocalCluster()
                self._client = Client(self._cluster)
            else:
                dbg(f"Existing scheduler address: {self.sch_add}")
                self._client = Client(self.sch_add)
            log.info(self._client)

    @curry
    def run_dask(self, func, iterator):
        dbg(f'Scheduler: {self.mp_type}')
        if self.mp_type[0] == 'd':
            dbg('Using dask client')
            return self._client.gather(self._client.map(func, iterator))
        else:
            dbg('Not using dask client.')
            return compute(*map(delayed(func), iterator),
                           scheduler=self.mp_type)
Beispiel #29
0
def convert_batch(apkFilenameList):
    """Convert APK files to AppGene files in batch.  Dask creates multiple threads or use multiple nodes to execute the convertSingleApk function.
       Args: 
         apkFilenameList: A list of the base filenames of APK files available from the HTTP interface to be converted 
       Returns:
         A list of conversion result objects 
    """
    client = Client(daskSchedulerConnection)
    # One APK file per new task
    futures = client.map(convertSingleApk, apkFilenameList)
    # Await until all tasks are done
    results = client.gather(futures)
    return list(results)
Beispiel #30
0
    def get_basis_rep(self):

        if self.basis_instructions.get('spec_agnostic', False):
            self.get_chemical_symbols = (
                lambda x: ['X'] * len(x.get_chemical_symbols()))
        else:
            self.get_chemical_symbols = (lambda x: x.get_chemical_symbols())

        if self.num_workers > 1:
            # cluster = LocalCluster(n_workers=1, threads_per_worker=self.num_workers)
            # print(cluster)
            # client = Client(cluster)
            client = Client()

        class FakeClient():
            def map(self, *args):
                return map(*args)

        if self.num_workers == 1:
            client = FakeClient()

        atoms = self.atoms
        extension = self.basis_instructions.get('extension', 'RHOXC')
        if extension[0] != '.':
            extension = '.' + extension

        jobs = []
        for i, system in enumerate(atoms):
            filename = ''
            for file in os.listdir(pjoin(self.src_path, str(i))):
                if file.endswith(extension):
                    filename = file
                    break
            if filename == '':
                raise Exception('Density file not found in ' +\
                    pjoin(self.src_path,str(i)))

            jobs.append([
                pjoin(self.src_path, str(i), filename),
                system.get_positions() / Bohr,
                self.get_chemical_symbols(system)
            ])
        # results = np.array([j.compute(num_workers = self.num_workers) for j in jobs])
        futures = client.map(transform_one,
                             *[[j[i] for j in jobs] for i in range(3)],
                             len(jobs) * [self.basis_instructions])
        if self.num_workers == 1:
            results = list(futures)
        else:
            results = [f.result() for f in futures]
        return results
Beispiel #31
0
class DaskRunManager(RunManager):
    """Manage multiple run in sequential or parallel mode"""

    def __init__(self, dataset):
        super().__init__(dataset)
        self.dask_config = dataset.config.get("dask", {})
        scheduler_config = self.dask_config.get("scheduler", {})
        mode = scheduler_config.pop("mode", "local")
        if mode == "local":
            cluster = LocalCluster(**scheduler_config)
            dataset.scheduler_address = cluster.scheduler_address
        elif mode == "distributed":
            dataset.scheduler_address = scheduler_config["scheduler_address"]
        else:
            raise ConfigurationError("{} model is not enable for dask".format(mode))
        self.client = Client(cluster.scheduler_address)

    def _run(self, records):

        futures = [
            self.dask_futures(self.record_structure(record)) for record in records
        ]

        asynchronous = self.dask_config.get("asynchronous", False)
        return self.client.gather(futures, asynchronous=asynchronous)

    def dask_futures(self, target):
        if isinstance(target, dict):
            return self.dask_parallel_futures(target)
        else:
            return self.dask_steps_future(target)

    def dask_parallel_futures(self, collection):
        def func(*args):
            return args

        return self.client.map(
            func, [self.dask_futures(item) for item in collection.values()]
        )

    def dask_steps_future(self, collection):

        future = None

        for item in collection:
            func = self.run_step_factory(item)
            future = self.client.submit(func, future)

        return future