Beispiel #1
0
 def command(self, command, timeout=60, sudo=False):
     """Execute a command on the instances. This will be done using an ssh command and potentially with sudo"""
     logger.debug(f'Executing {command} with sudo {sudo}.')
     client = ParallelSSHClient([i.ip for i in self.hosts], pkey=self.keysfile)
     output = client.run_command(command, read_timeout=timeout, sudo=sudo)
     client.join()
     return output
def fix_hostnames():
    hosts = get_ips()
    client = ParallelSSHClient(hosts, user=USER)
    host_args = [{
        "cmd":
        "sudo hostnamectl set-hostname synthetic-bot-up-%s" % (i, )
    } for i in range(len(hosts))]
    output = client.run_command("%(cmd)s", host_args=host_args)
    client.join()
    def ssh_all_hosts(self, command):
        """ Sends given command to remote hosts via ssh

        Parameters:
            command: command to be sent over ssh """
        hosts = self.addresses
        client = ParallelSSHClient(hosts, user='******')
        client.run_command(command)
        client.join()
Beispiel #4
0
def parallel_ssh(host, user, password, command):
    #enable_host_logger()
    client = ParallelSSHClient([host], user=user, password=password)
    output = client.run_command(command)
    client.join()
    stdout = ""
    for host_output in output:
        if host_output.exit_code != 0:
            raise Exception("host returned exit code " +
                            str(host_output.exit_code))
        stdout_li = list(host_output.stdout)
        for line in stdout_li:
            stdout += line + "\n"
    return stdout
 def run_command(command,
                 hosts,
                 user,
                 verbose=False,
                 proxy_host=None,
                 timeout=10,
                 **kwargs):
     """Run ssh command using Parallel SSH."""
     result = {"0": [], "1": []}
     if proxy_host:
         client = ParallelSSHClient(hosts,
                                    user='******',
                                    pkey=SSH_KEY,
                                    proxy_host=proxy_host,
                                    proxy_user=user,
                                    proxy_pkey=SSH_KEY,
                                    timeout=timeout)
     else:
         client = ParallelSSHClient(hosts,
                                    user=user,
                                    pkey=SSH_KEY,
                                    timeout=timeout)
     output = client.run_command(command, stop_on_errors=False, **kwargs)
     client.join(output)
     # output = pssh.output.HostOutput objects list
     for host in output:
         if host.exit_code == 0:
             if verbose and host.stdout:
                 for line in host.stdout:
                     print(line)
             result['0'].append(host.host)
         elif host.host is not None:
             result['1'].append(host.host)
     # find hosts that have raised Exception (Authentication, Connection)
     # host.exception = pssh.exceptions.* & host.host = None
     failed_hosts = list(set(hosts) - set(sum(result.values(), [])))
     if failed_hosts:
         result['1'].extend(failed_hosts)
     return result
Beispiel #6
0
    def run_command(command,
                    hosts,
                    user,
                    verbose=False,
                    proxy_host=None,
                    timeout=10,
                    **kwargs):
        """Run ssh command using Parallel SSH."""
        result = {"0": [], "1": []}
        if proxy_host:
            client = ParallelSSHClient(hosts,
                                       user='******',
                                       proxy_host=proxy_host,
                                       proxy_user=user,
                                       timeout=timeout)
        else:
            client = ParallelSSHClient(hosts, user=user, timeout=timeout)
        output = client.run_command(command, stop_on_errors=False, **kwargs)
        client.join(output)
        for host in hosts:
            if host not in output:
                # Pssh AuthenticationException duplicate output dict key
                # {'saclay.iot-lab.info': {'exception': ...},
                # {'saclay.iot-lab.info_qzhtyxlt': {'exception': ...}}
                site = next(iter(sorted(output)))
                raise OpenA8SshAuthenticationException(site)
            result['0' if output[host]['exit_code'] == 0 else '1'].append(host)
        if verbose:
            for host in hosts:
                # Pssh >= 1.0.0: stdout is None instead of generator object
                # when you have ConnectionErrorException
                stdout = output[host].get('stdout')
                if stdout:
                    for _ in stdout:
                        pass

        return result
