示例#1
0
    def task():
        """Creates the .ssh dir with proper permissions. Adds the public key
            to the authorized keys file if it's not been added already."""
        with stage_debug(log, "Creating authorized keys file: %s",
                         authorized_keys_path):
            with capture_fabric_output_to_log():
                run("mkdir -p ~/.ssh"
                    " && chmod 700 ~/.ssh"
                    " && touch '{authorized_keys_path}'"
                    " && chmod 644 '{authorized_keys_path}'".format(
                        authorized_keys_path=authorized_keys_path))

        with stage_debug(log, "Downloading authorized keys file."):
            with capture_fabric_output_to_log():
                authorized_keys_fd = BytesIO()
                get(authorized_keys_path, authorized_keys_fd)
                authorized_keys_contents = \
                    authorized_keys_fd.getvalue().decode('ascii').splitlines()

        if public_key not in authorized_keys_contents:
            with stage_debug(log, "Appending to authorized keys file."):
                log.debug("Warning: This operation is not atomic on NFS.")
                with capture_fabric_output_to_log():
                    run("echo '{public_key}' >> {authorized_keys_path}".format(
                        public_key=public_key,
                        authorized_keys_path=authorized_keys_path))

        with stage_debug(log, "Checking if public key was added."):
            with capture_fabric_output_to_log():
                run("grep '{public_key}' '{authorized_keys_path}'".format(
                    public_key=public_key,
                    authorized_keys_path=authorized_keys_path))
示例#2
0
    def task():
        """Creates the entry point dir and file.
            Fails if it couldn't be created."""
        with capture_fabric_output_to_log():
            run("mkdir -p {entry_point_location}"
                " && chmod 700 {entry_point_location}".format(
                    entry_point_location=entry_point_location))

            file_name = get_random_file_name(
                length=ENTRY_POINT_FILE_NAME_LENGTH)
            file_path = run("echo {entry_point_location}/{file_name}".format(
                entry_point_location=entry_point_location,
                file_name=file_name))
            file_exists = exists(file_path)

        if file_exists:
            log.warning("Overwriting randomly named entry point file:"
                        " %s", file_path)

        with stage_debug(log, "Uploading the entry point script."):
            with capture_fabric_output_to_log():
                real_path = run("echo {file_path}".format(file_path=file_path))
                file = BytesIO(contents.encode('ascii'))
                put(file, real_path, mode=0o700)

        with stage_debug(log, "Checking the entry point script was uploaded."):
            with capture_fabric_output_to_log():
                run("cat {real_path} > /dev/null".format(real_path=real_path))
        result.append(real_path)
示例#3
0
def deploy_generic(node: NodeInternal,
                   script_contents: str,
                   runtime_dir: str) -> GenericDeployment:
    """Deploys a program on the node.

        :param node: Node to deploy the program on.

        :param script_contents: Deployment script contents.

        :param runtime_dir: Runtime dir to remove.

    """
    log = get_logger(__name__)
    with stage_debug(log, "Uploading entry point."):
        script_path = upload_entry_point(contents=script_contents,
                                         node=node,
                                         runtime_dir=runtime_dir)

    with stage_debug(log, "Executing the deployment command."):
        output = node.run(get_deployment_command(
            script_path=script_path))

    lines = output.splitlines()
    pid = int(lines[0])
    return GenericDeployment(node=node,
                             pid=pid,
                             runtime_dir=runtime_dir)
示例#4
0
def determine_ports_for_nodes(allocation_id: int,
                              hostnames: List[str],
                              config: ClusterConfig,
                              raise_on_missing: bool) -> List[int]:
    """Tries to determine sshd ports for each node.
        Removes the file if no exception was raised.

        :param allocation_id: Job id.

        :param hostnames: List of hostnames.

        :param config: Cluster config.

        :param raise_on_missing: Raise an exception if port could not
                                 be determined.

    """
    log = get_logger(__name__)
    with stage_debug(log, "Fetching port info for sshd."):
        port_info_contents = fetch_port_info(allocation_id=allocation_id,
                                             config=config)
        port_info = SshdPortInfo(contents=port_info_contents)

    with stage_debug(log, "Determining ports for each host."):
        ports = [port_info.get_port(host=host,
                                    raise_on_missing=raise_on_missing)
                 for host in hostnames]

    with stage_debug(log, "Removing the file containing sshd port info."):
        remove_port_info(allocation_id, config=config)

    return ports
