Beispiel #1
0
class SwissFelCluster:
    def __init__(self, cores=8, memory="24 GB", workers=5):
        self.cluster = SLURMCluster(cores=cores, memory=memory)
        self.client = Client(self.cluster)
        self.ip = socket.gethostbyname(socket.gethostname())
        self.dashboard_port_scheduler = self.client._scheduler_identity.get(
            "services")["dashboard"]
        self.username = getpass.getuser()

    def _repr_html_(self):
        return self.client._repr_html_()

    def scale_workers(self, N_workers):
        self.cluster.scale(N_workers)

    def create_dashboard_tunnel(self, ssh_host="ra"):
        print(
            "type following commant in a terminal, if port is taken, change first number in command."
        )
        print(" ".join([
            f"jupdbport={self.dashboard_port_scheduler}",
            "&&",
            "ssh",
            "-f",
            "-L",
            f"$jupdbport:{self.ip}:{self.dashboard_port_scheduler}",
            f"{self.username}@{ssh_host}",
            "sleep 10",
            "&&",
            "firefox",
            "http://localhost:$jupdbport",
        ]))
Beispiel #2
0
def test_run_sorters_dask():
    cache_folder = './local_cache'
    working_folder = 'test_run_sorters_dask'
    if os.path.exists(cache_folder):
        shutil.rmtree(cache_folder)
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    # create recording
    recording_dict = {}
    for i in range(8):
        rec, _ = toy_example(num_channels=4, duration=30, seed=0, num_segments=1)
        # make dumpable
        rec = rec.save(name=f'rec_{i}')
        recording_dict[f'rec_{i}'] = rec

    sorter_list = ['tridesclous', ]

    # create a dask Client for a slurm queue
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster

    python = '/home/samuel.garcia/.virtualenvs/py36/bin/python3.6'
    cluster = SLURMCluster(processes=1, cores=1, memory="12GB", python=python, walltime='12:00:00', )
    cluster.scale(5)
    client = Client(cluster)

    # dask
    t0 = time.perf_counter()
    run_sorters(sorter_list, recording_dict, working_folder,
                engine='dask', engine_kwargs={'client': client},
                with_output=False,
                mode_if_folder_exists='keep')
    t1 = time.perf_counter()
    print(t1 - t0)
def main(args):
    config_file = args.config_file

    # Configure on cluster
    if config_file:
        stream = open(config_file, 'r')
        inp = yaml.load(stream)
        cores = inp['jobqueue']['slurm']['cores']
        memory = inp['jobqueue']['slurm']['memory']
        jobs = inp['jobqueue']['slurm']['jobs']
        cluster = SLURMCluster(
            cores=cores,
            memory=memory,
        )
        cluster.scale(jobs=jobs)

    # Configure locally
    else:
        cluster = LocalCluster()

    client = Client(cluster)
    raised_futures = client.map(sleep_more, range(100))
    progress(raised_futures)
    raised = client.gather(raised_futures)
    print('\n', raised)
Beispiel #4
0
def initialize_dask(n, factor = 5, slurm = False):

    if not slurm:
        cores =  len(os.sched_getaffinity(0))
        cluster = distributed.LocalCluster(processes = False,
                                           n_workers = 1,
                                           threads_per_worker = 1)

    else:
        n = min(100, n)
        py = './enter_conda.sh python3'
        params = {
            'python' : py,
            'cores' : 1,
            'memory' : '512MB',
            'walltime' : '180',
            'processes' : 1,
            'job_extra' : [
                '--qos use-everything',
                '--array 0-{0:d}'.format(n - 1),
                '--requeue',
                '--output "/dev/null"'
            ],
            'env_extra' : [
                'JOB_ID=${SLURM_ARRAY_JOB_ID%;*}_${SLURM_ARRAY_TASK_ID%;*}',
                'source /etc/profile.d/modules.sh',
                'cd {0!s}'.format(CONFIG['PATHS', 'root']),
            ]
        }
        cluster = SLURMCluster(**params)
        print(cluster.job_script())
        cluster.scale(1)

    print(cluster.dashboard_link)
    return distributed.Client(cluster)
