def determine_ports_for_nodes(allocation_id: int, hostnames: List[str], config: ClusterConfig, raise_on_missing: bool) -> List[int]: """Tries to determine sshd ports for each node. Removes the file if no exception was raised. :param allocation_id: Job id. :param hostnames: List of hostnames. :param config: Cluster config. :param raise_on_missing: Raise an exception if port could not be determined. """ log = get_logger(__name__) with stage_debug(log, "Fetching port info for sshd."): port_info_contents = fetch_port_info(allocation_id=allocation_id, config=config) port_info = SshdPortInfo(contents=port_info_contents) with stage_debug(log, "Determining ports for each host."): ports = [port_info.get_port(host=host, raise_on_missing=raise_on_missing) for host in hostnames] with stage_debug(log, "Removing the file containing sshd port info."): remove_port_info(allocation_id, config=config) return ports
def deserialize_client_config_from_json(data: dict) -> ClientConfig: """Deserializes :class:`.ClientConfig` from json. :param data: json to deserialize. """ log = get_logger(__name__) log.debug("Loaded config: %s", data) if use_defaults_in_missing_fields(data=data): log.debug("Filled missing fields with None: %s", data) clusters = { name: ClusterConfigImpl(host=value['host'], port=value['port'], user=value['user'], auth=AuthMethod[value['auth']], key=value['key'], install_key=value['installKey'], disable_sshd=value['disableSshd'], setup_actions=SetupActionsConfigImpl( jupyter=value['setupActions']['jupyter'], dask=value['setupActions']['dask']), scratch=value['scratch'], notebook_defaults=value['notebookDefaults'], retries=provide_defaults_for_retries( deserialize_retries(value['retries'])), use_jupyter_lab=value['useJupyterLab']) for name, value in data['clusters'].items() } return ClientConfig(clusters=clusters, log_level=data['logLevel'])
def try_generate_unique_path(suffix_length: int, location: str, prefix: str) -> Optional[str]: """Tries to generate a unique file name. Returns None if the file already exists. :param suffix_length: File name suffix length. :param location: File parent dir. :param prefix: File name prefix. """ log = get_logger(__name__) suffix = get_key_suffix(length=suffix_length) private_key_path = get_key_path(location=location, prefix=prefix, suffix=suffix) if os.path.isfile(private_key_path): log.warning("File exists: '%s'.", private_key_path) return None public_key_path = get_public_key_location( private_key_location=private_key_path) if os.path.isfile(public_key_path): log.warning("File exists: '%s'.", public_key_path) return None return private_key_path
def deploy_workers_on_each_node(nodes: Sequence[Node], scheduler: DaskSchedulerDeployment, stack: ExitStack) \ -> List[ DaskWorkerDeployment]: """Deploys workers on each node. :param nodes: Nodes to deploy workers on. :param scheduler: Scheduler for workers. :param stack: Exit stack. Workers will be cancelled on failure. """ log = get_logger(__name__) workers = [] with stage_info(log, "Deploying workers."): total = len(nodes) for i, node in enumerate(nodes): worker = deploy_worker_on_node(node=node, scheduler=scheduler, worker_number=i + 1, worker_count=total) stack.enter_context(cancel_on_failure(worker)) workers.append(worker) return workers
def extract_squeue_line(now: datetime.datetime, line: str, node: Node) -> Optional[SqueueResult]: """Extracts information from `squeue` output line, where format is `%A|%D|%L|%r|%R|%T`. :param now: Current time for calculating job finish time. :param line: `squeue` output line. :param node: Node to run `scontrol` on. """ if not line: return None components = line.split('|') try: job_id = extract_squeue_format_A(value=components[0]) node_count = extract_squeue_format_D(value=components[1]) end_time = extract_squeue_format_L(now=now, value=components[2]) reason = extract_squeue_format_r(value=components[3]) node_list = extract_squeue_format_R(value=components[4], node=node) state = extract_squeue_format_T(value=components[5]) except ValueError: log = get_logger(__name__) log.debug("Exception", exc_info=1) return None return SqueueResult(job_id=job_id, node_count=node_count, end_time=end_time, reason=reason, node_list=node_list, state=state)
def discard_invalid_workers(workers: List[DaskWorkerDeployment], stack: ExitStack) \ -> Tuple[ List[DaskWorkerDeployment], List[Node]]: """Validates each worker. Returns a tuple of valid workers and nodes for which the workers could not be validated. :param workers: Workers to validate. :param stack: Exit stack. Failed workers will be cancelled on exit. """ log = get_logger(__name__) valid_workers = [] nodes_to_redeploy = [] worker_count = len(workers) for i, worker in enumerate(workers): try: with stage_info(log, "Validating worker %d/%d.", i + 1, worker_count): validate_worker(worker=worker) valid_workers.append(worker) except Exception: # noqa, pylint: disable=broad-except log.debug("Failed to validate worker. Exception:", exc_info=1) nodes_to_redeploy.append(worker.deployment.node) stack.enter_context(cancel_on_exit(worker)) return valid_workers, nodes_to_redeploy
def cancel(self): """Kills the program and all its child processes. Removes the runtime dir. Raises an exception if the top level process is still running after :attr:`.CANCEL_TIMEOUT` seconds. :raises RuntimeError: If the program is still running. """ parent_pid = self._pid node = self._node tree = ' '.join([str(pid) for pid in ptree(pid=parent_pid, node=node)]) def cancel_task(): """Kills the process tree and fails if the parent is still running after a timeout.""" node.run("kill {tree}" "; kill -0 {parent_pid} && exit 1 || exit 0".format( tree=tree, parent_pid=parent_pid)) log = get_logger(__name__) with remove_runtime_dir_on_exit(node=self._node, runtime_dir=self._runtime_dir): with stage_debug(log, "Killing the process tree for pid: %d", self._pid): retry_with_config(fun=cancel_task, name=Retry.CANCEL_DEPLOYMENT, config=self._node.config)
def get_port(self, host: str, raise_on_missing: bool) -> int: """Returns the ssh access port for the host. Tries to provide defaults if none found. :param host: Host to find the ssh port for. :param raise_on_missing: Raise an exception on missing port info. """ log = get_logger(__name__) if host in self._hosts and self._hosts[host]: return self._hosts[host].pop() log.warning("Unable to find unique sshd server for %s", host) if raise_on_missing: raise RuntimeError( "Unable to find unique sshd server for {}".format(host)) if self._hosts: log.warning("Assuming sandbox, defaulting to first found." " If this is not sandbox, node access may not work" " properly.") port = self._hosts[next(iter(self._hosts.keys()))][0] log.info("First found: %d", port) return port log.warning("No port info found, defaulting to %d.", NODE_DEFAULT_PORT) return NODE_DEFAULT_PORT
def cancel(self): log = get_logger(__name__) with stage_info(log, "Cancelling job %d.", self._job_id): run_scancel(job_id=self._job_id, node=self._access_node) for node in self._nodes: node.make_cancelled()
def push_environment(cluster: Cluster, path: Optional[str] = None): """Merges the environment on the cluster with the current environment. :param cluster: Cluster to push the environment to. :param path: Path to remote environment file. Default: Remote IDACT_CONFIG_PATH environment variable, or ~/.idact.conf """ log = get_logger(__name__) with stage_info(log, "Pushing the environment to cluster."): try: remote_environment = deserialize_environment_from_cluster( cluster=cluster, path=path) except RuntimeError: log.info("Remote environment is missing, current environment will" " be copied to cluster.") log.debug("Exception", exc_info=1) remote_environment = EnvironmentImpl() local_environment = EnvironmentProvider().environment merged_environment = merge_environments(local=remote_environment, remote=local_environment) serialize_environment_to_cluster(environment=merged_environment, cluster=cluster, path=path)
def capture_fabric_output_to_log(): """Turns on all Fabric output and replaces `sys.stdout`, `sys.stderr` with a logger DEBUG output. """ saved = {group: fabric.state.output[group] for group in ['status', 'aborts', 'warnings', 'running', 'stdout', 'stderr', 'user', 'debug', 'exceptions']} for group in saved.keys(): fabric.state.output[group] = True logger = get_logger(FABRIC_LOGGER_NAME) saved_stdout = sys.stdout saved_stderr = sys.stderr replacement_stdout = LoggerOut(logger=logger, fileno=STDOUT_FILENO) replacement_stderr = LoggerOut(logger=logger, fileno=STDERR_FILENO) try: sys.stdout = replacement_stdout sys.stderr = replacement_stderr yield finally: sys.stderr = saved_stderr sys.stdout = saved_stdout for group, show in saved.items(): fabric.state.output[group] = show
def deploy_generic(node: NodeInternal, script_contents: str, runtime_dir: str) -> GenericDeployment: """Deploys a program on the node. :param node: Node to deploy the program on. :param script_contents: Deployment script contents. :param runtime_dir: Runtime dir to remove. """ log = get_logger(__name__) with stage_debug(log, "Uploading entry point."): script_path = upload_entry_point(contents=script_contents, node=node, runtime_dir=runtime_dir) with stage_debug(log, "Executing the deployment command."): output = node.run(get_deployment_command( script_path=script_path)) lines = output.splitlines() pid = int(lines[0]) return GenericDeployment(node=node, pid=pid, runtime_dir=runtime_dir)
def cancel(self): log = get_logger(__name__) with ExitStack() as stack: stack.enter_context( stage_info(log, "Cancelling Jupyter deployment.")) stack.enter_context(cancel_on_exit(self._deployment)) self.cancel_local()
def cancel(self): """Cancels the scheduler deployment.""" log = get_logger(__name__) with ExitStack() as stack: stack.enter_context( stage_info(log, "Cancelling scheduler deployment on %s.", self._deployment.node.host)) stack.enter_context(cancel_on_exit(self._deployment)) self.cancel_local()
def upload_entry_point(contents: str, node: NodeInternal, runtime_dir: Optional[str] = None) -> str: """Uploads the entry point script and returns its path. :param contents: Script contents. :param node: Node to upload the entry point to. :param runtime_dir: Runtime dir for deployment script. Default: ~/.idact/entry_points. """ log = get_logger(__name__) result = [] entry_point_location = runtime_dir if runtime_dir else ENTRY_POINT_LOCATION @fabric.decorators.task def task(): """Creates the entry point dir and file. Fails if it couldn't be created.""" with capture_fabric_output_to_log(): run("mkdir -p {entry_point_location}" " && chmod 700 {entry_point_location}".format( entry_point_location=entry_point_location)) file_name = get_random_file_name( length=ENTRY_POINT_FILE_NAME_LENGTH) file_path = run("echo {entry_point_location}/{file_name}".format( entry_point_location=entry_point_location, file_name=file_name)) file_exists = exists(file_path) if file_exists: log.warning("Overwriting randomly named entry point file:" " %s", file_path) with stage_debug(log, "Uploading the entry point script."): with capture_fabric_output_to_log(): real_path = run("echo {file_path}".format(file_path=file_path)) file = BytesIO(contents.encode('ascii')) put(file, real_path, mode=0o700) with stage_debug(log, "Checking the entry point script was uploaded."): with capture_fabric_output_to_log(): run("cat {real_path} > /dev/null".format(real_path=real_path)) result.append(real_path) node.run_task(task) return result[0]
def create_log_file(node: Node, runtime_dir: str) -> str: """Creates a log file in the runtime dir. :param node: Node to create the log file on. :param runtime_dir: Runtime dir path. """ log = get_logger(__name__) log_file = '{runtime_dir}/log'.format(runtime_dir=runtime_dir) with stage_debug(log, "Creating log file: '%s'.", log_file): node.run("touch '{}'".format(log_file)) return log_file
def report_pulled_deployments(deployments: SynchronizedDeploymentsImpl): """Prints pulled deployments. :param deployments: Deployments to report. """ log = get_logger(__name__) for node in deployments.nodes: log.info("Pulled allocation deployment: %s", node) for jupyter in deployments.jupyter_deployments: log.info("Pulled Jupyter deployment: %s", jupyter) for dask in deployments.dask_deployments: log.info("Pulled Dask deployment: %s", dask)
def discard_expired_deployments( deployments: DeploymentDefinitions) -> DeploymentDefinitions: # noqa """Returns a new object that does not contain deployments that have expired, or will expire in the near future. :param deployments: Deployments to examine. """ log = get_logger(__name__) with stage_debug(log, "Discarding expired deployments."): now = utc_now() log.debug("Now: %s", now) log.debug( "Will discard after the %d second mark" " before the expiration date.", DISCARD_DELTA_SECONDS) discard_now = utc_now() + timedelta(seconds=DISCARD_DELTA_SECONDS) unexpired_nodes = {} for uuid, node in deployments.nodes.items(): if node.expiration_date < discard_now: log.warning( "Discarding a synchronized allocation deployment," " because it has expired: %s", uuid) else: unexpired_nodes[uuid] = node unexpired_jupyter_deployments = {} for uuid, jupyter in deployments.jupyter_deployments.items(): if jupyter.expiration_date < discard_now: log.warning( "Discarding a Jupyter deployment," " because it has expired: %s", uuid) else: unexpired_jupyter_deployments[uuid] = jupyter unexpired_dask_deployments = {} for uuid, dask in deployments.dask_deployments.items(): if dask.expiration_date < discard_now: log.warning( "Discarding a Dask deployment," " because it has expired: %s", uuid) else: unexpired_dask_deployments[uuid] = dask return DeploymentDefinitions( nodes=unexpired_nodes, jupyter_deployments=unexpired_jupyter_deployments, dask_deployments=unexpired_dask_deployments)
def push_deployment(self, deployment: Union[Nodes, JupyterDeployment, DaskDeployment]): log = get_logger(__name__) with stage_info(log, "Pushing deployment: %s", deployment): log = get_logger(__name__) node = self.get_access_node() if deployment_definitions_file_exists(node=node): deployments = deserialize_deployment_definitions_from_cluster( node=node) else: log.debug( "No deployment definitions file, defaulting to empty.") deployments = DeploymentDefinitions() deployments = discard_expired_deployments(deployments) add_deployment_definition(deployments=deployments, deployment=deployment) serialize_deployment_definitions_to_cluster( node=node, deployments=deployments)
def is_local_port_taken(port: int) -> bool: """Returns True if local port is taken (unable to bind to it). :param port: Port to check. """ with get_socket_with_reuseaddr() as sock: try: sock.bind((LOCAL_BIND_ADDRESS, port)) except Exception: # noqa, pylint: disable=broad-except log = get_logger(__name__) log.debug("Exception, port probably taken.", exc_info=1) return True return False
def deserialize_deployment_definitions_from_cluster( node: NodeInternal) -> DeploymentDefinitions: # noqa """Downloads deployment definitions from the cluster. :param node: Node to deserialize deployment definitions from. """ log = get_logger(__name__) with stage_debug(log, "Deserializing deployment definitions" " from cluster."): path = get_deployment_definitions_file_path(node=node) file_contents = get_file_from_node(node=node, remote_path=path) serialized = json.loads(file_contents) return DeploymentDefinitions.deserialize(serialized=serialized)
def wait(self, timeout: Optional[float]): log = get_logger(__name__) end = None log.debug("Waiting for allocation of job %d...", self._job_id) if timeout is not None: end = utc_now() + datetime.timedelta(seconds=timeout) if self._done_waiting: raise RuntimeError("Already waited.") iterations = 0 while True: squeue = run_squeue(node=self._access_node) try: job = squeue[self._job_id] except KeyError as e: raise RuntimeError("Unable to obtain information " "about the allocation.") from e if job.state in ['PENDING', 'CONFIGURING']: if end is not None and utc_now() >= end: raise TimeoutError("Timed out while waiting " "for allocation.") if iterations % STILL_PENDING_MESSAGE_EVERY_N_SQUEUE == 0: log.info(STILL_PENDING_MESSAGE) else: log.debug(STILL_PENDING_MESSAGE) iterations += 1 sleep(WAIT_SQUEUE_INTERVAL) continue try: if job.state != 'RUNNING': message = ("Unable to wait: allocation entered unsupported" " or failing state: '{}'") raise RuntimeError(message.format(job.state)) self._done_waiting = True finalize_allocation(allocation_id=self._job_id, hostnames=job.node_list, nodes=self._nodes, parameters=self._parameters, allocated_until=job.end_time, config=self._access_node.config) finally: self._access_node.run("rm -f {entry_point_script_path}".format( entry_point_script_path=self._entry_point_script_path)) break
def connect_to_each_node(nodes: Sequence[Node], config: ClusterConfig): """Connects to each node to make sure any connection issues come up before attempting to actually deploy anything. :param nodes: Nodes to deploy Dask on. :param config: Cluster config. """ log = get_logger(__name__) node_count = len(nodes) for i, node in enumerate(nodes): with stage_info(log, "Connecting to %s:%d (%d/%d).", node.host, node.port, i + 1, node_count): retry_with_config(node.connect, name=Retry.DASK_NODE_CONNECT, config=config)
def deploy_scheduler_on_first_node( nodes: Sequence[Node]) -> DaskSchedulerDeployment: # noqa """Deploys a scheduler on the first node in the node sequence. :param nodes: Nodes to deploy Dask on. """ log = get_logger(__name__) assert isinstance(nodes[0], NodeInternal) first_node = nodes[0] # type: NodeInternal with stage_info(log, "Deploying scheduler on the first node: %s.", first_node.host): scheduler = retry_with_config( lambda: deploy_dask_scheduler(node=first_node), name=Retry.DEPLOY_DASK_SCHEDULER, config=first_node.config) return scheduler
def remove_runtime_dir(node: Node, runtime_dir: str): """Removes a runtime dir for deployment. Removes all files in it that do not start with a dot. Does not remove nested directories. On failure, produces a warning. :param node: Node to run commands on. :param runtime_dir: Path to the deployment dir. """ try: node.run("rm -f {runtime_dir}/*" " && rmdir {runtime_dir}".format(runtime_dir=runtime_dir)) except RuntimeError: log = get_logger(__name__) log.warning("Failed to remove runtime dir: '%s'.", runtime_dir) log.debug("Failed to remove runtime dir due to exception.", exc_info=1)
def run_sbatch(args: SbatchArguments, node: NodeInternal) -> Tuple[int, str]: """Runs sbatch on the given node. Returns the job id and the path to the entry point script. :param args: Arguments to use for allocation. :param node: Node to run sbatch on. """ log = get_logger(__name__) request, entry_point_script_path = prepare_sbatch_allocation_request( args=args, config=node.config, node=node) log.debug("Allocation request: %s", request) output = node.run_impl(request, install_keys=True) job_id = int(output.split(';')[0]) return job_id, entry_point_script_path
def get_file_from_node(node: NodeInternal, remote_path: str) -> str: """Runs a task on the node that downloads a file and returns its contents. :param node: Node to download the file from. :param remote_path: Remote file path. """ log = get_logger(__name__) @fabric.decorators.task def file_upload_task(): with capture_fabric_output_to_log(): return get_remote_file(remote_path=remote_path) with stage_debug(log, "Getting file from node %s: %s", node.host, remote_path): return node.run_task(task=file_upload_task)
def __init__(self, contents: str): self._hosts = defaultdict(list) log = get_logger(__name__) log.debug("Sshd port directory contents: %s", contents) lines = [i for i in contents.split(' ') if i] for line in lines: split = line.split(':') host = split[0] port = int(split[1]) self._hosts[host].append(port) log.debug("Host %s at %d", host, port) self._hosts = dict(self._hosts) if not self._hosts: log.warning("No deployed sshd servers were reported.")
def allocate_slurm_nodes(parameters: AllocationParameters, config: ClusterConfig) -> Nodes: """Tries to allocate nodes using Slurm. :param parameters: Allocation parameters. :param config: Config for the cluster to allocate nodes on. """ args = SbatchArguments(params=parameters) log = get_logger(__name__) with stage_debug(log, "Executing sbatch on access node."): access_node = get_access_node(config=config) job_id, entry_point_script_path = run_sbatch(args=args, node=access_node) def run_squeue_task() -> SqueueResult: job_squeue = run_squeue(node=access_node) return job_squeue[job_id] try: with stage_debug(log, "Obtaining info about job %d using squeue.", job_id): job = retry_with_config(run_squeue_task, name=Retry.SQUEUE_AFTER_SBATCH, config=config) except Exception as e: # noqa, pylint: disable=broad-except run_scancel(job_id=job_id, node=access_node) raise RuntimeError("Unable to obtain job info" " after allocation.") from e node_count = job.node_count nodes = [NodeImpl(config=config) for _ in range(node_count)] allocation = SlurmAllocation( job_id=job_id, access_node=access_node, nodes=nodes, entry_point_script_path=entry_point_script_path, parameters=parameters) return NodesImpl(nodes=nodes, allocation=allocation)
def serialize_deployment_definitions_to_cluster( node: NodeInternal, deployments: DeploymentDefinitions): # noqa """Uploads deployment definitions to the cluster, replacing any definitions file already there. :param node: Node to serialize definitions to. :param deployments: Deployments to upload. """ log = get_logger(__name__) with stage_debug(log, "Serializing deployment definitions to cluster."): serialized = deployments.serialize() file_contents = json.dumps(serialized, sort_keys=True, indent=4) parent_path = get_deployment_definitions_parent_path(node=node) node.run("mkdir -p {parent_path}" " && chmod 700 {parent_path}".format(parent_path=parent_path)) path = get_deployment_definitions_file_path(node=node) put_file_on_node(node=node, remote_path=path, contents=file_contents)