示例#5
0
    def cancel(self):
        """Kills the program and all its child processes.
            Removes the runtime dir.
            Raises an exception if the top level process is still running
            after :attr:`.CANCEL_TIMEOUT` seconds.

            :raises RuntimeError: If the program is still running.

        """

        parent_pid = self._pid
        node = self._node
        tree = ' '.join([str(pid) for pid in ptree(pid=parent_pid, node=node)])

        def cancel_task():
            """Kills the process tree and fails if the parent is still
                running after a timeout."""
            node.run("kill {tree}"
                     "; kill -0 {parent_pid} && exit 1 || exit 0".format(
                         tree=tree, parent_pid=parent_pid))

        log = get_logger(__name__)
        with remove_runtime_dir_on_exit(node=self._node,
                                        runtime_dir=self._runtime_dir):
            with stage_debug(log, "Killing the process tree for pid: %d",
                             self._pid):
                retry_with_config(fun=cancel_task,
                                  name=Retry.CANCEL_DEPLOYMENT,
                                  config=self._node.config)
示例#6
0
def allocate_slurm_nodes(parameters: AllocationParameters,
                         config: ClusterConfig) -> Nodes:
    """Tries to allocate nodes using Slurm.

       :param parameters:   Allocation parameters.

       :param config: Config for the cluster to allocate nodes on.

    """
    args = SbatchArguments(params=parameters)

    log = get_logger(__name__)
    with stage_debug(log, "Executing sbatch on access node."):
        access_node = get_access_node(config=config)
        job_id, entry_point_script_path = run_sbatch(args=args,
                                                     node=access_node)

    def run_squeue_task() -> SqueueResult:
        job_squeue = run_squeue(node=access_node)
        return job_squeue[job_id]

    try:
        with stage_debug(log, "Obtaining info about job %d using squeue.",
                         job_id):
            job = retry_with_config(run_squeue_task,
                                    name=Retry.SQUEUE_AFTER_SBATCH,
                                    config=config)
    except Exception as e:  # noqa, pylint: disable=broad-except
        run_scancel(job_id=job_id, node=access_node)
        raise RuntimeError("Unable to obtain job info"
                           " after allocation.") from e

    node_count = job.node_count
    nodes = [NodeImpl(config=config) for _ in range(node_count)]

    allocation = SlurmAllocation(
        job_id=job_id,
        access_node=access_node,
        nodes=nodes,
        entry_point_script_path=entry_point_script_path,
        parameters=parameters)

    return NodesImpl(nodes=nodes,
                     allocation=allocation)
示例#7
0
def create_log_file(node: Node, runtime_dir: str) -> str:
    """Creates a log file in the runtime dir.

        :param node: Node to create the log file on.

        :param runtime_dir: Runtime dir path.

    """
    log = get_logger(__name__)
    log_file = '{runtime_dir}/log'.format(runtime_dir=runtime_dir)
    with stage_debug(log, "Creating log file: '%s'.", log_file):
        node.run("touch '{}'".format(log_file))
        return log_file
