def command(self, command, timeout=60, sudo=False): """Execute a command on the instances. This will be done using an ssh command and potentially with sudo""" logger.debug(f'Executing {command} with sudo {sudo}.') client = ParallelSSHClient([i.ip for i in self.hosts], pkey=self.keysfile) output = client.run_command(command, read_timeout=timeout, sudo=sudo) client.join() return output
def fix_hostnames(): hosts = get_ips() client = ParallelSSHClient(hosts, user=USER) host_args = [{ "cmd": "sudo hostnamectl set-hostname synthetic-bot-up-%s" % (i, ) } for i in range(len(hosts))] output = client.run_command("%(cmd)s", host_args=host_args) client.join()
def ssh_all_hosts(self, command): """ Sends given command to remote hosts via ssh Parameters: command: command to be sent over ssh """ hosts = self.addresses client = ParallelSSHClient(hosts, user='******') client.run_command(command) client.join()
def parallel_ssh(host, user, password, command): #enable_host_logger() client = ParallelSSHClient([host], user=user, password=password) output = client.run_command(command) client.join() stdout = "" for host_output in output: if host_output.exit_code != 0: raise Exception("host returned exit code " + str(host_output.exit_code)) stdout_li = list(host_output.stdout) for line in stdout_li: stdout += line + "\n" return stdout
def run_command(command, hosts, user, verbose=False, proxy_host=None, timeout=10, **kwargs): """Run ssh command using Parallel SSH.""" result = {"0": [], "1": []} if proxy_host: client = ParallelSSHClient(hosts, user='******', pkey=SSH_KEY, proxy_host=proxy_host, proxy_user=user, proxy_pkey=SSH_KEY, timeout=timeout) else: client = ParallelSSHClient(hosts, user=user, pkey=SSH_KEY, timeout=timeout) output = client.run_command(command, stop_on_errors=False, **kwargs) client.join(output) # output = pssh.output.HostOutput objects list for host in output: if host.exit_code == 0: if verbose and host.stdout: for line in host.stdout: print(line) result['0'].append(host.host) elif host.host is not None: result['1'].append(host.host) # find hosts that have raised Exception (Authentication, Connection) # host.exception = pssh.exceptions.* & host.host = None failed_hosts = list(set(hosts) - set(sum(result.values(), []))) if failed_hosts: result['1'].extend(failed_hosts) return result
def run_command(command, hosts, user, verbose=False, proxy_host=None, timeout=10, **kwargs): """Run ssh command using Parallel SSH.""" result = {"0": [], "1": []} if proxy_host: client = ParallelSSHClient(hosts, user='******', proxy_host=proxy_host, proxy_user=user, timeout=timeout) else: client = ParallelSSHClient(hosts, user=user, timeout=timeout) output = client.run_command(command, stop_on_errors=False, **kwargs) client.join(output) for host in hosts: if host not in output: # Pssh AuthenticationException duplicate output dict key # {'saclay.iot-lab.info': {'exception': ...}, # {'saclay.iot-lab.info_qzhtyxlt': {'exception': ...}} site = next(iter(sorted(output))) raise OpenA8SshAuthenticationException(site) result['0' if output[host]['exit_code'] == 0 else '1'].append(host) if verbose: for host in hosts: # Pssh >= 1.0.0: stdout is None instead of generator object # when you have ConnectionErrorException stdout = output[host].get('stdout') if stdout: for _ in stdout: pass return result
class SSHManager(object): context_changed = False def __init__(self, hosts, host_config): self.hosts = sorted(hosts) self.all_hosts = copy.deepcopy(self.hosts) self.client = ParallelSSHClient(self.hosts, host_config=host_config) def remove_hosts(self, hosts): indices = [] new_hosts = [] for i in range(len(self.all_hosts)): if self.all_hosts[i] in hosts: indices.append(i) else: new_hosts.append(self.all_hosts[i]) self.all_hosts = new_hosts if self.context_changed: self.hosts = list(filter(lambda h: h not in hosts, self.hosts)) else: self.hosts = copy.deepcopy(self.all_hosts) self.client.host_config = list( map( lambda x: x[1], filter(lambda x: x[0] not in indices, enumerate(self.client.host_config)), )) self.client.hosts = self.all_hosts def add_host(self, host): self.all_hosts.append(host.host) self.all_hosts = sorted(list(set(self.all_hosts))) host_configs = self.client.host_config idx = self.all_hosts.index(host.host) host_configs.insert(idx, host.build_host_config()) self.client.hosts = self.all_hosts self.client.host_config = host_configs if not self.context_changed: self.hosts = copy.deepcopy(self.all_hosts) def run_command(self, command, commands=None, sudo=False): if commands is None: return self.client.run_command(command, sudo=sudo) else: return self.client.run_command(command, host_args=commands, sudo=sudo) def join(self, output): self.client.join(output) def change_context_hosts_all(self): self.change_context_hosts(self.all_hosts) def change_context_hosts(self, new_hosts): new_hosts = sorted(list(set(new_hosts))) for h in new_hosts: if h not in self.all_hosts: raise SSHManager.ContextException(f"Host {h} not in host list") self.hosts = list(filter(lambda h: h in new_hosts, self.all_hosts)) self.context_changed = True def change_context_indices(self, indices): indices = sorted(indices) if indices[0] < 0 or indices[len(indices) - 1] >= len(self.all_hosts): raise SSHManager.ContextException("Indices out of range") new_hosts = [] for i in indices: new_hosts.append(self.all_hosts[i]) self.change_context_hosts(new_hosts) def reset_context(self): self.hosts = copy.deepcopy(self.all_hosts) self.context_changed = False @staticmethod def build_host_config(*, n=0, users=None, passwords=None, user=None, password=None): if bool(users is None) == bool(user is None): raise HostConfigException( "Users or User (not both) needs to be defined") elif bool(passwords is None) == bool(password is None): raise HostConfigException( "Passwords or Password (not both) needs to be defined") elif passwords != None and len(passwords) != n: raise HostConfigException("Length of Passwords != n") elif users != None and len(users) != n: raise HostConfigException("Length of Users != n") elif n <= 0: raise HostConfigException("n should be greater than 0") if user is None and password is None: return [ HostConfig(user=users[i], password=passwords[i]) for i in range(n) ] elif user is None: return [ HostConfig(user=users[i], password=password) for i in range(n) ] elif password is None: return [ HostConfig(user=user, password=password[i]) for i in range(n) ] else: return [HostConfig(user=user, password=password) for i in range(n)] class ContextException(Exception): pass
class HPCConnection(object): def __init__(self, external_init_dict=None): self.logger = logging.getLogger(constants.logging_name) init_dict = {} clsname = self.__class__.__name__ if external_init_dict is not None: self.logger.debug( "{}: initializing from external dict".format(clsname)) init_dict = external_init_dict else: self.logger.debug( "{}: initializing with default values".format(clsname)) self.hostname = constants.hpc_hostname self.user = constants.user self.home_dir = os.path.join(constants.cc_working_dir, self.user) self.src_data_path = init_dict.get("src_data_path", "./data") self.template_path = constants.template_path self.logger.debug("Host being used is {}, under username {}".format( self.hostname, self.user)) self.keypath = init_dict.get("ssh_key_filename", constants.ssh_key_filename) self.client = ParallelSSHClient([self.hostname], pkey=self.keypath, user=self.user, keepalive_seconds=300) self.remote_abs_working_folder = None self.remote_working_folder = None self.active_dataset_name = None self.live_job_id = None def check_connection(self): status = True msg = None self.logger.debug("Testing connection...") try: self.client.run_command("ls") self.logger.debug("... ok") except ( AuthenticationException, UnknownHostException, ConnectionErrorException, ) as e: status = False msg = str(e) self.logger.debug("... failed ({})".format(msg)) return status, msg def copy_data_to_remote(self, dataset_, remote_temp_folder=None): """ Copies data contained in a local directory over to a remote destination """ self.logger.debug( "Copying data to remote location (from {} to {})".format( self.src_data_path, self.home_dir)) remote_base_path = self.home_dir local_datapath = self.src_data_path if remote_temp_folder is None: remote_temp_folder = rand_fname() full_remote_path = os.path.join(remote_base_path, remote_temp_folder) remote_tar = os.path.join(full_remote_path, "data.tar") self.remote_abs_working_folder = full_remote_path self.remote_working_folder = remote_temp_folder self.active_dataset_name = dataset_ self.logger.debug("Creating remote folder {}".format(full_remote_path)) self.client.run_command("mkdir " + full_remote_path) # data_path_content = os.listdir(path=src_data_path) # assert(len(data_path_content) == 1) # df_basename = data_path_content[0] df_basename = dataset_ # self.logger.debug("system cmd: " + "tar cvf /tmp/" + remote_temp_folder + ".tar -C " # + os.path.join(local_datapath, df_basename) + " .") self.logger.debug("system cmd: tar cvf /tmp/{}.tar -C {} .".format( remote_temp_folder, os.path.join(local_datapath, df_basename))) os.system("tar cf /tmp/" + remote_temp_folder + ".tar -C " + os.path.join(local_datapath, df_basename) + " .") try: self.logger.debug("Copying data tar file") g = self.client.scp_send("/tmp/" + remote_temp_folder + ".tar", remote_tar) joinall(g, raise_error=True) except SCPError as e: self.logger.error("Copy failed (scp error {})".format(e)) except Exception as e: self.logger.error("Copy failed: {}".format(e)) raise Exception("scp_send failed") s = "tar xvf " + remote_tar + " -C " + full_remote_path self.logger.debug("Untarring remote data: {}".format(s)) output = self.client.run_command(s) self.client.join(output) errmsg = next(output[self.hostname]["stderr"], None) if errmsg is not None: self.logger.error("Error: " + errmsg) raise Exception("Error untarring data file: " + errmsg) errmsg = next(output[self.hostname]["stdout"], None) if errmsg is not None: self.logger.debug("stdout: " + errmsg) self.logger.debug("Remove remote temp tar file " + "/tmp/" + remote_temp_folder + ".tar") os.remove("/tmp/" + remote_temp_folder + ".tar") # output files in base_dir/jobname/out def copy_data_from_remote(self, jobid, absolute_local_out_dir, cleanup_temp=True): self.logger.debug("Copying data from remote") absolute_tar_fname = os.path.join( self.remote_abs_working_folder, self.remote_working_folder + "_out.tar") absolute_output_data_path = os.path.join( self.remote_abs_working_folder, "out") stdout_file = os.path.join(self.home_dir, "slurm-" + jobid + ".out") self.logger.debug( " Remote data is located in {}".format(absolute_output_data_path)) self.logger.debug(" Slurm output file is {}".format(stdout_file)) try: self.logger.debug( " Copying slurm file to {}".format(absolute_output_data_path)) output = self.client.run_command("cp " + stdout_file + " " + absolute_output_data_path) self.client.join(output) self.logger.debug(output) self.logger.debug(" Tarring remote folder") output = self.client.run_command("tar cf " + absolute_tar_fname + " -C " + absolute_output_data_path + " .") self.client.join(output) self.logger.debug(output) # time.sleep(30) # patch since run_command sems non-blocking self.logger.debug("Picking up tar file size") output = self.client.run_command("du -sb " + absolute_tar_fname) self.client.join(output) self.logger.debug(output) line = "" for char in output[self.hostname].stdout: line += char # print(line) tar_size = int(re.match("[0-9]*", line).group(0)) self.logger.info("{} bytes to copy from remote".format(tar_size)) local_tar = "/tmp/" + self.remote_working_folder + "_out.tar" # g = self.client.scp_recv(absolute_tar_fname, local_tar) self.logger.debug( "Remote tar file is {}".format(absolute_tar_fname)) tries = 0 while tries < 3: self.logger.debug("Copying tar file to /tmp") g = self.client.copy_remote_file(absolute_tar_fname, local_tar) # scp_recv joinall(g, raise_error=True) output = subprocess.check_output("du -sb " + local_tar + "_" + self.hostname, shell=True) recv_tar_size = int( re.match("[0-9]*", output.decode("utf-8")).group(0)) self.logger.debug("Received: {} bytes".format(recv_tar_size)) if recv_tar_size == tar_size: break tries += 1 if tries == 3: raise Exception("Unable to copy tar file from remote end") if not os.path.exists(absolute_local_out_dir): # shutil.rmtree(absolute_local_out_dir) self.logger.debug( "Local destination folder {} does not exist, creating". format(absolute_local_out_dir)) os.mkdir(absolute_local_out_dir) # os.mkdir(path.join(absolute_local_out_dir,jobname) self.logger.debug( "Untarring received file to {}".format(absolute_local_out_dir)) os.system("tar xf " + local_tar + "_" + self.hostname + " -C " + absolute_local_out_dir) if cleanup_temp: # print("todo: cleanup tmp file") os.remove(local_tar + "_" + self.hostname) except Exception as e: self.logger.error( "Exception during file transfer from remote: {}".format(e)) def copy_singlefile_to_remote(self, local_filename, remote_path=".", is_executable=False): r = os.path.join( self.remote_abs_working_folder, remote_path, os.path.basename(local_filename), ) g = self.client.copy_file(local_filename, r) joinall(g, raise_error=True) if is_executable: self.client.run_command("chmod ugo+x " + r) def create_remote_subdir(self, remote_subdir): self.client.run_command( "mkdir -p " + os.path.join(self.remote_abs_working_folder, remote_subdir)) self.client.run_command( "chmod 777 " + os.path.join(self.remote_abs_working_folder, remote_subdir)) # executable_ is either raven or ostrich def copy_batchscript( self, executable_, guessed_duration, datafile_basename, batch_tmplt_fname, shub_hostname, ): template_file = open( os.path.join(self.template_path, batch_tmplt_fname), "r") abs_remote_output_dir = os.path.join(self.remote_abs_working_folder, "out") tmplt = template_file.read() tmplt = tmplt.replace("ACCOUNT", constants.cc_account_info) tmplt = tmplt.replace("DURATION", guessed_duration) tmplt = tmplt.replace("TEMP_PATH", self.remote_abs_working_folder) tmplt = tmplt.replace("INPUT_PATH", self.remote_abs_working_folder) tmplt = tmplt.replace("OUTPUT_PATH", abs_remote_output_dir) tmplt = tmplt.replace("DATAFILE_BASENAME", datafile_basename) tmplt = tmplt.replace("SHUB_HOSTNAME", shub_hostname) tmplt = tmplt.replace("EXEC", executable_) # subst_template_file, subst_fname = tempfile.mkstemp(suffix=".sh") subst_fname = self.remote_working_folder + ".sh" file = open("/tmp/" + subst_fname, "w") file.write(tmplt) file.close() self.client.run_command("mkdir " + abs_remote_output_dir) self.client.run_command("chmod 777 " + self.remote_abs_working_folder) self.client.run_command("chmod 777 " + abs_remote_output_dir) g = self.client.copy_file( "/tmp/" + subst_fname, os.path.join(self.remote_abs_working_folder, subst_fname), ) joinall(g, raise_error=True) self.client.run_command( "chmod ugo+x " + os.path.join(self.remote_abs_working_folder, subst_fname)) os.remove("/tmp/" + subst_fname) return os.path.join(self.remote_abs_working_folder, subst_fname) def submit_job(self, script_fname): self.logger.debug("Submitting job {}".format(script_fname)) # output = self.client.run_command("cd {}; ".format(self.home_dir) + constants.sbatch_cmd + # " --parsable " + script_fname) output = self.client.run_command("cd {}; {} --parsable {}".format( self.home_dir, constants.sbatch_cmd, script_fname)) self.client.join(output) errmsg = next(output[self.hostname]["stderr"], None) if errmsg is not None: for e in output[self.hostname]["stderr"]: errmsg += e + "\n" self.logger.error(" Error: {}".format(errmsg)) raise Exception("Error: " + errmsg) self.live_job_id = next(output[self.hostname]["stdout"]) self.logger.debug(" Job id {}".format(self.live_job_id)) return self.live_job_id def read_from_remote(self, remote_filename): filecontent = [] self.logger.debug("read_from_remote") retry = True # maybe remote file is being overwritten, try again if remote copy fails while True: try: local_filename = os.path.join( "/tmp", self.remote_working_folder + "_progress.json") g = self.client.copy_remote_file( os.path.join(self.remote_abs_working_folder, remote_filename), local_filename, ) joinall(g, raise_error=True) suffixed_local_filename = local_filename + "_" + self.hostname self.logger.debug(" Opening copied file") with open(suffixed_local_filename) as f: for line in f: self.logger.debug(line) filecontent.append(line) break # except SFTPIOError: # print("SFTPIOError") # return False except Exception as e: if retry: self.logger.debug( "exception {}, retrying".format(e) ) # pass # e.g. missing progress file as execution starts retry = False else: break self.logger.debug("End read_from_remote") return filecontent def get_status(self, jobid): """ :param jobid: :return: """ self.logger.debug("Inside get_status: executing sacct") cmd = constants.squeue_cmd + " -j {} -n -p -b".format(jobid) output = self.client.run_command(cmd) self.client.join(output) status_output = None # 1 line expected errmsg = next(output[self.hostname]["stderr"], None) if errmsg is not None: for e in output[self.hostname]["stderr"]: errmsg += e + "\n" self.logger.debug(" stderr: {}".format(errmsg)) raise Exception("Error: " + errmsg) stdout_str = "" for line in output[self.hostname]["stdout"]: # errmsg is None stdout_str += line + "\n" fields = line.split("|") if len(fields) >= 2: if fields[0] == jobid: status_output = fields[1].split()[0] if status_output is None: raise Exception( "Error parsing sacct output: {}".format(stdout_str)) if status_output not in [ "COMPLETED", "PENDING", "RUNNING", "TIMEOUT", "CANCELLED", ]: raise Exception( "Status error: state {} unknown".format(status_output)) return status_output def cancel_job(self, jobid): """ :param jobid: :return: """ cmd = constants.scancel_cmd + " {}".format(jobid) output = self.client.run_command(cmd) self.client.join(output) errmsg = next(output[self.hostname]["stderr"], None) if errmsg is not None: for e in output[self.hostname]["stderr"]: errmsg += e + "\n" self.logger.debug(" stderr: {}".format(errmsg)) raise Exception("Cancel error: " + errmsg) stdout_str = "" for line in output[self.hostname]["stdout"]: # errmsg is None stdout_str += line + "\n" if len(stdout_str) > 0: raise Exception("Cancel error: " + stdout_str) def reconnect(self): self.client = ParallelSSHClient([self.hostname], pkey=self.keypath, user=self.user, keepalive_seconds=300) """ def check_slurmoutput_for(self, substr, jobid): slurmfname = "slurm-" + jobid + ".out" local_slurmfname = os.path.join("/tmp", slurmfname) stdout_file = os.path.join(self.home_dir, slurmfname) found = False try: g = self.client.copy_remote_file(stdout_file, local_slurmfname) joinall(g, raise_error=True) # scan file for substr with open(local_slurmfname + "_" + self.hostname) as f: for line in f: print("comparing {} with {}".format(substr,line)) match_obj = re.search(substr, line) print(match_obj) if match_obj: found = True print("found") os.remove(local_slurmfname + "_" + self.hostname) except Exception as e: print("Exception inside check_slurmoutput_for") print(e) pass return found """ def cleanup(self, jobid): try: self.logger.debug("Deleting the remote folder") output1 = self.client.run_command("rm -rf {}".format( os.path.join(self.home_dir, self.remote_abs_working_folder))) self.logger.debug("Deleting the slurm log file") logfilepath = os.path.join(self.home_dir, "slurm-{}.out".format(jobid)) output2 = self.client.run_command("rm {}".format(logfilepath)) self.logger.debug("Deleting the local progress file") local_filename = os.path.join( "/tmp", self.remote_working_folder + "_progress.json") suffixed_local_filename = local_filename + "_" + self.hostname os.remove(suffixed_local_filename) self.logger.debug(next(output1[self.hostname]["stdout"])) self.logger.debug(next(output2[self.hostname]["stdout"])) self.logger.debug(next(output1[self.hostname]["stderr"])) self.logger.debug(next(output2[self.hostname]["stderr"])) except Exception as e: self.logger.debug("Hmm file cleanup failed: {}".format(e))
class ClusterConnector: def __init__( self, workspace, cluster_name, ssh_key, vm_type, admin_username="******", ): """Thin wrapper class around azureml.core.compute.AmlCluster Provides parallel ssh objects and helper for master node and all node commands and file copies. Usage: >>> cc = ClusterConnector(workspace, "MyCluster", sshkey, "Standard_ND40rs_v2") >>> cc.initialize(min_nodes=0, max_nodes=4, idle_timeout_secs=30) >>> cluster = cc.cluster >>> [print(node['name']) for node in cc.cluster.list_nodes()] """ self.cluster_name = cluster_name self.workspace = workspace self.ssh_key = ssh_key self.vm_type = vm_type self.admin_username = admin_username enable_host_logger() hlog = logging.getLogger("pssh.host_logger") tstr = datetime.now().isoformat(timespec="minutes") [ hlog.removeHandler(h) for h in hlog.handlers if isinstance(h, logging.StreamHandler) ] os.makedirs("clusterlogs", exist_ok=True) self.logfile = "clusterlogs/{}_{}.log".format(self.workspace.name, tstr) hlog.addHandler(logging.FileHandler(self.logfile)) self.cluster = None self._master_scp = None self._master_ssh = None self._all_ssh = None def initialise(self, min_nodes=0, max_nodes=0, idle_timeout_secs=1800): """Initialise underlying AmlCompute cluster instance""" self._create_or_update_cluster(min_nodes, max_nodes, idle_timeout_secs) def _check_logs_emessage(self, host, port): msg = "Remote command failed on {}:{}. For details see {}".format( host, port, self.logfile) return msg def terminate(self): print('Attempting to terminate cluster "{}"'.format( colored(self.cluster_name, "green"))) try: self.cluster.update(min_nodes=0, max_nodes=0, idle_seconds_before_scaledown=10) self.cluster.wait_for_completion() except ComputeTargetException as err: raise RuntimeError( "Failed to terminate cluster nodes ({})".format(err)) if len(self.cluster.list_nodes()): raise RuntimeError( "Failed to terminate cluster nodes (nodes still running)") @property def cluster_nodes(self): self.cluster.refresh_state() return sorted(self.cluster.list_nodes(), key=lambda n: n["port"]) def _create_or_update_cluster(self, min_nodes, max_nodes, idle_timeout_secs): try: self.cluster = AmlCompute(workspace=self.workspace, name=self.cluster_name) print('Updating existing cluster "{}"'.format( colored(self.cluster_name, "green"))) self.cluster.update( min_nodes=min_nodes, max_nodes=max_nodes, idle_seconds_before_scaledown=idle_timeout_secs, ) except ComputeTargetException: print('Creating new cluster "{}"'.format( colored(self.cluster_name, "green"))) cluster_config = AmlCompute.provisioning_configuration( vm_size=self.vm_type, min_nodes=min_nodes, max_nodes=max_nodes, idle_seconds_before_scaledown=idle_timeout_secs, admin_username=self.admin_username, admin_user_ssh_key=self.ssh_key, remote_login_port_public_access="Enabled", ) self.cluster = AmlCompute.create(self.workspace, self.cluster_name, cluster_config) self.cluster.wait_for_completion() if len(self.cluster_nodes) < min_nodes: sleep(30) if len(self.cluster_nodes) < min_nodes: raise RuntimeError("Failed to provision sufficient nodes") def _copy_nodefile_to_nodes(self): if len(self.cluster_nodes) == 1: cprint("Single node cluster -- skipping IB config", "yellow") return print("Collecting cluster IB info") outputs = self._all_ssh.run_command( r'ifconfig ib0 | grep -oe "inet[^6][adr: ]*[0-9.]*" | cut -d" " -f2', shell="bash -c", ) self._all_ssh.join(outputs) ibaddrs = [] for output in outputs: host = output.host port = output.client.port if output.exit_code != 0: print(list(output.stdout)) print(list(output.stderr)) raise RuntimeError("Failed to get IB ip for {}:{}".format( host, port)) try: ibaddr = list(output.stdout)[0].split()[0] except IndexError: raise RuntimeError("Failed to get IB ip for {}:{} - " "No ib interface found!".format(host, port)) print("Mapping {}:{} -> {}".format(host, port, ibaddr)) if port == self._master_scp.port: cprint("IB Master: {}".format(ibaddr), "green") ibaddrs = [ibaddr] + ibaddrs else: ibaddrs.append(ibaddr) with NamedTemporaryFile(delete=False, mode="wt") as nfh: self.nodefile = nfh.name for addr in ibaddrs: nfh.write("{}\n".format(addr)) self.ibaddrs = ibaddrs self.copy_to_all_nodes(self.nodefile, "./nodefile") def _create_cluster_ssh_conns(self): hostips = [n["publicIpAddress"] for n in self.cluster_nodes] hostconfigs = [HostConfig(port=n["port"]) for n in self.cluster_nodes] self._all_ssh = ParallelSSHClient(hostips, host_config=hostconfigs, user=self.admin_username) self._master_ssh = ParallelSSHClient(hostips[:1], host_config=hostconfigs[:1], user=self.admin_username) self._master_scp = SSHClient(hostips[0], port=hostconfigs[0].port, user=self.admin_username) def copy_to_all_nodes(self, source, dest): copy_jobs = self._all_ssh.copy_file(source, dest) joinall(copy_jobs, raise_error=True) def copy_to_master_node(self, source, dest): self._master_scp.copy_file(source, dest) def copy_from_master_node(self, source, dest): self._master_scp.copy_remote_file(source, dest) def run_on_all_nodes(self, command): outputs = self._all_ssh.run_command(command, shell="bash -c") self._all_ssh.join(outputs, consume_output=True) for output in outputs: if int(output.exit_code) != 0: host = output.host port = output.client.port raise RuntimeError(self._check_logs_emessage(host, port)) def run_on_master_node(self, command): outputs = self._master_ssh.run_command(command, shell="bash -c") self._master_ssh.join(outputs) for output in outputs: if int(output.exit_code) != 0: host = output.host port = output.client.port raise RuntimeError(self._check_logs_emessage(host, port)) def attempt_termination(self): try: self.terminate() except RuntimeError as err: print(colored("ERROR: {}\n\n", "red", attrs=["bold"]).format(err)) self.warn_unterminated() def warn_unterminated(self): print( colored("WARNING: {}", "red", attrs=["bold"]).format( colored( "Cluster {} is still running - terminate manually to avoid " "additional compute costs".format( colored(self.cluster_name, "green")), "red", )))
import datetime output = [] host = 'localhost' hosts = [host, host] client = ParallelSSHClient(hosts) # Run 10 five second sleeps cmds = ['sleep 5; uname' for _ in range(10)] start = datetime.datetime.now() for cmd in cmds: output.append( client.run_command(cmd, stop_on_errors=False, return_list=True)) end = datetime.datetime.now() print("Started %s 'sleep 5' commands on %s host(s) in %s" % ( len(cmds), len(hosts), end - start, )) start = datetime.datetime.now() for _output in output: client.join(_output) for host_out in _output: for line in host_out.stdout: print(line) for line in host_out.stderr: print(line) print(f"Exit code: {host_out.exit_code}") end = datetime.datetime.now() print("All commands finished in %s" % (end - start, ))
class Runner: def __init__(self, timeout, log: Log, hostNames=None, hostFile=None, quiet=False, only=[]): if hostFile and hostNames: raise ValueError('Cannot specify both "hostNames" and "hostFile"') elif hostNames: self.client = ParallelSSHClient(list(hostNames), timeout=timeout) self.hostConfig = None elif hostFile: self.hostConfig = getHostConfig(hostFile) self.client = ParallelSSHClient( list(self.hostConfig.keys()), host_config=self.hostConfig, timeout=timeout) else: raise ValueError('One of "hostNames" or "hostFile" is required') if only: for host in list(self.client.hosts): if str(host) in only or self._hostName(host) in only: continue self.client.hosts.remove(host) self.quiet = quiet self.log = log self.pool = ThreadPoolExecutor() self.justification = max( len(self._hostName(host)) for host in (hostNames or self.hostConfig.keys())) def runCommand(self, command, sudo, consumeOutput=None, logExit=True): consumeOutput = not self.quiet if consumeOutput is None else consumeOutput output = self.client.run_command(command, sudo=sudo, stop_on_errors=False) if sudo and self.hostConfig: for host in output: if (self.hostConfig[host].get('sudoRequiresPassword', False) and self.hostConfig[host].get('password', False)): stdin = output[host].stdin if stdin: pwd = self.hostConfig[host]['password'] stdin.write(f'{pwd}\n') stdin.flush() if consumeOutput: self._logCommand(output) self.client.join(output, consume_output=True) if logExit: for host, hostOutput in output.items(): self._print(host, f'exit code: {hostOutput.exit_code}') return output def runCommandList(self, commandList: List[Command], verbose=None): verbose = not self.quiet if verbose is None else verbose for command in commandList: if verbose: self.log(self.log.colour(Fore.GREEN, f'Running command {command.name}')) output = self.runCommand(command.commad, command.sudo, verbose, False) if command.abortOnError: for host, hostOutput in output.items(): if hostOutput.exit_code: text = (f'command {command.name} exited with error {hostOutput.exit_code}.', 'No further commands will be executed.') self._print(host, text, True) self.client.hosts.remove(host) if hostOutput.exception: text = (f'command {command.name} exited with exception' f'{hostOutput.exception}.No further commands will be executed.') self._print(host, text, True) self.client.hosts.remove(host) def close(self): self.pool.shutdown(False) def _hostName(self, host): if not self.hostConfig: return host return self.hostConfig[host].get('name', host) def _prefix(self, host): return self.log.colour(Fore.LIGHTBLUE_EX, f'[{self._hostName(host)}]'.ljust(self.justification + 2)) def _logCommand(self, output: Dict[str, HostOutput]): futures = [] for hostOutput in output.values(): f = self.pool.submit(self._logHostOutput, hostOutput) futures.append(f) wait(futures) def _print(self, host, line, isError=False): if isError: line = f'{self.log.colour(Fore.RED, "[err]")} {line}' text = f'{self._prefix(host)}\t{line}' fn = self.log.error if isError else self.log.print fn(text) def _logHostOutput(self, hostOutput: HostOutput): for line in hostOutput.stdout: self._print(hostOutput.host, line, False) for line in hostOutput.stderr: self._print(hostOutput.host, line, True)
from pssh.clients import ParallelSSHClient hosts = ['66.246.107.24', '66.246.107.15'] client = ParallelSSHClient(hosts, user="******", pkey="~/keys/denisl-pubkey.pem") cmd = 'python3 -c "$(curl -fsSL https://pgsql-io-download.s3.amazonaws.com/REPO/install.py)"' #cmd='/usr/bin/python3 --version' #cmd='uname' #cmd='yum install -y python3 python3-devel wget curl' output = client.run_command(cmd, sudo=True) client.join(output) for host_out in output: for line in host_out.stdout: print(line) sys.exit(1) shells = client.open_shell(read_timeout=10) client.run_shell_commands(shells, [cmd]) client.join_shells(shells) for shell in shells: stdout = list(shell.stdout) for s in stdout: print(s)