コード例 #1
0
def test_warn_on_duration():
    with pytest.warns(None) as record:
        with warn_on_duration("10s", "foo"):
            pass
    assert not record

    with pytest.warns(None) as record:
        with warn_on_duration("1ms", "foo"):
            sleep(0.100)

    assert record
    assert any("foo" in str(rec.message) for rec in record)
コード例 #2
0
    async def _start(self, ):
        while self.status == Status.starting:
            await asyncio.sleep(0.01)
        if self.status == Status.running:
            return
        if self.status == Status.closed:
            raise ValueError("Cluster is closed")

        self.scheduler_spec = {
            "cls": self.scheduler_class,
            "options": self.scheduler_options,
        }
        self.new_spec = {
            "cls": self.worker_class,
            "options": self.worker_options
        }
        self.worker_spec = {i: self.new_spec for i in range(self._n_workers)}

        with warn_on_duration(
                "10s",
                "Creating your cluster is taking a surprisingly long time. "
                "This is likely due to pending resources. "
                "Hang tight! ",
        ):
            await super()._start()
コード例 #3
0
ファイル: ecs.py プロジェクト: yaojiach/dask-cloudprovider
    async def _start(self,):
        while self.status == "starting":
            await asyncio.sleep(0.01)
        if self.status == "running":
            return
        if self.status == "closed":
            raise ValueError("Cluster is closed")

        self.config = dask.config.get("cloudprovider.ecs", {})

        # Cleanup any stale resources before we start
        if self._skip_cleanup is None:
            self._skip_cleanup = self.config.get("skip_cleanup")
        if not self._skip_cleanup:
            await _cleanup_stale_resources()

        self._clients = await self._get_clients(
            aws_access_key_id=self._aws_access_key_id,
            aws_secret_access_key=self._aws_secret_access_key,
            region_name=self._region_name,
        )

        if self._fargate_scheduler is None:
            self._fargate_scheduler = self.config.get("fargate_scheduler")
        if self._fargate_workers is None:
            self._fargate_workers = self.config.get("fargate_workers")

        if self._tags is None:
            self._tags = self.config.get("tags")

        if self._environment is None:
            self._environment = self.config.get("environment")

        if self._find_address_timeout is None:
            self._find_address_timeout = self.config.get("find_address_timeout", 60)

        if self._worker_gpu is None:
            self._worker_gpu = self.config.get(
                "worker_gpu"
            )  # TODO Detect whether cluster is GPU capable

        if self.image is None:
            if self._worker_gpu:
                self.image = self.config.get("gpu_image")
            else:
                self.image = self.config.get("image")

        if self._scheduler_cpu is None:
            self._scheduler_cpu = self.config.get("scheduler_cpu")

        if self._scheduler_mem is None:
            self._scheduler_mem = self.config.get("scheduler_mem")

        if self._scheduler_timeout is None:
            self._scheduler_timeout = self.config.get("scheduler_timeout")

        if self._worker_cpu is None:
            self._worker_cpu = self.config.get("worker_cpu")

        if self._worker_mem is None:
            self._worker_mem = self.config.get("worker_mem")

        if self._n_workers is None:
            self._n_workers = self.config.get("n_workers")

        if self._cluster_name_template is None:
            self._cluster_name_template = self.config.get("cluster_name_template")

        if self.cluster_arn is None:
            self.cluster_arn = (
                self.config.get("cluster_arn") or await self._create_cluster()
            )

        if self.cluster_name is None:
            [cluster_info] = (
                await self._clients["ecs"].describe_clusters(
                    clusters=[self.cluster_arn]
                )
            )["clusters"]
            self.cluster_name = cluster_info["clusterName"]

        if self._execution_role_arn is None:
            self._execution_role_arn = (
                self.config.get("execution_role_arn")
                or await self._create_execution_role()
            )

        if self._task_role_policies is None:
            self._task_role_policies = self.config.get("task_role_policies")

        if self._task_role_arn is None:
            self._task_role_arn = (
                self.config.get("task_role_arn") or await self._create_task_role()
            )

        if self._cloudwatch_logs_stream_prefix is None:
            self._cloudwatch_logs_stream_prefix = self.config.get(
                "cloudwatch_logs_stream_prefix"
            ).format(cluster_name=self.cluster_name)

        if self._cloudwatch_logs_default_retention is None:
            self._cloudwatch_logs_default_retention = self.config.get(
                "cloudwatch_logs_default_retention"
            )

        if self.cloudwatch_logs_group is None:
            self.cloudwatch_logs_group = (
                self.config.get("cloudwatch_logs_group")
                or await self._create_cloudwatch_logs_group()
            )

        if self._vpc is None:
            self._vpc = self.config.get("vpc")

        if self._vpc == "default":
            self._vpc = await self._get_default_vpc()

        if self._vpc_subnets is None:
            self._vpc_subnets = (
                self.config.get("subnets") or await self._get_vpc_subnets()
            )

        if self._security_groups is None:
            self._security_groups = (
                self.config.get("security_groups")
                or await self._create_security_groups()
            )

        self.scheduler_task_definition_arn = (
            await self._create_scheduler_task_definition_arn()
        )
        self.worker_task_definition_arn = (
            await self._create_worker_task_definition_arn()
        )

        options = {
            "clients": self._clients,
            "cluster_arn": self.cluster_arn,
            "vpc_subnets": self._vpc_subnets,
            "security_groups": self._security_groups,
            "log_group": self.cloudwatch_logs_group,
            "log_stream_prefix": self._cloudwatch_logs_stream_prefix,
            "environment": self._environment,
            "tags": self.tags,
            "find_address_timeout": self._find_address_timeout,
        }
        scheduler_options = {
            "task_definition_arn": self.scheduler_task_definition_arn,
            "fargate": self._fargate_scheduler,
            **options,
        }
        worker_options = {
            "task_definition_arn": self.worker_task_definition_arn,
            "fargate": self._fargate_workers,
            "cpu": self._worker_cpu,
            "mem": self._worker_mem,
            "gpu": self._worker_gpu,
            **options,
        }

        self.scheduler_spec = {"cls": Scheduler, "options": scheduler_options}
        self.new_spec = {"cls": Worker, "options": worker_options}
        self.worker_spec = {i: self.new_spec for i in range(self._n_workers)}

        with warn_on_duration(
            "10s",
            "Creating your cluster is taking a surprisingly long time. "
            "This is likely due to pending resources on AWS. "
            "Hang tight! ",
        ):
            await super()._start()