def discard_expired_deployments(
        deployments: DeploymentDefinitions) -> DeploymentDefinitions:  # noqa
    """Returns a new object that does not contain deployments
        that have expired, or will expire in the near future.

        :param deployments: Deployments to examine.

    """
    log = get_logger(__name__)

    with stage_debug(log, "Discarding expired deployments."):
        now = utc_now()
        log.debug("Now: %s", now)
        log.debug(
            "Will discard after the %d second mark"
            " before the expiration date.", DISCARD_DELTA_SECONDS)

        discard_now = utc_now() + timedelta(seconds=DISCARD_DELTA_SECONDS)

        unexpired_nodes = {}
        for uuid, node in deployments.nodes.items():
            if node.expiration_date < discard_now:
                log.warning(
                    "Discarding a synchronized allocation deployment,"
                    " because it has expired: %s", uuid)
            else:
                unexpired_nodes[uuid] = node

        unexpired_jupyter_deployments = {}
        for uuid, jupyter in deployments.jupyter_deployments.items():
            if jupyter.expiration_date < discard_now:
                log.warning(
                    "Discarding a Jupyter deployment,"
                    " because it has expired: %s", uuid)
            else:
                unexpired_jupyter_deployments[uuid] = jupyter

        unexpired_dask_deployments = {}
        for uuid, dask in deployments.dask_deployments.items():
            if dask.expiration_date < discard_now:
                log.warning(
                    "Discarding a Dask deployment,"
                    " because it has expired: %s", uuid)
            else:
                unexpired_dask_deployments[uuid] = dask

        return DeploymentDefinitions(
            nodes=unexpired_nodes,
            jupyter_deployments=unexpired_jupyter_deployments,
            dask_deployments=unexpired_dask_deployments)
def deserialize_deployment_definitions_from_cluster(
        node: NodeInternal) -> DeploymentDefinitions:  # noqa
    """Downloads deployment definitions from the cluster.

        :param node: Node to deserialize deployment definitions from.

    """
    log = get_logger(__name__)
    with stage_debug(log, "Deserializing deployment definitions"
                     " from cluster."):
        path = get_deployment_definitions_file_path(node=node)
        file_contents = get_file_from_node(node=node, remote_path=path)

        serialized = json.loads(file_contents)
        return DeploymentDefinitions.deserialize(serialized=serialized)
示例#10
0
def get_file_from_node(node: NodeInternal, remote_path: str) -> str:
    """Runs a task on the node that downloads a file and returns its contents.

        :param node: Node to download the file from.

        :param remote_path: Remote file path.

    """
    log = get_logger(__name__)

    @fabric.decorators.task
    def file_upload_task():
        with capture_fabric_output_to_log():
            return get_remote_file(remote_path=remote_path)

    with stage_debug(log, "Getting file from node %s: %s", node.host,
                     remote_path):
        return node.run_task(task=file_upload_task)
def serialize_deployment_definitions_to_cluster(
        node: NodeInternal, deployments: DeploymentDefinitions):  # noqa
    """Uploads deployment definitions to the cluster, replacing
        any definitions file already there.

        :param node: Node to serialize definitions to.

        :param deployments: Deployments to upload.

    """
    log = get_logger(__name__)
    with stage_debug(log, "Serializing deployment definitions to cluster."):
        serialized = deployments.serialize()
        file_contents = json.dumps(serialized, sort_keys=True, indent=4)
        parent_path = get_deployment_definitions_parent_path(node=node)
        node.run("mkdir -p {parent_path}"
                 " && chmod 700 {parent_path}".format(parent_path=parent_path))
        path = get_deployment_definitions_file_path(node=node)
        put_file_on_node(node=node, remote_path=path, contents=file_contents)
示例#12
0
def put_file_on_node(node: NodeInternal, remote_path: str, contents: str):
    """Runs a task on the node that uploads a file.

        :param node: Node to upload the file to.

        :param remote_path: Remote file path.

        :param contents: File contents.

    """
    log = get_logger(__name__)
    with stage_debug(log, "Putting file on node %s: %s", node.host,
                     remote_path):

        @fabric.decorators.task
        def file_upload_task():
            with capture_fabric_output_to_log():
                put_remote_file(remote_path=remote_path, contents=contents)

        node.run_task(task=file_upload_task)
def deserialize_environment_from_cluster(cluster: Cluster,
                                         path: Optional[
                                             str] = None) -> Environment:  # noqa, pylint: disable=line-too-long,bad-whitespace
    """Loads the environment from remote file.

        See :func:`.pull_environment`.

        :param cluster: Cluster to pull the environment from.

        :param path: Remote file path.
                     Default: Remote IDACT_CONFIG_PATH environment variable,
                     or ~/.idact.conf

    """
    log = get_logger(__name__)
    with stage_debug(log, "Deserializing the environment from cluster."):
        node = cluster.get_access_node()
        assert isinstance(node, NodeInternal)
        path = get_remote_environment_path(node=node, path=path)

        file_contents = get_file_from_node(node=node, remote_path=path)
        return deserialize_environment(text=file_contents)
