def task(): """Creates the .ssh dir with proper permissions. Adds the public key to the authorized keys file if it's not been added already.""" with stage_debug(log, "Creating authorized keys file: %s", authorized_keys_path): with capture_fabric_output_to_log(): run("mkdir -p ~/.ssh" " && chmod 700 ~/.ssh" " && touch '{authorized_keys_path}'" " && chmod 644 '{authorized_keys_path}'".format( authorized_keys_path=authorized_keys_path)) with stage_debug(log, "Downloading authorized keys file."): with capture_fabric_output_to_log(): authorized_keys_fd = BytesIO() get(authorized_keys_path, authorized_keys_fd) authorized_keys_contents = \ authorized_keys_fd.getvalue().decode('ascii').splitlines() if public_key not in authorized_keys_contents: with stage_debug(log, "Appending to authorized keys file."): log.debug("Warning: This operation is not atomic on NFS.") with capture_fabric_output_to_log(): run("echo '{public_key}' >> {authorized_keys_path}".format( public_key=public_key, authorized_keys_path=authorized_keys_path)) with stage_debug(log, "Checking if public key was added."): with capture_fabric_output_to_log(): run("grep '{public_key}' '{authorized_keys_path}'".format( public_key=public_key, authorized_keys_path=authorized_keys_path))
def task(): """Creates the entry point dir and file. Fails if it couldn't be created.""" with capture_fabric_output_to_log(): run("mkdir -p {entry_point_location}" " && chmod 700 {entry_point_location}".format( entry_point_location=entry_point_location)) file_name = get_random_file_name( length=ENTRY_POINT_FILE_NAME_LENGTH) file_path = run("echo {entry_point_location}/{file_name}".format( entry_point_location=entry_point_location, file_name=file_name)) file_exists = exists(file_path) if file_exists: log.warning("Overwriting randomly named entry point file:" " %s", file_path) with stage_debug(log, "Uploading the entry point script."): with capture_fabric_output_to_log(): real_path = run("echo {file_path}".format(file_path=file_path)) file = BytesIO(contents.encode('ascii')) put(file, real_path, mode=0o700) with stage_debug(log, "Checking the entry point script was uploaded."): with capture_fabric_output_to_log(): run("cat {real_path} > /dev/null".format(real_path=real_path)) result.append(real_path)
def deploy_generic(node: NodeInternal, script_contents: str, runtime_dir: str) -> GenericDeployment: """Deploys a program on the node. :param node: Node to deploy the program on. :param script_contents: Deployment script contents. :param runtime_dir: Runtime dir to remove. """ log = get_logger(__name__) with stage_debug(log, "Uploading entry point."): script_path = upload_entry_point(contents=script_contents, node=node, runtime_dir=runtime_dir) with stage_debug(log, "Executing the deployment command."): output = node.run(get_deployment_command( script_path=script_path)) lines = output.splitlines() pid = int(lines[0]) return GenericDeployment(node=node, pid=pid, runtime_dir=runtime_dir)
def determine_ports_for_nodes(allocation_id: int, hostnames: List[str], config: ClusterConfig, raise_on_missing: bool) -> List[int]: """Tries to determine sshd ports for each node. Removes the file if no exception was raised. :param allocation_id: Job id. :param hostnames: List of hostnames. :param config: Cluster config. :param raise_on_missing: Raise an exception if port could not be determined. """ log = get_logger(__name__) with stage_debug(log, "Fetching port info for sshd."): port_info_contents = fetch_port_info(allocation_id=allocation_id, config=config) port_info = SshdPortInfo(contents=port_info_contents) with stage_debug(log, "Determining ports for each host."): ports = [port_info.get_port(host=host, raise_on_missing=raise_on_missing) for host in hostnames] with stage_debug(log, "Removing the file containing sshd port info."): remove_port_info(allocation_id, config=config) return ports
def cancel(self): """Kills the program and all its child processes. Removes the runtime dir. Raises an exception if the top level process is still running after :attr:`.CANCEL_TIMEOUT` seconds. :raises RuntimeError: If the program is still running. """ parent_pid = self._pid node = self._node tree = ' '.join([str(pid) for pid in ptree(pid=parent_pid, node=node)]) def cancel_task(): """Kills the process tree and fails if the parent is still running after a timeout.""" node.run("kill {tree}" "; kill -0 {parent_pid} && exit 1 || exit 0".format( tree=tree, parent_pid=parent_pid)) log = get_logger(__name__) with remove_runtime_dir_on_exit(node=self._node, runtime_dir=self._runtime_dir): with stage_debug(log, "Killing the process tree for pid: %d", self._pid): retry_with_config(fun=cancel_task, name=Retry.CANCEL_DEPLOYMENT, config=self._node.config)
def allocate_slurm_nodes(parameters: AllocationParameters, config: ClusterConfig) -> Nodes: """Tries to allocate nodes using Slurm. :param parameters: Allocation parameters. :param config: Config for the cluster to allocate nodes on. """ args = SbatchArguments(params=parameters) log = get_logger(__name__) with stage_debug(log, "Executing sbatch on access node."): access_node = get_access_node(config=config) job_id, entry_point_script_path = run_sbatch(args=args, node=access_node) def run_squeue_task() -> SqueueResult: job_squeue = run_squeue(node=access_node) return job_squeue[job_id] try: with stage_debug(log, "Obtaining info about job %d using squeue.", job_id): job = retry_with_config(run_squeue_task, name=Retry.SQUEUE_AFTER_SBATCH, config=config) except Exception as e: # noqa, pylint: disable=broad-except run_scancel(job_id=job_id, node=access_node) raise RuntimeError("Unable to obtain job info" " after allocation.") from e node_count = job.node_count nodes = [NodeImpl(config=config) for _ in range(node_count)] allocation = SlurmAllocation( job_id=job_id, access_node=access_node, nodes=nodes, entry_point_script_path=entry_point_script_path, parameters=parameters) return NodesImpl(nodes=nodes, allocation=allocation)
def create_log_file(node: Node, runtime_dir: str) -> str: """Creates a log file in the runtime dir. :param node: Node to create the log file on. :param runtime_dir: Runtime dir path. """ log = get_logger(__name__) log_file = '{runtime_dir}/log'.format(runtime_dir=runtime_dir) with stage_debug(log, "Creating log file: '%s'.", log_file): node.run("touch '{}'".format(log_file)) return log_file
def discard_expired_deployments( deployments: DeploymentDefinitions) -> DeploymentDefinitions: # noqa """Returns a new object that does not contain deployments that have expired, or will expire in the near future. :param deployments: Deployments to examine. """ log = get_logger(__name__) with stage_debug(log, "Discarding expired deployments."): now = utc_now() log.debug("Now: %s", now) log.debug( "Will discard after the %d second mark" " before the expiration date.", DISCARD_DELTA_SECONDS) discard_now = utc_now() + timedelta(seconds=DISCARD_DELTA_SECONDS) unexpired_nodes = {} for uuid, node in deployments.nodes.items(): if node.expiration_date < discard_now: log.warning( "Discarding a synchronized allocation deployment," " because it has expired: %s", uuid) else: unexpired_nodes[uuid] = node unexpired_jupyter_deployments = {} for uuid, jupyter in deployments.jupyter_deployments.items(): if jupyter.expiration_date < discard_now: log.warning( "Discarding a Jupyter deployment," " because it has expired: %s", uuid) else: unexpired_jupyter_deployments[uuid] = jupyter unexpired_dask_deployments = {} for uuid, dask in deployments.dask_deployments.items(): if dask.expiration_date < discard_now: log.warning( "Discarding a Dask deployment," " because it has expired: %s", uuid) else: unexpired_dask_deployments[uuid] = dask return DeploymentDefinitions( nodes=unexpired_nodes, jupyter_deployments=unexpired_jupyter_deployments, dask_deployments=unexpired_dask_deployments)
def deserialize_deployment_definitions_from_cluster( node: NodeInternal) -> DeploymentDefinitions: # noqa """Downloads deployment definitions from the cluster. :param node: Node to deserialize deployment definitions from. """ log = get_logger(__name__) with stage_debug(log, "Deserializing deployment definitions" " from cluster."): path = get_deployment_definitions_file_path(node=node) file_contents = get_file_from_node(node=node, remote_path=path) serialized = json.loads(file_contents) return DeploymentDefinitions.deserialize(serialized=serialized)
def get_file_from_node(node: NodeInternal, remote_path: str) -> str: """Runs a task on the node that downloads a file and returns its contents. :param node: Node to download the file from. :param remote_path: Remote file path. """ log = get_logger(__name__) @fabric.decorators.task def file_upload_task(): with capture_fabric_output_to_log(): return get_remote_file(remote_path=remote_path) with stage_debug(log, "Getting file from node %s: %s", node.host, remote_path): return node.run_task(task=file_upload_task)
def serialize_deployment_definitions_to_cluster( node: NodeInternal, deployments: DeploymentDefinitions): # noqa """Uploads deployment definitions to the cluster, replacing any definitions file already there. :param node: Node to serialize definitions to. :param deployments: Deployments to upload. """ log = get_logger(__name__) with stage_debug(log, "Serializing deployment definitions to cluster."): serialized = deployments.serialize() file_contents = json.dumps(serialized, sort_keys=True, indent=4) parent_path = get_deployment_definitions_parent_path(node=node) node.run("mkdir -p {parent_path}" " && chmod 700 {parent_path}".format(parent_path=parent_path)) path = get_deployment_definitions_file_path(node=node) put_file_on_node(node=node, remote_path=path, contents=file_contents)
def put_file_on_node(node: NodeInternal, remote_path: str, contents: str): """Runs a task on the node that uploads a file. :param node: Node to upload the file to. :param remote_path: Remote file path. :param contents: File contents. """ log = get_logger(__name__) with stage_debug(log, "Putting file on node %s: %s", node.host, remote_path): @fabric.decorators.task def file_upload_task(): with capture_fabric_output_to_log(): put_remote_file(remote_path=remote_path, contents=contents) node.run_task(task=file_upload_task)
def deserialize_environment_from_cluster(cluster: Cluster, path: Optional[ str] = None) -> Environment: # noqa, pylint: disable=line-too-long,bad-whitespace """Loads the environment from remote file. See :func:`.pull_environment`. :param cluster: Cluster to pull the environment from. :param path: Remote file path. Default: Remote IDACT_CONFIG_PATH environment variable, or ~/.idact.conf """ log = get_logger(__name__) with stage_debug(log, "Deserializing the environment from cluster."): node = cluster.get_access_node() assert isinstance(node, NodeInternal) path = get_remote_environment_path(node=node, path=path) file_contents = get_file_from_node(node=node, remote_path=path) return deserialize_environment(text=file_contents)
def serialize_environment_to_cluster(environment: Environment, cluster: Cluster, path: Optional[str]): """Dumps the environment to remote file. See :func:`.push_environment`. :param environment: Environment to save. :param cluster: Cluster to push the environment to. :param path: Remote file path. Default: IDACT_CONFIG_PATH environment variable, or ~/.idact.conf """ log = get_logger(__name__) with stage_debug(log, "Serializing the environment to cluster."): node = cluster.get_access_node() assert isinstance(node, NodeInternal) path = get_remote_environment_path(node=node, path=path) file_contents = serialize_environment(environment) put_file_on_node(node=node, remote_path=path, contents=file_contents)
def tunnel(self, there: int, here: Optional[int] = None) -> TunnelInternal: try: log = get_logger(__name__) with stage_debug(log, "Opening tunnel %s -> %d to %s", here, there, self): self._ensure_allocated() here, there = validate_tunnel_ports(here=here, there=there) first_try = [True] def get_bindings_and_build_tunnel() -> TunnelInternal: bindings = get_bindings_with_single_gateway( here=here if first_try[0] else ANY_TUNNEL_PORT, node_host=self._host, node_port=self._port, there=there) first_try[0] = False return build_tunnel(config=self._config, bindings=bindings, ssh_password=env.password, ssh_pkey=env.key_filename) with authenticate(host=self._host, port=self._port, config=self._config): if here == ANY_TUNNEL_PORT: return get_bindings_and_build_tunnel() return retry_with_config( get_bindings_and_build_tunnel, name=Retry.TUNNEL_TRY_AGAIN_WITH_ANY_PORT, config=self._config) except RuntimeError as e: raise RuntimeError( "Unable to tunnel {there} on node '{host}'.".format( there=there, host=self._host)) from e
def install_key(config: ClusterConfig, authorized_keys: Optional[str] = None): """Installs the public key on the access node. If it was not generated or it's missing, generates one. Expects password authentication to have already been performed. :param config: Cluster config for connection. :param authorized_keys: Path to authorized_keys. Default: `~/.ssh/authorized_keys` """ log = get_logger(__name__) authorized_keys_path = (authorized_keys if authorized_keys else ".ssh/authorized_keys") with stage_debug(log, "Attempting to determine key path."): public_key_path = try_getting_public_key_from_config(config=config, log=log) if public_key_path is None: with stage_debug(log, "Generating key."): config.key = generate_key(host=config.host) public_key_path = get_public_key_location( private_key_location=config.key) with stage_debug(log, "Reading key."): public_key = read_public_key(public_key_path=public_key_path) @fabric.decorators.task def task(): """Creates the .ssh dir with proper permissions. Adds the public key to the authorized keys file if it's not been added already.""" with stage_debug(log, "Creating authorized keys file: %s", authorized_keys_path): with capture_fabric_output_to_log(): run("mkdir -p ~/.ssh" " && chmod 700 ~/.ssh" " && touch '{authorized_keys_path}'" " && chmod 644 '{authorized_keys_path}'".format( authorized_keys_path=authorized_keys_path)) with stage_debug(log, "Downloading authorized keys file."): with capture_fabric_output_to_log(): authorized_keys_fd = BytesIO() get(authorized_keys_path, authorized_keys_fd) authorized_keys_contents = \ authorized_keys_fd.getvalue().decode('ascii').splitlines() if public_key not in authorized_keys_contents: with stage_debug(log, "Appending to authorized keys file."): log.debug("Warning: This operation is not atomic on NFS.") with capture_fabric_output_to_log(): run("echo '{public_key}' >> {authorized_keys_path}".format( public_key=public_key, authorized_keys_path=authorized_keys_path)) with stage_debug(log, "Checking if public key was added."): with capture_fabric_output_to_log(): run("grep '{public_key}' '{authorized_keys_path}'".format( public_key=public_key, authorized_keys_path=authorized_keys_path)) with raise_on_remote_fail(exception=RuntimeError): fabric.tasks.execute(task)
def deploy_dask_scheduler(node: NodeInternal) -> DaskSchedulerDeployment: """Deploys a Dask scheduler on the node. :param node: Node to deploy on. """ log = get_logger(__name__) with ExitStack() as stack: with stage_debug(log, "Creating a runtime dir."): runtime_dir = create_runtime_dir(node=node) stack.enter_context( remove_runtime_dir_on_failure(node=node, runtime_dir=runtime_dir)) with stage_debug(log, "Obtaining free remote ports."): remote_port, bokeh_port = get_free_remote_ports(count=2, node=node) with stage_debug(log, "Creating a scratch subdirectory."): scratch_subdir = create_scratch_subdir(node=node) log_file = create_log_file(node=node, runtime_dir=runtime_dir) script_contents = get_scheduler_deployment_script( remote_port=remote_port, bokeh_port=bokeh_port, scratch_subdir=scratch_subdir, log_file=log_file, config=node.config) log.debug("Deployment script contents: %s", script_contents) with stage_debug(log, "Deploying script."): deployment = deploy_generic(node=node, script_contents=script_contents, runtime_dir=runtime_dir) stack.enter_context(cancel_on_failure(deployment)) @fabric.decorators.task def extract_address_from_log() -> str: """Extracts scheduler address from a log file.""" with capture_fabric_output_to_log(): output = get_remote_file(remote_path=log_file) log.debug("Log file: %s", output) return extract_address_from_output(output=output) with stage_debug(log, "Obtaining scheduler address."): address = retry_with_config( lambda: node.run_task(task=extract_address_from_log), name=Retry.GET_SCHEDULER_ADDRESS, config=node.config) with stage_debug(log, "Opening a tunnel to scheduler."): tunnel = node.tunnel(here=remote_port, there=remote_port) stack.enter_context(close_tunnel_on_failure(tunnel)) log.debug("Scheduler local port: %d", tunnel.here) with stage_debug(log, "Opening a tunnel to bokeh diagnostics server."): bokeh_tunnel = node.tunnel(here=bokeh_port, there=bokeh_port) stack.enter_context(close_tunnel_on_failure(bokeh_tunnel)) log.debug("Diagnostics local port: %d", bokeh_tunnel.here) return DaskSchedulerDeployment(deployment=deployment, tunnel=tunnel, bokeh_tunnel=bokeh_tunnel, address=address)
def deploy_jupyter(node: NodeInternal, local_port: int) -> JupyterDeployment: """Deploys a Jupyter Notebook server on the node, and creates a tunnel to a local port. :param node: Node to deploy Jupyter Notebook on. :param local_port: Local tunnel binding port. """ log = get_logger(__name__) with stage_debug(log, "Creating a runtime dir."): runtime_dir = create_runtime_dir(node=node) with stage_debug(log, "Obtaining a free remote port."): remote_port = get_free_remote_port(node=node) if node.config.use_jupyter_lab: jupyter_version = 'lab' else: jupyter_version = 'notebook' deployment_commands = [ 'export JUPYTER_RUNTIME_DIR="{runtime_dir}"'.format( runtime_dir=runtime_dir), get_command_to_append_local_bin() ] log_file = create_log_file(node=node, runtime_dir=runtime_dir) deployment_commands.append('jupyter {jupyter_version}' ' --ip 127.0.0.1' ' --port "{remote_port}"' ' --no-browser > {log_file} 2>&1'.format( jupyter_version=jupyter_version, remote_port=remote_port, log_file=log_file)) script_contents = get_deployment_script_contents( deployment_commands=deployment_commands, setup_actions=node.config.setup_actions.jupyter) log.debug("Deployment script contents: %s", script_contents) with stage_debug(log, "Deploying script."): deployment = deploy_generic(node=node, script_contents=script_contents, runtime_dir=runtime_dir) with cancel_on_failure(deployment): @fabric.decorators.task def load_nbserver_json(): """Loads notebook parameters from a json file.""" with capture_fabric_output_to_log(): with cd(runtime_dir): nbserver_json_path = run( "readlink -vf $PWD/nbserver-*.json").splitlines()[0] run("cat '{log_file}' || exit 0".format(log_file=log_file)) run("cat '{nbserver_json_path}' > /dev/null".format( nbserver_json_path=nbserver_json_path)) nbserver_json_str = get_remote_file(nbserver_json_path) nbserver_json = json.loads(nbserver_json_str) return int(nbserver_json['port']), nbserver_json['token'] with stage_debug(log, "Obtaining info about notebook from json file."): actual_port, token = retry_with_config( lambda: node.run_task(task=load_nbserver_json), name=Retry.JUPYTER_JSON, config=node.config) with stage_debug(log, "Opening a tunnel to notebook."): tunnel = node.tunnel(there=actual_port, here=local_port) return JupyterDeploymentImpl(deployment=deployment, tunnel=tunnel, token=token)
def discard_non_functional_deployments( deployments: SynchronizedDeployments ) -> SynchronizedDeployments: # noqa """Discards deployments that were not expired, but are no longer functional, e.g. were cancelled.""" log = get_logger(__name__) all_nodes = [] for nodes in deployments.nodes: with stage_debug( log, "Checking whether allocation deployment" " is functional: %s.", nodes): nodes_functional = not nodes.waited or nodes.running() if nodes_functional: all_nodes.append(nodes) else: log.info( "Discarding an allocation deployment," " because it is no longer functional: %s.", nodes) all_jupyter_deployments = [] for jupyter in deployments.jupyter_deployments: jupyter_impl = jupyter assert isinstance(jupyter_impl, JupyterDeploymentImpl) with stage_debug( log, "Checking whether Jupyter deployment" " is functional: %s.", jupyter_impl): try: validate_tunnel_http_connection(tunnel=jupyter_impl.tunnel) all_jupyter_deployments.append(jupyter_impl) except Exception: # pylint: disable=broad-except log.info( "Discarding a Jupyter deployment," " because it is no longer functional: %s.", jupyter_impl) log.debug("Exception", exc_info=1) with stage_debug(log, "Cancelling tunnel to discarded notebook."): jupyter_impl.cancel_local() all_dask_deployments = [] for dask in deployments.dask_deployments: dask_impl = dask assert isinstance(dask_impl, DaskDeploymentImpl) with stage_debug( log, "Checking whether Dask deployment" " is functional: %s.", dask_impl): try: validate_tunnel_http_connection( tunnel=dask_impl.scheduler.bokeh_tunnel) all_dask_deployments.append(dask_impl) except Exception: # pylint: disable=broad-except log.info( "Discarding a Dask deployment," " because it is no longer functional: %s.", dask_impl) log.debug("Exception", exc_info=1) with stage_debug( log, "Cancelling tunnels for discarded Dask" " deployment."): dask_impl.cancel_local() return SynchronizedDeploymentsImpl( nodes=all_nodes, jupyter_deployments=all_jupyter_deployments, dask_deployments=all_dask_deployments)
def deploy_dask_worker(node: NodeInternal, scheduler: DaskSchedulerDeployment) -> DaskWorkerDeployment: # noqa, pylint: disable=line-too-long """Deploys a Dask worker on the node. :param node: Node to deploy on. :param scheduler: Already deployed scheduler. """ log = get_logger(__name__) with ExitStack() as stack: with stage_debug(log, "Creating a runtime dir."): runtime_dir = create_runtime_dir(node=node) stack.enter_context( remove_runtime_dir_on_failure(node=node, runtime_dir=runtime_dir)) with stage_debug(log, "Obtaining a free remote port."): bokeh_port = get_free_remote_port(node=node) with stage_debug(log, "Creating a scratch subdirectory."): scratch_subdir = create_scratch_subdir(node=node) log_file = create_log_file(node=node, runtime_dir=runtime_dir) script_contents = get_worker_deployment_script( scheduler_address=scheduler.address, bokeh_port=bokeh_port, scratch_subdir=scratch_subdir, cores=node.cores, memory_limit=node.memory, log_file=log_file, config=node.config) log.debug("Deployment script contents: %s", script_contents) with stage_debug(log, "Deploying script."): deployment = deploy_generic(node=node, script_contents=script_contents, runtime_dir=runtime_dir) stack.enter_context(cancel_on_failure(deployment)) @fabric.decorators.task def validate_worker_started_from_log(): """Checks that the worker has started correctly based on the log file.""" with capture_fabric_output_to_log(): output = get_remote_file(remote_path=log_file) log.debug("Log file: %s", output) validate_worker_started(output=output) with stage_debug(log, "Checking if worker started."): retry_with_config( lambda: node.run_task(task=validate_worker_started_from_log), name=Retry.CHECK_WORKER_STARTED, config=node.config) with stage_debug(log, "Opening a tunnel to bokeh diagnostics server."): bokeh_tunnel = node.tunnel(here=bokeh_port, there=bokeh_port) stack.enter_context(close_tunnel_on_failure(bokeh_tunnel)) log.debug("Diagnostics local port: %d", bokeh_tunnel.here) return DaskWorkerDeployment(deployment=deployment, bokeh_tunnel=bokeh_tunnel)
def build_tunnel(config: ClusterConfig, bindings: List[Binding], ssh_password: Optional[str] = None, ssh_pkey: Optional[str] = None) -> TunnelInternal: """Builds a multi-hop tunnel from a sequence of bindings. :param config: Cluster config. :param bindings: Sequence of bindings, starting with the local binding. :param ssh_password: Ssh password. :param ssh_pkey: Ssh private key. """ if len(bindings) < 2: raise ValueError("At least one local and one remote binding" " is required to build a tunnel") with ExitStack() as stack: tunnels = [] log = get_logger(__name__) log.debug("Ssh username: %s", config.user) log.debug("Using password: %r", ssh_password is not None) log.debug("Using key file: %s", ssh_pkey) logger = get_debug_logger("{}/Tunnels".format(__name__)) # First hop if not the only one if len(bindings) != 2: with stage_debug(log, "Adding first hop."): ssh_address_or_host = (config.host, config.port) local_bind_address = ANY_ADDRESS remote_bind_address = bindings[1].as_tuple() log.debug("Ssh address is %s", ssh_address_or_host) log.debug("Local bind address is %s", local_bind_address) log.debug("Remote bind address is %s", remote_bind_address) def create_first_tunnel(): return FirstHopTunnel( forwarder=SSHTunnelForwarder( ssh_address_or_host, ssh_config_file=None, ssh_username=config.user, ssh_password=ssh_password, ssh_pkey=ssh_pkey, local_bind_address=local_bind_address, remote_bind_address=remote_bind_address, set_keepalive=TUNNEL_KEEPALIVE, allow_agent=False, logger=logger), there=bindings[1].port, config=config) tunnel = retry_with_config(create_first_tunnel, name=Retry.OPEN_TUNNEL, config=config) stack.enter_context(close_tunnel_on_failure(tunnel)) tunnels.append(tunnel) # Middle hops if any prev_tunnel = tunnels[0] if tunnels else None for i, next_binding in enumerate(bindings[2:-1]): with stage_debug(log, "Adding middle hop %d.", i): # Connect through previous tunnel ssh_address_or_host = ( "127.0.0.1", prev_tunnel.forwarder.local_bind_port) local_bind_address = ANY_ADDRESS remote_bind_address = next_binding.as_tuple() log.debug("Ssh address is %s", ssh_address_or_host) log.debug("Local bind address is %s", local_bind_address) log.debug("Remote bind address is %s", remote_bind_address) def create_middle_tunnel(): next_binding_port = next_binding.port # noqa, pylint: disable=cell-var-from-loop, line-too-long return FirstHopTunnel( forwarder=SSHTunnelForwarder( ssh_address_or_host, ssh_config_file=None, ssh_username=config.user, ssh_password=ssh_password, ssh_pkey=ssh_pkey, local_bind_address=local_bind_address, remote_bind_address=remote_bind_address, set_keepalive=TUNNEL_KEEPALIVE, allow_agent=False, logger=logger), there=next_binding_port, config=config) next_tunnel = retry_with_config(create_middle_tunnel, name=Retry.OPEN_TUNNEL, config=config) stack.enter_context(close_tunnel_on_failure(next_tunnel)) tunnels.append(next_tunnel) prev_tunnel = next_tunnel with stage_debug(log, "Adding last hop."): # Last hop last_hop_port = (config.port if len(bindings) == 2 else tunnels[-1].forwarder.local_bind_port) ssh_address_or_host = ("127.0.0.1", last_hop_port) local_bind_address = bindings[0].as_tuple() remote_bind_address = bindings[-1].as_tuple() log.debug("Ssh address is %s", ssh_address_or_host) log.debug("Local bind address is %s", local_bind_address) log.debug("Remote bind address is %s", remote_bind_address) def create_last_tunnel(): return FirstHopTunnel( forwarder=SSHTunnelForwarder( ssh_address_or_host, ssh_config_file=None, ssh_username=config.user, ssh_password=ssh_password, ssh_pkey=ssh_pkey, local_bind_address=local_bind_address, remote_bind_address=remote_bind_address, set_keepalive=TUNNEL_KEEPALIVE, allow_agent=False, logger=logger), there=bindings[-1].port, config=config) last_tunnel = retry_with_config(create_last_tunnel, name=Retry.OPEN_TUNNEL, config=config) stack.enter_context(close_tunnel_on_failure(last_tunnel)) tunnels.append(last_tunnel) if len(tunnels) == 1: return tunnels[0] return MultiHopTunnel(tunnels=tunnels, config=config)