Beispiel #7
0
class SSHManager(object):
    context_changed = False

    def __init__(self, hosts, host_config):
        self.hosts = sorted(hosts)
        self.all_hosts = copy.deepcopy(self.hosts)
        self.client = ParallelSSHClient(self.hosts, host_config=host_config)

    def remove_hosts(self, hosts):
        indices = []
        new_hosts = []
        for i in range(len(self.all_hosts)):
            if self.all_hosts[i] in hosts:
                indices.append(i)
            else:
                new_hosts.append(self.all_hosts[i])
        self.all_hosts = new_hosts
        if self.context_changed:
            self.hosts = list(filter(lambda h: h not in hosts, self.hosts))
        else:
            self.hosts = copy.deepcopy(self.all_hosts)
        self.client.host_config = list(
            map(
                lambda x: x[1],
                filter(lambda x: x[0] not in indices,
                       enumerate(self.client.host_config)),
            ))
        self.client.hosts = self.all_hosts

    def add_host(self, host):
        self.all_hosts.append(host.host)
        self.all_hosts = sorted(list(set(self.all_hosts)))
        host_configs = self.client.host_config
        idx = self.all_hosts.index(host.host)
        host_configs.insert(idx, host.build_host_config())
        self.client.hosts = self.all_hosts
        self.client.host_config = host_configs
        if not self.context_changed:
            self.hosts = copy.deepcopy(self.all_hosts)

    def run_command(self, command, commands=None, sudo=False):
        if commands is None:
            return self.client.run_command(command, sudo=sudo)
        else:
            return self.client.run_command(command,
                                           host_args=commands,
                                           sudo=sudo)

    def join(self, output):
        self.client.join(output)

    def change_context_hosts_all(self):
        self.change_context_hosts(self.all_hosts)

    def change_context_hosts(self, new_hosts):
        new_hosts = sorted(list(set(new_hosts)))

        for h in new_hosts:
            if h not in self.all_hosts:
                raise SSHManager.ContextException(f"Host {h} not in host list")

        self.hosts = list(filter(lambda h: h in new_hosts, self.all_hosts))
        self.context_changed = True

    def change_context_indices(self, indices):
        indices = sorted(indices)
        if indices[0] < 0 or indices[len(indices) - 1] >= len(self.all_hosts):
            raise SSHManager.ContextException("Indices out of range")
        new_hosts = []
        for i in indices:
            new_hosts.append(self.all_hosts[i])
        self.change_context_hosts(new_hosts)

    def reset_context(self):
        self.hosts = copy.deepcopy(self.all_hosts)
        self.context_changed = False

    @staticmethod
    def build_host_config(*,
                          n=0,
                          users=None,
                          passwords=None,
                          user=None,
                          password=None):
        if bool(users is None) == bool(user is None):
            raise HostConfigException(
                "Users or User (not both) needs to be defined")
        elif bool(passwords is None) == bool(password is None):
            raise HostConfigException(
                "Passwords or Password (not both) needs to be defined")
        elif passwords != None and len(passwords) != n:
            raise HostConfigException("Length of Passwords != n")
        elif users != None and len(users) != n:
            raise HostConfigException("Length of Users != n")
        elif n <= 0:
            raise HostConfigException("n should be greater than 0")

        if user is None and password is None:
            return [
                HostConfig(user=users[i], password=passwords[i])
                for i in range(n)
            ]
        elif user is None:
            return [
                HostConfig(user=users[i], password=password) for i in range(n)
            ]
        elif password is None:
            return [
                HostConfig(user=user, password=password[i]) for i in range(n)
            ]
        else:
            return [HostConfig(user=user, password=password) for i in range(n)]

    class ContextException(Exception):
        pass