def serialize_environment_to_cluster(environment: Environment,
                                     cluster: Cluster, path: Optional[str]):
    """Dumps the environment to remote file.

        See :func:`.push_environment`.

        :param environment: Environment to save.

        :param cluster: Cluster to push the environment to.

        :param path: Remote file path.
                     Default: IDACT_CONFIG_PATH environment variable,
                     or ~/.idact.conf

    """
    log = get_logger(__name__)
    with stage_debug(log, "Serializing the environment to cluster."):
        node = cluster.get_access_node()
        assert isinstance(node, NodeInternal)
        path = get_remote_environment_path(node=node, path=path)

        file_contents = serialize_environment(environment)
        put_file_on_node(node=node, remote_path=path, contents=file_contents)
示例#15
0
    def tunnel(self, there: int, here: Optional[int] = None) -> TunnelInternal:
        try:
            log = get_logger(__name__)
            with stage_debug(log, "Opening tunnel %s -> %d to %s", here, there,
                             self):
                self._ensure_allocated()

                here, there = validate_tunnel_ports(here=here, there=there)

                first_try = [True]

                def get_bindings_and_build_tunnel() -> TunnelInternal:
                    bindings = get_bindings_with_single_gateway(
                        here=here if first_try[0] else ANY_TUNNEL_PORT,
                        node_host=self._host,
                        node_port=self._port,
                        there=there)
                    first_try[0] = False
                    return build_tunnel(config=self._config,
                                        bindings=bindings,
                                        ssh_password=env.password,
                                        ssh_pkey=env.key_filename)

                with authenticate(host=self._host,
                                  port=self._port,
                                  config=self._config):
                    if here == ANY_TUNNEL_PORT:
                        return get_bindings_and_build_tunnel()
                    return retry_with_config(
                        get_bindings_and_build_tunnel,
                        name=Retry.TUNNEL_TRY_AGAIN_WITH_ANY_PORT,
                        config=self._config)

        except RuntimeError as e:
            raise RuntimeError(
                "Unable to tunnel {there} on node '{host}'.".format(
                    there=there, host=self._host)) from e
示例#16
0
def install_key(config: ClusterConfig, authorized_keys: Optional[str] = None):
    """Installs the public key on the access node.

        If it was not generated or it's missing, generates one.
        Expects password authentication to have already been performed.

        :param config: Cluster config for connection.

        :param authorized_keys: Path to authorized_keys.
                                Default: `~/.ssh/authorized_keys`
    """
    log = get_logger(__name__)

    authorized_keys_path = (authorized_keys
                            if authorized_keys else ".ssh/authorized_keys")

    with stage_debug(log, "Attempting to determine key path."):
        public_key_path = try_getting_public_key_from_config(config=config,
                                                             log=log)

    if public_key_path is None:
        with stage_debug(log, "Generating key."):
            config.key = generate_key(host=config.host)
            public_key_path = get_public_key_location(
                private_key_location=config.key)

    with stage_debug(log, "Reading key."):
        public_key = read_public_key(public_key_path=public_key_path)

    @fabric.decorators.task
    def task():
        """Creates the .ssh dir with proper permissions. Adds the public key
            to the authorized keys file if it's not been added already."""
        with stage_debug(log, "Creating authorized keys file: %s",
                         authorized_keys_path):
            with capture_fabric_output_to_log():
                run("mkdir -p ~/.ssh"
                    " && chmod 700 ~/.ssh"
                    " && touch '{authorized_keys_path}'"
                    " && chmod 644 '{authorized_keys_path}'".format(
                        authorized_keys_path=authorized_keys_path))

        with stage_debug(log, "Downloading authorized keys file."):
            with capture_fabric_output_to_log():
                authorized_keys_fd = BytesIO()
                get(authorized_keys_path, authorized_keys_fd)
                authorized_keys_contents = \
                    authorized_keys_fd.getvalue().decode('ascii').splitlines()

        if public_key not in authorized_keys_contents:
            with stage_debug(log, "Appending to authorized keys file."):
                log.debug("Warning: This operation is not atomic on NFS.")
                with capture_fabric_output_to_log():
                    run("echo '{public_key}' >> {authorized_keys_path}".format(
                        public_key=public_key,
                        authorized_keys_path=authorized_keys_path))

        with stage_debug(log, "Checking if public key was added."):
            with capture_fabric_output_to_log():
                run("grep '{public_key}' '{authorized_keys_path}'".format(
                    public_key=public_key,
                    authorized_keys_path=authorized_keys_path))

    with raise_on_remote_fail(exception=RuntimeError):
        fabric.tasks.execute(task)