コード例 #4
0
def main(
    scheduler,
    host,
    nthreads,
    name,
    memory_limit,
    device_memory_limit,
    pid_file,
    reconnect,
    resources,
    dashboard,
    dashboard_address,
    local_directory,
    scheduler_file,
    interface,
    death_timeout,
    preload,
    preload_argv,
    bokeh_prefix,
    tls_ca_file,
    tls_cert,
    tls_key,
):
    enable_proctitle_on_current()
    enable_proctitle_on_children()

    sec = Security(tls_ca_file=tls_ca_file,
                   tls_worker_cert=tls_cert,
                   tls_worker_key=tls_key)

    try:
        nprocs = len(os.environ["CUDA_VISIBLE_DEVICES"].split(","))
    except KeyError:
        nprocs = get_n_gpus()

    if not nthreads:
        nthreads = min(1, multiprocessing.cpu_count() // nprocs)

    if pid_file:
        with open(pid_file, "w") as f:
            f.write(str(os.getpid()))

        def del_pid_file():
            if os.path.exists(pid_file):
                os.remove(pid_file)

        atexit.register(del_pid_file)

    services = {}

    if dashboard:
        try:
            from distributed.dashboard import BokehWorker
        except ImportError:
            pass
        else:
            if bokeh_prefix:
                result = (BokehWorker, {"prefix": bokeh_prefix})
            else:
                result = BokehWorker
            services[("dashboard", dashboard_address)] = result

    if resources:
        resources = resources.replace(",", " ").split()
        resources = dict(pair.split("=") for pair in resources)
        resources = valmap(float, resources)
    else:
        resources = None

    loop = IOLoop.current()

    kwargs = {"worker_port": None, "listen_address": None}
    t = Nanny

    if not scheduler and not scheduler_file and "scheduler-address" not in config:
        raise ValueError("Need to provide scheduler address like\n"
                         "dask-worker SCHEDULER_ADDRESS:8786")

    if interface:
        if host:
            raise ValueError("Can not specify both interface and host")
        else:
            host = get_ip_interface(interface)

    if host:
        addr = uri_from_host_port(host, 0, 0)
    else:
        # Choose appropriate address for scheduler
        addr = None

    if death_timeout is not None:
        death_timeout = parse_timedelta(death_timeout, "s")

    local_dir = kwargs.get("local_dir", "dask-worker-space")
    with warn_on_duration(
            "1s",
            "Creating scratch directories is taking a surprisingly long time. "
            "This is often due to running workers on a network file system. "
            "Consider specifying a local-directory to point workers to write "
            "scratch data to a local disk.",
    ):
        _workspace = WorkSpace(os.path.abspath(local_dir))
        _workdir = _workspace.new_work_dir(prefix="worker-")
        local_dir = _workdir.dir_path

    nannies = [
        t(
            scheduler,
            scheduler_file=scheduler_file,
            nthreads=nthreads,
            services=services,
            loop=loop,
            resources=resources,
            memory_limit=memory_limit,
            reconnect=reconnect,
            local_dir=local_directory,
            death_timeout=death_timeout,
            preload=(preload or []) + ["dask_cuda.initialize_context"],
            preload_argv=preload_argv,
            security=sec,
            contact_address=None,
            env={"CUDA_VISIBLE_DEVICES": cuda_visible_devices(i)},
            name=name if nprocs == 1 or not name else name + "-" + str(i),
            data=(
                DeviceHostFile,
                {
                    "device_memory_limit":
                    get_device_total_memory(index=i) if
                    (device_memory_limit == "auto" or device_memory_limit
                     == int(0)) else parse_bytes(device_memory_limit),
                    "memory_limit":
                    parse_memory_limit(memory_limit,
                                       nthreads,
                                       total_cores=nprocs),
                    "local_dir":
                    local_dir,
                },
            ),
            **kwargs,
        ) for i in range(nprocs)
    ]

    @gen.coroutine
    def close_all():
        # Unregister all workers from scheduler
        yield [n._close(timeout=2) for n in nannies]

    def on_signal(signum):
        logger.info("Exiting on signal %d", signum)
        close_all()

    @gen.coroutine
    def run():
        yield [n._start(addr) for n in nannies]
        while all(n.status != "closed" for n in nannies):
            yield gen.sleep(0.2)

    install_signal_handlers(loop, cleanup=on_signal)

    try:
        loop.run_sync(run)
    except (KeyboardInterrupt, TimeoutError):
        pass
    finally:
        logger.info("End worker")