Beispiel #8
0
class HPCConnection(object):
    def __init__(self, external_init_dict=None):

        self.logger = logging.getLogger(constants.logging_name)
        init_dict = {}
        clsname = self.__class__.__name__
        if external_init_dict is not None:
            self.logger.debug(
                "{}: initializing from external dict".format(clsname))
            init_dict = external_init_dict
        else:
            self.logger.debug(
                "{}: initializing with default values".format(clsname))

        self.hostname = constants.hpc_hostname
        self.user = constants.user
        self.home_dir = os.path.join(constants.cc_working_dir, self.user)
        self.src_data_path = init_dict.get("src_data_path", "./data")
        self.template_path = constants.template_path
        self.logger.debug("Host being used is {}, under username {}".format(
            self.hostname, self.user))
        self.keypath = init_dict.get("ssh_key_filename",
                                     constants.ssh_key_filename)
        self.client = ParallelSSHClient([self.hostname],
                                        pkey=self.keypath,
                                        user=self.user,
                                        keepalive_seconds=300)
        self.remote_abs_working_folder = None
        self.remote_working_folder = None
        self.active_dataset_name = None
        self.live_job_id = None

    def check_connection(self):

        status = True
        msg = None
        self.logger.debug("Testing connection...")
        try:
            self.client.run_command("ls")
            self.logger.debug("... ok")
        except (
                AuthenticationException,
                UnknownHostException,
                ConnectionErrorException,
        ) as e:
            status = False
            msg = str(e)
            self.logger.debug("... failed ({})".format(msg))

        return status, msg

    def copy_data_to_remote(self, dataset_, remote_temp_folder=None):
        """
        Copies data contained in a local directory over to a remote destination
        """

        self.logger.debug(
            "Copying data to remote location (from {} to {})".format(
                self.src_data_path, self.home_dir))
        remote_base_path = self.home_dir
        local_datapath = self.src_data_path
        if remote_temp_folder is None:
            remote_temp_folder = rand_fname()

        full_remote_path = os.path.join(remote_base_path, remote_temp_folder)
        remote_tar = os.path.join(full_remote_path, "data.tar")
        self.remote_abs_working_folder = full_remote_path
        self.remote_working_folder = remote_temp_folder
        self.active_dataset_name = dataset_
        self.logger.debug("Creating remote folder {}".format(full_remote_path))
        self.client.run_command("mkdir " + full_remote_path)

        #    data_path_content = os.listdir(path=src_data_path)
        #    assert(len(data_path_content) == 1)
        #    df_basename = data_path_content[0]
        df_basename = dataset_

        # self.logger.debug("system cmd: " + "tar cvf /tmp/" + remote_temp_folder + ".tar -C "
        #                   + os.path.join(local_datapath, df_basename) + " .")
        self.logger.debug("system cmd: tar cvf /tmp/{}.tar -C {} .".format(
            remote_temp_folder, os.path.join(local_datapath, df_basename)))

        os.system("tar cf /tmp/" + remote_temp_folder + ".tar -C " +
                  os.path.join(local_datapath, df_basename) + " .")
        try:
            self.logger.debug("Copying data tar file")
            g = self.client.scp_send("/tmp/" + remote_temp_folder + ".tar",
                                     remote_tar)
            joinall(g, raise_error=True)
        except SCPError as e:
            self.logger.error("Copy failed (scp error {})".format(e))
        except Exception as e:
            self.logger.error("Copy failed: {}".format(e))
            raise Exception("scp_send failed")

        s = "tar xvf " + remote_tar + " -C " + full_remote_path
        self.logger.debug("Untarring remote data: {}".format(s))

        output = self.client.run_command(s)
        self.client.join(output)

        errmsg = next(output[self.hostname]["stderr"], None)
        if errmsg is not None:
            self.logger.error("Error: " + errmsg)
            raise Exception("Error untarring data file: " + errmsg)

        errmsg = next(output[self.hostname]["stdout"], None)
        if errmsg is not None:
            self.logger.debug("stdout: " + errmsg)

        self.logger.debug("Remove remote temp tar file " + "/tmp/" +
                          remote_temp_folder + ".tar")
        os.remove("/tmp/" + remote_temp_folder + ".tar")

    # output files in base_dir/jobname/out
    def copy_data_from_remote(self,
                              jobid,
                              absolute_local_out_dir,
                              cleanup_temp=True):

        self.logger.debug("Copying data from remote")
        absolute_tar_fname = os.path.join(
            self.remote_abs_working_folder,
            self.remote_working_folder + "_out.tar")

        absolute_output_data_path = os.path.join(
            self.remote_abs_working_folder, "out")
        stdout_file = os.path.join(self.home_dir, "slurm-" + jobid + ".out")
        self.logger.debug(
            "  Remote data is located in {}".format(absolute_output_data_path))
        self.logger.debug("  Slurm output file is {}".format(stdout_file))

        try:
            self.logger.debug(
                "  Copying slurm file to {}".format(absolute_output_data_path))
            output = self.client.run_command("cp " + stdout_file + " " +
                                             absolute_output_data_path)
            self.client.join(output)
            self.logger.debug(output)
            self.logger.debug("  Tarring remote folder")
            output = self.client.run_command("tar cf " + absolute_tar_fname +
                                             " -C " +
                                             absolute_output_data_path + " .")
            self.client.join(output)
            self.logger.debug(output)
            # time.sleep(30)  # patch since run_command sems non-blocking
            self.logger.debug("Picking up tar file size")
            output = self.client.run_command("du -sb " + absolute_tar_fname)
            self.client.join(output)
            self.logger.debug(output)
            line = ""
            for char in output[self.hostname].stdout:
                line += char
            # print(line)
            tar_size = int(re.match("[0-9]*", line).group(0))

            self.logger.info("{} bytes to copy from remote".format(tar_size))
            local_tar = "/tmp/" + self.remote_working_folder + "_out.tar"
            # g = self.client.scp_recv(absolute_tar_fname, local_tar)
            self.logger.debug(
                "Remote tar file is {}".format(absolute_tar_fname))

            tries = 0
            while tries < 3:
                self.logger.debug("Copying tar file to /tmp")
                g = self.client.copy_remote_file(absolute_tar_fname,
                                                 local_tar)  # scp_recv
                joinall(g, raise_error=True)

                output = subprocess.check_output("du -sb " + local_tar + "_" +
                                                 self.hostname,
                                                 shell=True)
                recv_tar_size = int(
                    re.match("[0-9]*", output.decode("utf-8")).group(0))

                self.logger.debug("Received: {} bytes".format(recv_tar_size))
                if recv_tar_size == tar_size:
                    break
                tries += 1

            if tries == 3:
                raise Exception("Unable to copy tar file from remote end")

            if not os.path.exists(absolute_local_out_dir):
                # shutil.rmtree(absolute_local_out_dir)
                self.logger.debug(
                    "Local destination folder {} does not exist, creating".
                    format(absolute_local_out_dir))
                os.mkdir(absolute_local_out_dir)

            # os.mkdir(path.join(absolute_local_out_dir,jobname)
            self.logger.debug(
                "Untarring received file to {}".format(absolute_local_out_dir))
            os.system("tar xf " + local_tar + "_" + self.hostname + " -C " +
                      absolute_local_out_dir)
            if cleanup_temp:
                # print("todo: cleanup tmp file")
                os.remove(local_tar + "_" + self.hostname)

        except Exception as e:
            self.logger.error(
                "Exception during file transfer from remote: {}".format(e))

    def copy_singlefile_to_remote(self,
                                  local_filename,
                                  remote_path=".",
                                  is_executable=False):

        r = os.path.join(
            self.remote_abs_working_folder,
            remote_path,
            os.path.basename(local_filename),
        )
        g = self.client.copy_file(local_filename, r)
        joinall(g, raise_error=True)
        if is_executable:
            self.client.run_command("chmod ugo+x " + r)

    def create_remote_subdir(self, remote_subdir):

        self.client.run_command(
            "mkdir -p " +
            os.path.join(self.remote_abs_working_folder, remote_subdir))
        self.client.run_command(
            "chmod 777 " +
            os.path.join(self.remote_abs_working_folder, remote_subdir))

    # executable_ is either raven or ostrich

    def copy_batchscript(
        self,
        executable_,
        guessed_duration,
        datafile_basename,
        batch_tmplt_fname,
        shub_hostname,
    ):

        template_file = open(
            os.path.join(self.template_path, batch_tmplt_fname), "r")
        abs_remote_output_dir = os.path.join(self.remote_abs_working_folder,
                                             "out")
        tmplt = template_file.read()
        tmplt = tmplt.replace("ACCOUNT", constants.cc_account_info)
        tmplt = tmplt.replace("DURATION", guessed_duration)
        tmplt = tmplt.replace("TEMP_PATH", self.remote_abs_working_folder)
        tmplt = tmplt.replace("INPUT_PATH", self.remote_abs_working_folder)
        tmplt = tmplt.replace("OUTPUT_PATH", abs_remote_output_dir)
        tmplt = tmplt.replace("DATAFILE_BASENAME", datafile_basename)
        tmplt = tmplt.replace("SHUB_HOSTNAME", shub_hostname)
        tmplt = tmplt.replace("EXEC", executable_)

        # subst_template_file, subst_fname = tempfile.mkstemp(suffix=".sh")
        subst_fname = self.remote_working_folder + ".sh"
        file = open("/tmp/" + subst_fname, "w")
        file.write(tmplt)
        file.close()

        self.client.run_command("mkdir " + abs_remote_output_dir)
        self.client.run_command("chmod 777 " + self.remote_abs_working_folder)
        self.client.run_command("chmod 777 " + abs_remote_output_dir)
        g = self.client.copy_file(
            "/tmp/" + subst_fname,
            os.path.join(self.remote_abs_working_folder, subst_fname),
        )
        joinall(g, raise_error=True)
        self.client.run_command(
            "chmod ugo+x " +
            os.path.join(self.remote_abs_working_folder, subst_fname))
        os.remove("/tmp/" + subst_fname)

        return os.path.join(self.remote_abs_working_folder, subst_fname)

    def submit_job(self, script_fname):

        self.logger.debug("Submitting job {}".format(script_fname))
        # output = self.client.run_command("cd {}; ".format(self.home_dir) + constants.sbatch_cmd +
        #                                  " --parsable " + script_fname)
        output = self.client.run_command("cd {}; {} --parsable {}".format(
            self.home_dir, constants.sbatch_cmd, script_fname))
        self.client.join(output)
        errmsg = next(output[self.hostname]["stderr"], None)

        if errmsg is not None:
            for e in output[self.hostname]["stderr"]:
                errmsg += e + "\n"

            self.logger.error("  Error: {}".format(errmsg))
            raise Exception("Error: " + errmsg)

        self.live_job_id = next(output[self.hostname]["stdout"])
        self.logger.debug("  Job id {}".format(self.live_job_id))

        return self.live_job_id

    def read_from_remote(self, remote_filename):

        filecontent = []
        self.logger.debug("read_from_remote")
        retry = True
        # maybe remote file is being overwritten, try again if remote copy fails
        while True:
            try:
                local_filename = os.path.join(
                    "/tmp", self.remote_working_folder + "_progress.json")
                g = self.client.copy_remote_file(
                    os.path.join(self.remote_abs_working_folder,
                                 remote_filename),
                    local_filename,
                )
                joinall(g, raise_error=True)
                suffixed_local_filename = local_filename + "_" + self.hostname
                self.logger.debug("  Opening copied file")
                with open(suffixed_local_filename) as f:
                    for line in f:
                        self.logger.debug(line)
                        filecontent.append(line)
                break
            #        except SFTPIOError:
            #            print("SFTPIOError")
            #            return False
            except Exception as e:

                if retry:
                    self.logger.debug(
                        "exception {}, retrying".format(e)
                    )  # pass # e.g. missing progress file as execution starts
                    retry = False
                else:
                    break

        self.logger.debug("End read_from_remote")
        return filecontent

    def get_status(self, jobid):
        """
        :param jobid:
        :return:
        """
        self.logger.debug("Inside get_status: executing sacct")
        cmd = constants.squeue_cmd + " -j {} -n -p -b".format(jobid)

        output = self.client.run_command(cmd)
        self.client.join(output)
        status_output = None  # 1 line expected

        errmsg = next(output[self.hostname]["stderr"], None)
        if errmsg is not None:
            for e in output[self.hostname]["stderr"]:
                errmsg += e + "\n"
            self.logger.debug("  stderr: {}".format(errmsg))

            raise Exception("Error: " + errmsg)

        stdout_str = ""
        for line in output[self.hostname]["stdout"]:  # errmsg is None
            stdout_str += line + "\n"
            fields = line.split("|")
            if len(fields) >= 2:
                if fields[0] == jobid:
                    status_output = fields[1].split()[0]

        if status_output is None:
            raise Exception(
                "Error parsing sacct output: {}".format(stdout_str))

        if status_output not in [
                "COMPLETED",
                "PENDING",
                "RUNNING",
                "TIMEOUT",
                "CANCELLED",
        ]:
            raise Exception(
                "Status error: state {} unknown".format(status_output))

        return status_output

    def cancel_job(self, jobid):
        """
        :param jobid:
        :return:
        """
        cmd = constants.scancel_cmd + " {}".format(jobid)

        output = self.client.run_command(cmd)
        self.client.join(output)
        errmsg = next(output[self.hostname]["stderr"], None)
        if errmsg is not None:
            for e in output[self.hostname]["stderr"]:
                errmsg += e + "\n"
            self.logger.debug("  stderr: {}".format(errmsg))

            raise Exception("Cancel error: " + errmsg)

        stdout_str = ""
        for line in output[self.hostname]["stdout"]:  # errmsg is None
            stdout_str += line + "\n"
        if len(stdout_str) > 0:
            raise Exception("Cancel error: " + stdout_str)

    def reconnect(self):

        self.client = ParallelSSHClient([self.hostname],
                                        pkey=self.keypath,
                                        user=self.user,
                                        keepalive_seconds=300)

    """
    def check_slurmoutput_for(self, substr, jobid):

            slurmfname = "slurm-" + jobid + ".out"
            local_slurmfname = os.path.join("/tmp", slurmfname)
            stdout_file = os.path.join(self.home_dir, slurmfname)
            found = False
            try:
                g = self.client.copy_remote_file(stdout_file, local_slurmfname)
                joinall(g, raise_error=True)
                # scan file for substr
                with open(local_slurmfname + "_" + self.hostname) as f:
                    for line in f:
                        print("comparing {} with {}".format(substr,line))
                        match_obj = re.search(substr, line)
                        print(match_obj)
                        if match_obj:
                            found = True
                            print("found")

                os.remove(local_slurmfname + "_" + self.hostname)

            except Exception as e:
                print("Exception inside check_slurmoutput_for")
                print(e)
                pass

            return found
    """

    def cleanup(self, jobid):

        try:
            self.logger.debug("Deleting the remote folder")
            output1 = self.client.run_command("rm -rf {}".format(
                os.path.join(self.home_dir, self.remote_abs_working_folder)))
            self.logger.debug("Deleting the slurm log file")
            logfilepath = os.path.join(self.home_dir,
                                       "slurm-{}.out".format(jobid))
            output2 = self.client.run_command("rm {}".format(logfilepath))
            self.logger.debug("Deleting the local progress file")
            local_filename = os.path.join(
                "/tmp", self.remote_working_folder + "_progress.json")
            suffixed_local_filename = local_filename + "_" + self.hostname
            os.remove(suffixed_local_filename)

            self.logger.debug(next(output1[self.hostname]["stdout"]))
            self.logger.debug(next(output2[self.hostname]["stdout"]))
            self.logger.debug(next(output1[self.hostname]["stderr"]))
            self.logger.debug(next(output2[self.hostname]["stderr"]))

        except Exception as e:
            self.logger.debug("Hmm file cleanup failed: {}".format(e))