示例#17
0
def deploy_dask_scheduler(node: NodeInternal) -> DaskSchedulerDeployment:
    """Deploys a Dask scheduler on the node.

        :param node: Node to deploy on.

    """
    log = get_logger(__name__)

    with ExitStack() as stack:
        with stage_debug(log, "Creating a runtime dir."):
            runtime_dir = create_runtime_dir(node=node)
            stack.enter_context(
                remove_runtime_dir_on_failure(node=node,
                                              runtime_dir=runtime_dir))

        with stage_debug(log, "Obtaining free remote ports."):
            remote_port, bokeh_port = get_free_remote_ports(count=2, node=node)

        with stage_debug(log, "Creating a scratch subdirectory."):
            scratch_subdir = create_scratch_subdir(node=node)

        log_file = create_log_file(node=node, runtime_dir=runtime_dir)

        script_contents = get_scheduler_deployment_script(
            remote_port=remote_port,
            bokeh_port=bokeh_port,
            scratch_subdir=scratch_subdir,
            log_file=log_file,
            config=node.config)

        log.debug("Deployment script contents: %s", script_contents)

        with stage_debug(log, "Deploying script."):
            deployment = deploy_generic(node=node,
                                        script_contents=script_contents,
                                        runtime_dir=runtime_dir)
            stack.enter_context(cancel_on_failure(deployment))

        @fabric.decorators.task
        def extract_address_from_log() -> str:
            """Extracts scheduler address from a log file."""
            with capture_fabric_output_to_log():
                output = get_remote_file(remote_path=log_file)
            log.debug("Log file: %s", output)
            return extract_address_from_output(output=output)

        with stage_debug(log, "Obtaining scheduler address."):
            address = retry_with_config(
                lambda: node.run_task(task=extract_address_from_log),
                name=Retry.GET_SCHEDULER_ADDRESS,
                config=node.config)

        with stage_debug(log, "Opening a tunnel to scheduler."):
            tunnel = node.tunnel(here=remote_port, there=remote_port)
            stack.enter_context(close_tunnel_on_failure(tunnel))
            log.debug("Scheduler local port: %d", tunnel.here)

        with stage_debug(log, "Opening a tunnel to bokeh diagnostics server."):
            bokeh_tunnel = node.tunnel(here=bokeh_port, there=bokeh_port)
            stack.enter_context(close_tunnel_on_failure(bokeh_tunnel))
            log.debug("Diagnostics local port: %d", bokeh_tunnel.here)

        return DaskSchedulerDeployment(deployment=deployment,
                                       tunnel=tunnel,
                                       bokeh_tunnel=bokeh_tunnel,
                                       address=address)