def make_cluster():
    if socket.gethostname() == 'sgw1':

        # number of processing units per node. for ease of use, cores to the
        # number of CPU per node warning: this is the unitary increment by
        # which you can scale your number of workers inside your cluster.
        proc_per_worker = 24

        # total number of slurm node to request. Max number of dask workers
        # will be proc_per_worker * max_slurm_nodes
        max_slurm_nodes = 4

        cluster = SLURMCluster(
            workers=0,  # number of (initial slurm jobs)
            memory="16GB",
            # cores = number processing units per worker, can be
            # dask.Worker (processes) or threads of a worker's
            # ThreadPoolExecutor
            cores=proc_per_worker,
            # among those $cores workers, how many should be dask Workers,
            # (each worker will then have cores // processes threads inside
            # their ThreadPoolExecutor)
            # sets cpus-per-task=processes inside batch script
            processes=proc_per_worker,
            # job_extra=[get_sbatch_args(max_workers, proc_per_worker)],
        )
        # scale the number of unitary dask workers (and not batch jobs)
        cluster.scale(96)
    else:
        cluster = LocalCluster(
            n_workers=2, threads_per_worker=1, processes=False,
            dashboard_address=':7777'
        )
    return cluster
def train_on_jz_dask(job_name, train_function, *args, **kwargs):
    cluster = SLURMCluster(
        cores=1,
        job_cpu=20,
        memory='80GB',
        job_name=job_name,
        walltime='60:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:1',
            '--qos=qos_gpu-t4',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/fastmri-reproducible-benchmark',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    cluster.scale(1)

    print(cluster.job_script())

    client = Client(cluster)
    futures = client.submit(
        # function to execute
        train_function,
        *args,
        **kwargs,
        # this function has potential side effects
        pure=True,
    )
    client.gather(futures)
    print('Shutting down dask workers')
Beispiel #7
0
def main(args):

    split_files = split_file(args.url_file)


    if args.distribute:
        extra_args = [
            "-J newsnet_worker"
            "--mail-type=ALL",
            "[email protected]"
            "--gres=nvme:100"]

        cluster = SLURMCluster(
            name = "newsnet_worker",
            cores = 20,
            memory="2GB",
            queue="small",
            walltime="3:00:00",
            local_directory = '/tmp',
            log_directory = f"{os.environ.get('PWD')}/dask-worker-space",
            project = args.project,
            job_extra = extra_args)

        with Client(cluster) as client:
            print("\n\nLaunching Dask SLURM cluster...")
            cluster.scale(4)
            to_upload = f'{os.path.dirname(os.path.abspath(sys.argv[0]))}/parse_articles.py'
            client.upload_file(to_upload)
            print(to_upload)
            _ = [run_parse(args, file) for file in split_files]
            [os.remove(sf) for sf in split_files]
    else:
        with Client() as client:
            _ = [run_parse(args, file) for file in split_files]
            [os.remove(sf) for sf in split_files]
Beispiel #8
0
class ManagedSLURMCluster(ManagedCluster):
    """
    Args:
        project (str, optional): project name
        queue (str, optional): queue to submit to
        walltime (str, optional): maximum wall time
    """
    def __init__(self,
                 project=None,
                 queue=None,
                 walltime="24:00:00",
                 **kwargs):
        super().__init__(**kwargs)
        self._project = project
        self._queue = queue
        self._walltime = walltime

    def open(self):
        from dask_jobqueue import SLURMCluster

        args = {
            "cores": self.threads_per_worker,
            "processes": 1,
            "memory": self.memory,
            "project": self._project,
            "queue": self._queue,
            "walltime": self._walltime,
            "log_directory": "/tmp",
        }
        self._cluster = SLURMCluster(**args)
        self._cluster.scale(self.n_workers)
def train_on_jz_dask(job_name, train_function, *args, **kwargs):
    cluster = SLURMCluster(
        cores=1,
        job_cpu=40,
        memory='80GB',
        job_name=job_name,
        walltime='20:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:4',
            '--qos=qos_gpu-t3',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/understanding-unets',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    cluster.scale(1)

    print(cluster.job_script())

    client = Client(cluster)
    futures = client.submit(
        # function to execute
        train_function,
        *args,
        **kwargs,
        # this function has potential side effects
        pure=True,
    )
    run_id = client.gather(futures)
    print(f'Train run id: {run_id}')
Beispiel #10
0
def launch_dask_cluster(queue, nodes, localcluster):
    """
    Usage from script:
        from distributed import Client
        from lsst_dashboard.cli import launch_dask_cluster
        cluster, port = launch_dask_cluster('normal', 6, False)
        client = Client(cluster)
    """
    # Launch Dask Cluster
    if "lsst-dev" in host:
        # Set up allowed ports
        (scheduler_port, ) = find_available_ports(1, *DASK_ALLOWED_PORTS)
        (lsst_dashboard_port, ) = find_available_ports(
            1, *DASHBOARD_ALLOWED_PORTS)
        (dask_dashboard_port, ) = find_available_ports(
            1, *DASK_DASHBOARD_ALLOWED_PORTS)
    else:
        localcluster = True
        lsst_dashboard_port = 52001
        dask_dashboard_port = 52002

    if not localcluster:
        from dask_jobqueue import SLURMCluster

        print(
            f"...starting dask cluster using slurm on {host} (queue={queue})")
        procs_per_node = 6
        cluster = SLURMCluster(
            queue=queue,
            cores=24,
            processes=procs_per_node,
            memory="128GB",
            scheduler_port=scheduler_port,
            extra=[
                f'--worker-port {":".join(str(p) for p in DASK_ALLOWED_PORTS)}'
            ],
            dashboard_address=f":{dask_dashboard_port}",
        )

        print(f"...requesting {nodes} nodes")
        cluster.scale(nodes * procs_per_node)
        print(
            "run the command below from your local machine to forward ports for view dashboard and dask diagnostics:"
        )
        print(
            f"\nssh -N -L {lsst_dashboard_port}:{host}:{lsst_dashboard_port} -L {dask_dashboard_port}:{host}:{dask_dashboard_port} {username}@{hostname}\n"
        )
    else:
        from dask.distributed import LocalCluster

        print(f"starting local dask cluster on {host}")
        cluster = LocalCluster(dashboard_address=f":{dask_dashboard_port}")

    print(
        f"### dask dashboard available at http://localhost:{dask_dashboard_port} ###"
    )
    return cluster, lsst_dashboard_port
def eval_parameter_grid(run_ids,
                        job_name,
                        eval_function,
                        parameter_grid,
                        n_gpus=1):
    parameters = list(ParameterGrid(parameter_grid))
    n_parameters_config = len(parameters)
    # eval
    eval_cluster = SLURMCluster(
        cores=1,
        job_cpu=40,
        memory='80GB',
        job_name=job_name,
        walltime='5:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:{n_gpus}',
            '--qos=qos_gpu-t3',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/fastmri-reproducible-benchmark',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    eval_cluster.scale(n_parameters_config)
    client = Client(eval_cluster)
    original_parameters = []
    for params in parameters:
        original_params = {}
        original_params['n_samples'] = params.pop('n_samples', None)
        original_params['loss'] = params.pop('loss', 'mae')
        original_params['fixed_masks'] = params.pop('fixed_masks', False)
        original_parameters.append(original_params)
    futures = [
        client.submit(
            # function to execute
            eval_function,
            run_id=run_id,
            n_samples=50,
            **params,
        ) for run_id, params in zip(run_ids, parameters)
    ]

    for params, original_params, future in zip(parameters, original_parameters,
                                               futures):
        metrics_names, eval_res = client.gather(future)
        params.update(original_params)
        print('Parameters', params)
        print(metrics_names)
        print(eval_res)
    print('Shutting down dask workers')
    client.close()
    eval_cluster.close()
def start(cpus=0, gpus=0, mem_size="10GB"):
    #################
    # Setup dask cluster
    #################

    if cpus > 0:
        #job args
        extra_args = [
            "--error=/orange/idtrees-collab/logs/dask-worker-%j.err",
            "--account=ewhite",
            "--output=/orange/idtrees-collab/logs/dask-worker-%j.out"
        ]

        cluster = SLURMCluster(
            processes=1,
            queue='hpg2-compute',
            cores=1,
            memory=mem_size,
            walltime='24:00:00',
            job_extra=extra_args,
            extra=['--resources cpu=1'],
            scheduler_options={"dashboard_address": ":8781"},
            local_directory="/orange/idtrees-collab/tmp/",
            death_timeout=300)

        print(cluster.job_script())
        cluster.scale(cpus)

    if gpus:
        #job args
        extra_args = [
            "--error=/orange/idtrees-collab/logs/dask-worker-%j.err",
            "--account=ewhite",
            "--output=/orange/idtrees-collab/logs/dask-worker-%j.out",
            "--partition=gpu", "--gpus=1"
        ]

        cluster = SLURMCluster(
            processes=1,
            cores=1,
            memory=mem_size,
            walltime='24:00:00',
            job_extra=extra_args,
            extra=['--resources gpu=1'],
            scheduler_options={"dashboard_address": ":8787"},
            local_directory="/orange/idtrees-collab/tmp/",
            death_timeout=300)

        cluster.scale(gpus)

    dask_client = Client(cluster)

    #Start dask
    dask_client.run_on_scheduler(start_tunnel)

    return dask_client
Beispiel #13
0
def get_slurm_dask_client(n_workers):
    cluster = SLURMCluster(cores=24,
                           memory='128GB',
                           project="co_aiolos",
                           walltime="24:00:00",
                           queue="savio2_bigmem")

    cluster.scale(n_workers)
    client = Client(cluster)
    return client
Beispiel #14
0
def get_slurm_dask_client_savio3(n_nodes):
    cluster = SLURMCluster(cores=32,
                           memory='96GB',
                           project="co_aiolos",
                           walltime="72:00:00",
                           queue="savio3",
                           job_extra=['--qos="aiolos_savio3_normal"'])

    cluster.scale(n_nodes*32)
    client = Client(cluster)
    return client
Beispiel #15
0
def get_slurm_dask_client(n_workers, n_cores):
    cluster = SLURMCluster(cores=n_cores,
                           memory='32GB',
                           project="co_aiolos",
                           walltime="02:00:00",
                           queue="savio2_gpu",
                           job_extra=['--gres=gpu:1','--cpus-per-task=2'])

    cluster.scale(n_workers)
    client = Client(cluster)
    return client
Beispiel #16
0
def get_slurm_dask_client_bigmem(n_nodes):
    cluster = SLURMCluster(cores=24,
                           memory='128GB',
                           project="co_aiolos",
                           walltime="02:00:00",
                           queue="savio2_bigmem",
                           job_extra=['--qos="savio_lowprio"'])

    cluster.scale(n_nodes*6)
    client = Client(cluster)
    return client
def eval_parameter_grid(job_name,
                        eval_function,
                        parameter_grid,
                        run_ids,
                        n_samples_eval=None):
    parameters = list(ParameterGrid(parameter_grid))
    n_parameters_config = len(parameters)
    assert n_parameters_config == len(
        run_ids), 'Not enough run ids provided for grid evaluation'
    eval_cluster = SLURMCluster(
        cores=1,
        job_cpu=40,
        memory='60GB',
        job_name=job_name,
        walltime='3:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:4',
            '--qos=qos_gpu-t3',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/understanding-unets',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    eval_cluster.scale(n_parameters_config)
    client = Client(eval_cluster)
    n_samples_list = []
    for params in parameters:
        n_samples = params.pop('n_samples', -1)
        n_samples_list.append(n_samples)
    futures = [
        client.submit(
            # function to execute
            eval_function,
            run_id=run_id,
            n_samples=n_samples_eval,
            **params,
        ) for run_id, params in zip(run_ids, parameters)
    ]

    results = []
    for params, future, n_samples in zip(parameters, futures, n_samples_list):
        metrics_names, eval_res = client.gather(future)
        if n_samples != -1:
            params.update({'n_samples': n_samples})
        results.append((params, eval_res))
    print('Shutting down dask workers')
    client.close()
    eval_cluster.close()
    return metrics_names, results
Beispiel #18
0
def get_slurm_dask_client_bigmem(n_nodes):
    cluster = SLURMCluster(cores=24,
                           memory='128GB',
                           project="co_aiolos",
                           walltime="02:00:00",
                           queue="savio2_bigmem",
                           local_directory = '/global/home/users/qindan_zhu/myscratch/qindan_zhu/SatelliteNO2',
                            job_extra=['--qos="savio_lowprio"'])

    cluster.scale(n_nodes*4)
    client = Client(cluster)
    return client
Beispiel #19
0
def main():
    cluster = SLURMCluster(cores=2, memory="10GB", walltime='00:05:00')
    cluster.scale(
        5)  # Start 100 workers in 100 jobs that match the description above
    client = Client(cluster)  # Connect to that cluster    client = Client()
    print(client)
    start = datetime.now()
    results = run_test(client=client)
    end = datetime.now()

    print(f"Time taken: {end - start}")
    print(results)
Beispiel #20
0
def createSLURMCluster():
    cluster = SLURMCluster(queue=single_worker['queue'],
                           project=single_worker['project'],
                           cores=single_worker['cores'],
                           memory=single_worker['memory'],
                           walltime=single_worker['time'],
                           interface='ib0',
                           local_directory=single_worker['temp_folder'])

    cluster.scale(number_of_workers)
    client = Client(cluster)
    print(client)
Beispiel #21
0
def get_slurm_dask_client_savio3(n_nodes):
    cluster = SLURMCluster(cores=32,
                           memory='96GB',
                           project="co_aiolos",
                           walltime="72:00:00",
                           queue="savio3",
                           local_directory = '/global/home/users/qindan_zhu/myscratch/qindan_zhu/SatelliteNO2',
                           job_extra=['--qos="aiolos_savio3_normal"'])

    cluster.scale(n_nodes*8)
    client = Client(cluster)
    return client
Beispiel #22
0
def slurm_cluster_setup(nodes=1, **kwargs):
    """
    Set up SLURM cluster

    Parameters
    ----------
    nodes: int
        Number of nodes to use
    **kwargs:
        Keyword arguments for cluster specifications
    """
    from dask_jobqueue import SLURMCluster
    cluster = SLURMCluster(**kwargs)
    cluster.scale(nodes)
    return cluster
Beispiel #23
0
def test_run_sorters_dask():
    # create a dask Client for a slurm queue
    from dask.distributed import Client
    from dask_jobqueue import SLURMCluster

    python = '/home/samuel.garcia/.virtualenvs/py36/bin/python3.6'
    cluster = SLURMCluster(
        processes=1,
        cores=1,
        memory="12GB",
        python=python,
        walltime='12:00:00',
    )
    cluster.scale(5)
    client = Client(cluster)

    # create recording
    recording_dict = {}
    for i in range(8):
        rec, _ = se.example_datasets.toy_example(num_channels=8,
                                                 duration=30,
                                                 seed=0,
                                                 dumpable=True)
        recording_dict['rec_' + str(i)] = rec

    # sorter_list = ['mountainsort4', 'klusta', 'tridesclous']
    sorter_list = [
        'tridesclous',
    ]
    # ~ sorter_list = ['tridesclous', 'herdingspikes']

    working_folder = 'test_run_sorters_dask'
    if os.path.exists(working_folder):
        shutil.rmtree(working_folder)

    # dask
    t0 = time.perf_counter()
    results = run_sorters(sorter_list,
                          recording_dict,
                          working_folder,
                          engine='dask',
                          engine_kwargs={'client': client},
                          with_output=True)
    # dask do not return results always None
    assert results is None
    t1 = time.perf_counter()
    print(t1 - t0)
def evaluate_pdnet_sense_dask(run_id, contrast, af, n_iter,
                              cuda_visible_devices, n_samples):
    job_name = f'evaluate_pdnet_sense_{af}'
    if contrast is not None:
        job_name += f'_{contrast}'

    cluster = SLURMCluster(
        cores=1,
        job_cpu=40,
        memory='160GB',
        job_name=job_name,
        walltime='20:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:4',
            '--qos=qos_gpu-t3',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/fastmri-reproducible-benchmark',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    cluster.scale(1)

    print(cluster.job_script())

    client = Client(cluster)
    futures = client.submit(
        # function to execute
        evaluate_pdnet_sense,
        # *args
        run_id,
        contrast,
        int(af),
        n_iter,
        n_samples,
        cuda_visible_devices,
        # this function has potential side effects
        pure=True,
    )
    metrics_names, eval_res = client.gather(futures)
    print(metrics_names)
    print(eval_res)
    print('Shutting down dask workers')
def train_eval_parameter_grid(job_name,
                              train_function,
                              eval_function,
                              parameter_grid,
                              n_samples_eval=None):
    parameters = list(ParameterGrid(parameter_grid))
    n_parameters_config = len(parameters)
    train_cluster = SLURMCluster(
        cores=1,
        job_cpu=40,
        memory='60GB',
        job_name=job_name,
        walltime='20:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:4',
            '--qos=qos_gpu-t3',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/understanding-unets',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    train_cluster.scale(n_parameters_config)
    client = Client(train_cluster)
    futures = [
        client.submit(
            # function to execute
            train_function,
            **params,
        ) for params in parameters
    ]
    run_ids = client.gather(futures)
    client.close()
    train_cluster.close()
    # eval
    return eval_parameter_grid(
        job_name,
        eval_function,
        parameter_grid,
        run_ids,
        n_samples_eval=n_samples_eval,
    )
Beispiel #26
0
def create_dask_cluster(use_slurm: bool, n_workers: int,
                        threads_per_worker: int):
    if use_slurm:
        cluster = SLURMCluster(
            workers=0,  # number of (initial slurm jobs)
            memory="16GB",
            extra=['--nthreads 1 --nprocs=4'],  # arguments to dask-worker CLI
            job_extra=[get_sbatch_args(n_workers)],
        )
        num_jobs = (n_workers - 1) // WORKER_PER_JOBS + 1
        cluster.scale(num_jobs)
    else:
        cluster = LocalCluster(
            n_workers=n_workers,
            threads_per_worker=threads_per_worker,
            processes=True,
        )
    return cluster
Beispiel #27
0
def start_client(num_workers):
    '''
    initialize dask client
    '''
    cluster = SLURMCluster(
        queue='batch',
        walltime='04-23:00:00',
        cores=1,
        memory='10000MiB',  #1 GiB = 1,024 MiB
        processes=1)
    print('dashboard link: ', cluster.dashboard_link)
    cluster.scale(num_workers)
    client = Client(cluster)
    print(client)
    print('scheduler info: ', client.scheduler_info())
    time.sleep(5)

    return client
def launch_dask_tasks(batch_sizes, save):
    job_name = 'dask_mnist_tf_example'

    cluster = SLURMCluster(
        cores=1,
        job_cpu=10,
        memory='10GB',
        job_name=job_name,
        walltime='1:00:00',
        interface='ib0',
        job_extra=[
            f'--gres=gpu:1',
            '--qos=qos_gpu-dev',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
    )
    n_jobs = len(batch_sizes)
    cluster.scale(jobs=n_jobs)
    print(cluster.job_script())

    client = Client(cluster)
    futures = [
        client.submit(
            # function to execute
            train_dense_model,
            # *args
            None,
            save,
            batch_size,
            # this function has potential side effects
            pure=not save,
        ) for batch_size in batch_sizes
    ]
    job_result = client.gather(futures)
    if all(job_result):
        print('All jobs finished without errors')
    else:
        print('One job errored out')
    print('Shutting down dask workers')
def eval_on_jz_dask(job_name, eval_function, *args, **kwargs):
    cluster = SLURMCluster(
        cores=1,
        job_cpu=40,
        memory='80GB',
        job_name=job_name,
        walltime='20:00:00',
        interface='ib0',
        job_extra=[
            # for now we can't use 4 GPUs because of
            # https://github.com/tensorflow/tensorflow/issues/39268
            f'--gres=gpu:1',
            '--qos=qos_gpu-t3',
            '--distribution=block:block',
            '--hint=nomultithread',
            '--output=%x_%j.out',
        ],
        env_extra=[
            'cd $WORK/fastmri-reproducible-benchmark',
            '. ./submission_scripts_jean_zay/env_config.sh',
        ],
    )
    cluster.scale(1)

    print(cluster.job_script())

    client = Client(cluster)
    futures = client.submit(
        # function to execute
        eval_function,
        *args,
        **kwargs,
        # this function has potential side effects
        pure=True,
    )
    metrics_names, eval_res = client.gather(futures)
    print(metrics_names)
    print(eval_res)
    print('Shutting down dask workers')
def start_slurm_scheduler(account, cores, walltime, memory, processes, interface, local_dir,
                          scheduler_port, dash_port,
                          num_workers, adapt_min, adapt_max):

    # choose either adaptive mode or fixed number of walkers mode (you
    # can always connect to and scale manually without adapt mode),
    # but adapt mode is the default since it is the most no nonsense
    # DWIM approach
    adapt_mode = True
    if num_workers > -1:
        adapt_mode = False

    local_cluster_kwargs = {'scheduler_port' : scheduler_port,
                            'dashboard_address' : ':{}'.format(dash_port)}

    cluster = SLURMCluster(project=account,
                           cores=cores,
                           walltime=walltime,
                           memory=memory,
                           processes=processes,
                           interface=interface,
                           **local_cluster_kwargs)
    with cluster:

        click.echo("Scheduler address: {}".format(cluster.scheduler_address))
        click.echo("Dashboard port: {}".format(cluster.dashboard_link))

        if adapt_mode:
            cluster.adapt(minimum=adapt_min, maximum=adapt_max)
        else:
            cluster.scale(num_workers)



        # loop forever to block
        while True:
            # sleep so we avoid evaluating the loop to frequently
            time.sleep(2)