class ClusterConnector:
    def __init__(
        self,
        workspace,
        cluster_name,
        ssh_key,
        vm_type,
        admin_username="******",
    ):
        """Thin wrapper class around azureml.core.compute.AmlCluster

        Provides parallel ssh objects and helper for master node and all node commands
        and file copies.

        Usage:
        >>> cc = ClusterConnector(workspace, "MyCluster", sshkey, "Standard_ND40rs_v2")
        >>> cc.initialize(min_nodes=0, max_nodes=4, idle_timeout_secs=30)
        >>> cluster = cc.cluster
        >>> [print(node['name']) for node in cc.cluster.list_nodes()]
        """

        self.cluster_name = cluster_name
        self.workspace = workspace
        self.ssh_key = ssh_key
        self.vm_type = vm_type
        self.admin_username = admin_username

        enable_host_logger()
        hlog = logging.getLogger("pssh.host_logger")
        tstr = datetime.now().isoformat(timespec="minutes")
        [
            hlog.removeHandler(h) for h in hlog.handlers
            if isinstance(h, logging.StreamHandler)
        ]
        os.makedirs("clusterlogs", exist_ok=True)
        self.logfile = "clusterlogs/{}_{}.log".format(self.workspace.name,
                                                      tstr)
        hlog.addHandler(logging.FileHandler(self.logfile))

        self.cluster = None
        self._master_scp = None
        self._master_ssh = None
        self._all_ssh = None

    def initialise(self, min_nodes=0, max_nodes=0, idle_timeout_secs=1800):
        """Initialise underlying AmlCompute cluster instance"""
        self._create_or_update_cluster(min_nodes, max_nodes, idle_timeout_secs)

    def _check_logs_emessage(self, host, port):
        msg = "Remote command failed on {}:{}. For details see {}".format(
            host, port, self.logfile)
        return msg

    def terminate(self):

        print('Attempting to terminate cluster "{}"'.format(
            colored(self.cluster_name, "green")))
        try:
            self.cluster.update(min_nodes=0,
                                max_nodes=0,
                                idle_seconds_before_scaledown=10)
            self.cluster.wait_for_completion()
        except ComputeTargetException as err:
            raise RuntimeError(
                "Failed to terminate cluster nodes ({})".format(err))

        if len(self.cluster.list_nodes()):
            raise RuntimeError(
                "Failed to terminate cluster nodes (nodes still running)")

    @property
    def cluster_nodes(self):
        self.cluster.refresh_state()
        return sorted(self.cluster.list_nodes(), key=lambda n: n["port"])

    def _create_or_update_cluster(self, min_nodes, max_nodes,
                                  idle_timeout_secs):

        try:
            self.cluster = AmlCompute(workspace=self.workspace,
                                      name=self.cluster_name)
            print('Updating existing cluster "{}"'.format(
                colored(self.cluster_name, "green")))
            self.cluster.update(
                min_nodes=min_nodes,
                max_nodes=max_nodes,
                idle_seconds_before_scaledown=idle_timeout_secs,
            )
        except ComputeTargetException:
            print('Creating new cluster "{}"'.format(
                colored(self.cluster_name, "green")))
            cluster_config = AmlCompute.provisioning_configuration(
                vm_size=self.vm_type,
                min_nodes=min_nodes,
                max_nodes=max_nodes,
                idle_seconds_before_scaledown=idle_timeout_secs,
                admin_username=self.admin_username,
                admin_user_ssh_key=self.ssh_key,
                remote_login_port_public_access="Enabled",
            )
            self.cluster = AmlCompute.create(self.workspace, self.cluster_name,
                                             cluster_config)

        self.cluster.wait_for_completion()

        if len(self.cluster_nodes) < min_nodes:
            sleep(30)
            if len(self.cluster_nodes) < min_nodes:
                raise RuntimeError("Failed to provision sufficient nodes")

    def _copy_nodefile_to_nodes(self):

        if len(self.cluster_nodes) == 1:
            cprint("Single node cluster -- skipping IB config", "yellow")
            return

        print("Collecting cluster IB info")

        outputs = self._all_ssh.run_command(
            r'ifconfig ib0 | grep -oe "inet[^6][adr: ]*[0-9.]*" | cut -d" " -f2',
            shell="bash -c",
        )
        self._all_ssh.join(outputs)

        ibaddrs = []
        for output in outputs:
            host = output.host
            port = output.client.port
            if output.exit_code != 0:
                print(list(output.stdout))
                print(list(output.stderr))
                raise RuntimeError("Failed to get IB ip for {}:{}".format(
                    host, port))
            try:
                ibaddr = list(output.stdout)[0].split()[0]
            except IndexError:
                raise RuntimeError("Failed to get IB ip for {}:{} - "
                                   "No ib interface found!".format(host, port))
            print("Mapping {}:{} -> {}".format(host, port, ibaddr))
            if port == self._master_scp.port:
                cprint("IB Master: {}".format(ibaddr), "green")
                ibaddrs = [ibaddr] + ibaddrs
            else:
                ibaddrs.append(ibaddr)

        with NamedTemporaryFile(delete=False, mode="wt") as nfh:
            self.nodefile = nfh.name
            for addr in ibaddrs:
                nfh.write("{}\n".format(addr))

        self.ibaddrs = ibaddrs
        self.copy_to_all_nodes(self.nodefile, "./nodefile")

    def _create_cluster_ssh_conns(self):

        hostips = [n["publicIpAddress"] for n in self.cluster_nodes]
        hostconfigs = [HostConfig(port=n["port"]) for n in self.cluster_nodes]

        self._all_ssh = ParallelSSHClient(hostips,
                                          host_config=hostconfigs,
                                          user=self.admin_username)

        self._master_ssh = ParallelSSHClient(hostips[:1],
                                             host_config=hostconfigs[:1],
                                             user=self.admin_username)

        self._master_scp = SSHClient(hostips[0],
                                     port=hostconfigs[0].port,
                                     user=self.admin_username)

    def copy_to_all_nodes(self, source, dest):

        copy_jobs = self._all_ssh.copy_file(source, dest)
        joinall(copy_jobs, raise_error=True)

    def copy_to_master_node(self, source, dest):

        self._master_scp.copy_file(source, dest)

    def copy_from_master_node(self, source, dest):

        self._master_scp.copy_remote_file(source, dest)

    def run_on_all_nodes(self, command):

        outputs = self._all_ssh.run_command(command, shell="bash -c")
        self._all_ssh.join(outputs, consume_output=True)

        for output in outputs:
            if int(output.exit_code) != 0:
                host = output.host
                port = output.client.port
                raise RuntimeError(self._check_logs_emessage(host, port))

    def run_on_master_node(self, command):

        outputs = self._master_ssh.run_command(command, shell="bash -c")
        self._master_ssh.join(outputs)

        for output in outputs:
            if int(output.exit_code) != 0:
                host = output.host
                port = output.client.port
                raise RuntimeError(self._check_logs_emessage(host, port))

    def attempt_termination(self):
        try:
            self.terminate()
        except RuntimeError as err:
            print(colored("ERROR: {}\n\n", "red", attrs=["bold"]).format(err))
            self.warn_unterminated()

    def warn_unterminated(self):
        print(
            colored("WARNING: {}", "red", attrs=["bold"]).format(
                colored(
                    "Cluster {} is still running - terminate manually to avoid "
                    "additional compute costs".format(
                        colored(self.cluster_name, "green")),
                    "red",
                )))