示例#18
0
def deploy_jupyter(node: NodeInternal, local_port: int) -> JupyterDeployment:
    """Deploys a Jupyter Notebook server on the node, and creates a tunnel
        to a local port.

        :param node: Node to deploy Jupyter Notebook on.

        :param local_port: Local tunnel binding port.

    """
    log = get_logger(__name__)

    with stage_debug(log, "Creating a runtime dir."):
        runtime_dir = create_runtime_dir(node=node)

    with stage_debug(log, "Obtaining a free remote port."):
        remote_port = get_free_remote_port(node=node)

    if node.config.use_jupyter_lab:
        jupyter_version = 'lab'
    else:
        jupyter_version = 'notebook'

    deployment_commands = [
        'export JUPYTER_RUNTIME_DIR="{runtime_dir}"'.format(
            runtime_dir=runtime_dir),
        get_command_to_append_local_bin()
    ]

    log_file = create_log_file(node=node, runtime_dir=runtime_dir)

    deployment_commands.append('jupyter {jupyter_version}'
                               ' --ip 127.0.0.1'
                               ' --port "{remote_port}"'
                               ' --no-browser > {log_file} 2>&1'.format(
                                   jupyter_version=jupyter_version,
                                   remote_port=remote_port,
                                   log_file=log_file))

    script_contents = get_deployment_script_contents(
        deployment_commands=deployment_commands,
        setup_actions=node.config.setup_actions.jupyter)

    log.debug("Deployment script contents: %s", script_contents)
    with stage_debug(log, "Deploying script."):
        deployment = deploy_generic(node=node,
                                    script_contents=script_contents,
                                    runtime_dir=runtime_dir)

    with cancel_on_failure(deployment):

        @fabric.decorators.task
        def load_nbserver_json():
            """Loads notebook parameters from a json file."""
            with capture_fabric_output_to_log():
                with cd(runtime_dir):
                    nbserver_json_path = run(
                        "readlink -vf $PWD/nbserver-*.json").splitlines()[0]
                run("cat '{log_file}' || exit 0".format(log_file=log_file))
                run("cat '{nbserver_json_path}' > /dev/null".format(
                    nbserver_json_path=nbserver_json_path))
                nbserver_json_str = get_remote_file(nbserver_json_path)
                nbserver_json = json.loads(nbserver_json_str)
                return int(nbserver_json['port']), nbserver_json['token']

        with stage_debug(log, "Obtaining info about notebook from json file."):
            actual_port, token = retry_with_config(
                lambda: node.run_task(task=load_nbserver_json),
                name=Retry.JUPYTER_JSON,
                config=node.config)

        with stage_debug(log, "Opening a tunnel to notebook."):
            tunnel = node.tunnel(there=actual_port, here=local_port)

        return JupyterDeploymentImpl(deployment=deployment,
                                     tunnel=tunnel,
                                     token=token)
def discard_non_functional_deployments(
        deployments: SynchronizedDeployments
) -> SynchronizedDeployments:  # noqa
    """Discards deployments that were not expired, but are no longer
        functional, e.g. were cancelled."""

    log = get_logger(__name__)
    all_nodes = []
    for nodes in deployments.nodes:
        with stage_debug(
                log, "Checking whether allocation deployment"
                " is functional: %s.", nodes):
            nodes_functional = not nodes.waited or nodes.running()

        if nodes_functional:
            all_nodes.append(nodes)
        else:
            log.info(
                "Discarding an allocation deployment,"
                " because it is no longer functional: %s.", nodes)

    all_jupyter_deployments = []
    for jupyter in deployments.jupyter_deployments:
        jupyter_impl = jupyter
        assert isinstance(jupyter_impl, JupyterDeploymentImpl)
        with stage_debug(
                log, "Checking whether Jupyter deployment"
                " is functional: %s.", jupyter_impl):
            try:
                validate_tunnel_http_connection(tunnel=jupyter_impl.tunnel)
                all_jupyter_deployments.append(jupyter_impl)
            except Exception:  # pylint: disable=broad-except
                log.info(
                    "Discarding a Jupyter deployment,"
                    " because it is no longer functional: %s.", jupyter_impl)
                log.debug("Exception", exc_info=1)
                with stage_debug(log,
                                 "Cancelling tunnel to discarded notebook."):
                    jupyter_impl.cancel_local()

    all_dask_deployments = []
    for dask in deployments.dask_deployments:
        dask_impl = dask
        assert isinstance(dask_impl, DaskDeploymentImpl)
        with stage_debug(
                log, "Checking whether Dask deployment"
                " is functional: %s.", dask_impl):
            try:
                validate_tunnel_http_connection(
                    tunnel=dask_impl.scheduler.bokeh_tunnel)
                all_dask_deployments.append(dask_impl)
            except Exception:  # pylint: disable=broad-except
                log.info(
                    "Discarding a Dask deployment,"
                    " because it is no longer functional: %s.", dask_impl)
                log.debug("Exception", exc_info=1)
                with stage_debug(
                        log, "Cancelling tunnels for discarded Dask"
                        " deployment."):
                    dask_impl.cancel_local()

    return SynchronizedDeploymentsImpl(
        nodes=all_nodes,
        jupyter_deployments=all_jupyter_deployments,
        dask_deployments=all_dask_deployments)
示例#20
0
def deploy_dask_worker(node: NodeInternal,
                       scheduler: DaskSchedulerDeployment) -> DaskWorkerDeployment:  # noqa, pylint: disable=line-too-long
    """Deploys a Dask worker on the node.

        :param node: Node to deploy on.

        :param scheduler: Already deployed scheduler.

    """
    log = get_logger(__name__)

    with ExitStack() as stack:
        with stage_debug(log, "Creating a runtime dir."):
            runtime_dir = create_runtime_dir(node=node)
            stack.enter_context(
                remove_runtime_dir_on_failure(node=node,
                                              runtime_dir=runtime_dir))

        with stage_debug(log, "Obtaining a free remote port."):
            bokeh_port = get_free_remote_port(node=node)

        with stage_debug(log, "Creating a scratch subdirectory."):
            scratch_subdir = create_scratch_subdir(node=node)

        log_file = create_log_file(node=node, runtime_dir=runtime_dir)

        script_contents = get_worker_deployment_script(
            scheduler_address=scheduler.address,
            bokeh_port=bokeh_port,
            scratch_subdir=scratch_subdir,
            cores=node.cores,
            memory_limit=node.memory,
            log_file=log_file,
            config=node.config)

        log.debug("Deployment script contents: %s", script_contents)

        with stage_debug(log, "Deploying script."):
            deployment = deploy_generic(node=node,
                                        script_contents=script_contents,
                                        runtime_dir=runtime_dir)
            stack.enter_context(cancel_on_failure(deployment))

        @fabric.decorators.task
        def validate_worker_started_from_log():
            """Checks that the worker has started correctly based on
                the log file."""
            with capture_fabric_output_to_log():
                output = get_remote_file(remote_path=log_file)
            log.debug("Log file: %s", output)
            validate_worker_started(output=output)

        with stage_debug(log, "Checking if worker started."):
            retry_with_config(
                lambda: node.run_task(task=validate_worker_started_from_log),
                name=Retry.CHECK_WORKER_STARTED,
                config=node.config)

        with stage_debug(log, "Opening a tunnel to bokeh diagnostics server."):
            bokeh_tunnel = node.tunnel(here=bokeh_port, there=bokeh_port)
            stack.enter_context(close_tunnel_on_failure(bokeh_tunnel))
            log.debug("Diagnostics local port: %d", bokeh_tunnel.here)

        return DaskWorkerDeployment(deployment=deployment,
                                    bokeh_tunnel=bokeh_tunnel)