import datetime

output = []
host = 'localhost'
hosts = [host, host]
client = ParallelSSHClient(hosts)

# Run 10 five second sleeps
cmds = ['sleep 5; uname' for _ in range(10)]
start = datetime.datetime.now()
for cmd in cmds:
    output.append(
        client.run_command(cmd, stop_on_errors=False, return_list=True))
end = datetime.datetime.now()
print("Started %s 'sleep 5' commands on %s host(s) in %s" % (
    len(cmds),
    len(hosts),
    end - start,
))
start = datetime.datetime.now()
for _output in output:
    client.join(_output)
    for host_out in _output:
        for line in host_out.stdout:
            print(line)
        for line in host_out.stderr:
            print(line)
        print(f"Exit code: {host_out.exit_code}")
end = datetime.datetime.now()
print("All commands finished in %s" % (end - start, ))
Beispiel #11
0
class Runner:
    def __init__(self, timeout, log: Log, hostNames=None, hostFile=None, quiet=False, only=[]):
        if hostFile and hostNames:
            raise ValueError('Cannot specify both "hostNames" and "hostFile"')
        elif hostNames:
            self.client = ParallelSSHClient(list(hostNames), timeout=timeout)
            self.hostConfig = None
        elif hostFile:
            self.hostConfig = getHostConfig(hostFile)
            self.client = ParallelSSHClient(
                list(self.hostConfig.keys()), host_config=self.hostConfig, timeout=timeout)
        else:
            raise ValueError('One of "hostNames" or "hostFile" is required')
        if only:
            for host in list(self.client.hosts):
                if str(host) in only or self._hostName(host) in only:
                    continue
                self.client.hosts.remove(host)
        self.quiet = quiet
        self.log = log
        self.pool = ThreadPoolExecutor()
        self.justification = max(
            len(self._hostName(host)) for host in (hostNames or self.hostConfig.keys()))

    def runCommand(self, command, sudo, consumeOutput=None, logExit=True):
        consumeOutput = not self.quiet if consumeOutput is None else consumeOutput
        output = self.client.run_command(command, sudo=sudo, stop_on_errors=False)
        if sudo and self.hostConfig:
            for host in output:
                if (self.hostConfig[host].get('sudoRequiresPassword', False)
                        and self.hostConfig[host].get('password', False)):
                    stdin = output[host].stdin
                    if stdin:
                        pwd = self.hostConfig[host]['password']
                        stdin.write(f'{pwd}\n')
                        stdin.flush()
        if consumeOutput:
            self._logCommand(output)
        self.client.join(output, consume_output=True)
        if logExit:
            for host, hostOutput in output.items():
                self._print(host, f'exit code: {hostOutput.exit_code}')
        return output

    def runCommandList(self, commandList: List[Command], verbose=None):
        verbose = not self.quiet if verbose is None else verbose
        for command in commandList:
            if verbose:
                self.log(self.log.colour(Fore.GREEN, f'Running command {command.name}'))
            output = self.runCommand(command.commad, command.sudo, verbose, False)
            if command.abortOnError:
                for host, hostOutput in output.items():
                    if hostOutput.exit_code:
                        text = (f'command {command.name} exited with error {hostOutput.exit_code}.',
                                'No further commands will be executed.')
                        self._print(host, text, True)
                        self.client.hosts.remove(host)
                    if hostOutput.exception:
                        text = (f'command {command.name} exited with exception'
                                f'{hostOutput.exception}.No further commands will be executed.')
                        self._print(host, text, True)
                        self.client.hosts.remove(host)

    def close(self):
        self.pool.shutdown(False)

    def _hostName(self, host):
        if not self.hostConfig:
            return host
        return self.hostConfig[host].get('name', host)

    def _prefix(self, host):
        return self.log.colour(Fore.LIGHTBLUE_EX,
                               f'[{self._hostName(host)}]'.ljust(self.justification + 2))

    def _logCommand(self, output: Dict[str, HostOutput]):
        futures = []
        for hostOutput in output.values():
            f = self.pool.submit(self._logHostOutput, hostOutput)
            futures.append(f)

        wait(futures)

    def _print(self, host, line, isError=False):
        if isError:
            line = f'{self.log.colour(Fore.RED, "[err]")} {line}'
        text = f'{self._prefix(host)}\t{line}'
        fn = self.log.error if isError else self.log.print
        fn(text)

    def _logHostOutput(self, hostOutput: HostOutput):
        for line in hostOutput.stdout:
            self._print(hostOutput.host, line, False)
        for line in hostOutput.stderr:
            self._print(hostOutput.host, line, True)
Beispiel #12
0
from pssh.clients import ParallelSSHClient

hosts = ['66.246.107.24', '66.246.107.15']

client = ParallelSSHClient(hosts,
                           user="******",
                           pkey="~/keys/denisl-pubkey.pem")

cmd = 'python3 -c "$(curl -fsSL https://pgsql-io-download.s3.amazonaws.com/REPO/install.py)"'
#cmd='/usr/bin/python3 --version'
#cmd='uname'
#cmd='yum install -y python3 python3-devel wget curl'

output = client.run_command(cmd, sudo=True)
client.join(output)
for host_out in output:
    for line in host_out.stdout:
        print(line)

sys.exit(1)

shells = client.open_shell(read_timeout=10)
client.run_shell_commands(shells, [cmd])
client.join_shells(shells)

for shell in shells:
    stdout = list(shell.stdout)
    for s in stdout:
        print(s)