示例#21
0
def build_tunnel(config: ClusterConfig,
                 bindings: List[Binding],
                 ssh_password: Optional[str] = None,
                 ssh_pkey: Optional[str] = None) -> TunnelInternal:
    """Builds a multi-hop tunnel from a sequence of bindings.

        :param config:   Cluster config.

        :param bindings: Sequence of bindings, starting with the local binding.

        :param ssh_password: Ssh password.

        :param ssh_pkey: Ssh private key.
    """
    if len(bindings) < 2:
        raise ValueError("At least one local and one remote binding"
                         " is required to build a tunnel")

    with ExitStack() as stack:
        tunnels = []

        log = get_logger(__name__)
        log.debug("Ssh username: %s", config.user)
        log.debug("Using password: %r", ssh_password is not None)
        log.debug("Using key file: %s", ssh_pkey)

        logger = get_debug_logger("{}/Tunnels".format(__name__))

        # First hop if not the only one
        if len(bindings) != 2:
            with stage_debug(log, "Adding first hop."):
                ssh_address_or_host = (config.host, config.port)
                local_bind_address = ANY_ADDRESS
                remote_bind_address = bindings[1].as_tuple()
                log.debug("Ssh address is %s", ssh_address_or_host)
                log.debug("Local bind address is %s", local_bind_address)
                log.debug("Remote bind address is %s", remote_bind_address)

                def create_first_tunnel():
                    return FirstHopTunnel(
                        forwarder=SSHTunnelForwarder(
                            ssh_address_or_host,
                            ssh_config_file=None,
                            ssh_username=config.user,
                            ssh_password=ssh_password,
                            ssh_pkey=ssh_pkey,
                            local_bind_address=local_bind_address,
                            remote_bind_address=remote_bind_address,
                            set_keepalive=TUNNEL_KEEPALIVE,
                            allow_agent=False,
                            logger=logger),
                        there=bindings[1].port,
                        config=config)

                tunnel = retry_with_config(create_first_tunnel,
                                           name=Retry.OPEN_TUNNEL,
                                           config=config)
                stack.enter_context(close_tunnel_on_failure(tunnel))
                tunnels.append(tunnel)

        # Middle hops if any
        prev_tunnel = tunnels[0] if tunnels else None
        for i, next_binding in enumerate(bindings[2:-1]):
            with stage_debug(log, "Adding middle hop %d.", i):
                # Connect through previous tunnel
                ssh_address_or_host = (
                    "127.0.0.1",
                    prev_tunnel.forwarder.local_bind_port)
                local_bind_address = ANY_ADDRESS
                remote_bind_address = next_binding.as_tuple()
                log.debug("Ssh address is %s", ssh_address_or_host)
                log.debug("Local bind address is %s", local_bind_address)
                log.debug("Remote bind address is %s", remote_bind_address)

                def create_middle_tunnel():
                    next_binding_port = next_binding.port  # noqa, pylint: disable=cell-var-from-loop, line-too-long
                    return FirstHopTunnel(
                        forwarder=SSHTunnelForwarder(
                            ssh_address_or_host,
                            ssh_config_file=None,
                            ssh_username=config.user,
                            ssh_password=ssh_password,
                            ssh_pkey=ssh_pkey,
                            local_bind_address=local_bind_address,
                            remote_bind_address=remote_bind_address,
                            set_keepalive=TUNNEL_KEEPALIVE,
                            allow_agent=False,
                            logger=logger),
                        there=next_binding_port,
                        config=config)

                next_tunnel = retry_with_config(create_middle_tunnel,
                                                name=Retry.OPEN_TUNNEL,
                                                config=config)
                stack.enter_context(close_tunnel_on_failure(next_tunnel))
                tunnels.append(next_tunnel)
                prev_tunnel = next_tunnel

        with stage_debug(log, "Adding last hop."):
            # Last hop
            last_hop_port = (config.port
                             if len(bindings) == 2
                             else tunnels[-1].forwarder.local_bind_port)

            ssh_address_or_host = ("127.0.0.1", last_hop_port)
            local_bind_address = bindings[0].as_tuple()
            remote_bind_address = bindings[-1].as_tuple()
            log.debug("Ssh address is %s", ssh_address_or_host)
            log.debug("Local bind address is %s", local_bind_address)
            log.debug("Remote bind address is %s", remote_bind_address)

            def create_last_tunnel():
                return FirstHopTunnel(
                    forwarder=SSHTunnelForwarder(
                        ssh_address_or_host,
                        ssh_config_file=None,
                        ssh_username=config.user,
                        ssh_password=ssh_password,
                        ssh_pkey=ssh_pkey,
                        local_bind_address=local_bind_address,
                        remote_bind_address=remote_bind_address,
                        set_keepalive=TUNNEL_KEEPALIVE,
                        allow_agent=False,
                        logger=logger),
                    there=bindings[-1].port,
                    config=config)

            last_tunnel = retry_with_config(create_last_tunnel,
                                            name=Retry.OPEN_TUNNEL,
                                            config=config)
            stack.enter_context(close_tunnel_on_failure(last_tunnel))
            tunnels.append(last_tunnel)

        if len(tunnels) == 1:
            return tunnels[0]
        return MultiHopTunnel(tunnels=tunnels,
                              config=config)