Ejemplo n.º 1
0
 def _exec_scp(self):
     client = ParallelSSHClient(self.hosts, port=self.port)
     output = client.copy_file(self.source, self.destination, True)
     joinall(output, raise_error=True)
     nice_output = dict()
     for host in output:
         nice_output[host] = {'stdout': [], 'stderr': []}
     return nice_output
Ejemplo n.º 2
0
def deploy_remote(host, user, password, port, path_to_private_key):
    path_to_monyze_agent = sys.argv[0]
    if DEBUG:
        path_to_monyze_agent = 'dist/monyze-agent'
    filename_bin = str(basename(path_to_monyze_agent))
    filename_sh = '/etc/init.d/monyze-agent'
    dist_filename_bin = '/usr/local/bin/' + filename_bin
    hosts = list()
    hosts.append(host)

    try:
        # client = ParallelSSHClient(hosts, user=user, password=password, port=int(port), pkey=path_to_private_key)
        client = ParallelSSHClient(hosts,
                                   user=user,
                                   port=int(port),
                                   pkey=path_to_private_key)

        print('Copying monyze-agent to remote hosts...')
        sys.stdout.flush()

        # deploy ELF monyze-agent
        os.system("scp -i " + path_to_private_key + " -P " + port + " " +
                  path_to_monyze_agent + " " + user + "@" + host + ":" +
                  str(basename(path_to_monyze_agent)))
        remote_sudo_cmd_run(
            client, password,
            'mv' + ' ' + filename_bin + ' ' + dist_filename_bin)
        remote_sudo_cmd_run(client, password,
                            'chmod 755' + ' ' + dist_filename_bin)

        # deploy shell monyze-agent
        greenlets = client.copy_file(filename_sh, str(basename(filename_sh)))
        joinall(greenlets, raise_error=True)
        remote_sudo_cmd_run(
            client, password,
            'mv ' + str(basename(filename_sh)) + ' ' + filename_sh)
        remote_sudo_cmd_run(client, password, 'chmod 755' + ' ' + filename_sh)
        remote_sudo_cmd_run(client, password,
                            'update-rc.d monyze-agent defaults')
        remote_sudo_cmd_run(client, password, 'service monyze-agent restart')
    except Exception as error:
        print('Can not deploy on host: %s' % host)
        print('Caught this error: ' + repr(error))
Ejemplo n.º 3
0
class HPCConnection(object):
    def __init__(self, external_init_dict=None):

        self.logger = logging.getLogger(constants.logging_name)
        init_dict = {}
        clsname = self.__class__.__name__
        if external_init_dict is not None:
            self.logger.debug(
                "{}: initializing from external dict".format(clsname))
            init_dict = external_init_dict
        else:
            self.logger.debug(
                "{}: initializing with default values".format(clsname))

        self.hostname = constants.hpc_hostname
        self.user = constants.user
        self.home_dir = os.path.join(constants.cc_working_dir, self.user)
        self.src_data_path = init_dict.get("src_data_path", "./data")
        self.template_path = constants.template_path
        self.logger.debug("Host being used is {}, under username {}".format(
            self.hostname, self.user))
        self.keypath = init_dict.get("ssh_key_filename",
                                     constants.ssh_key_filename)
        self.client = ParallelSSHClient([self.hostname],
                                        pkey=self.keypath,
                                        user=self.user,
                                        keepalive_seconds=300)
        self.remote_abs_working_folder = None
        self.remote_working_folder = None
        self.active_dataset_name = None
        self.live_job_id = None

    def check_connection(self):

        status = True
        msg = None
        self.logger.debug("Testing connection...")
        try:
            self.client.run_command("ls")
            self.logger.debug("... ok")
        except (
                AuthenticationException,
                UnknownHostException,
                ConnectionErrorException,
        ) as e:
            status = False
            msg = str(e)
            self.logger.debug("... failed ({})".format(msg))

        return status, msg

    def copy_data_to_remote(self, dataset_, remote_temp_folder=None):
        """
        Copies data contained in a local directory over to a remote destination
        """

        self.logger.debug(
            "Copying data to remote location (from {} to {})".format(
                self.src_data_path, self.home_dir))
        remote_base_path = self.home_dir
        local_datapath = self.src_data_path
        if remote_temp_folder is None:
            remote_temp_folder = rand_fname()

        full_remote_path = os.path.join(remote_base_path, remote_temp_folder)
        remote_tar = os.path.join(full_remote_path, "data.tar")
        self.remote_abs_working_folder = full_remote_path
        self.remote_working_folder = remote_temp_folder
        self.active_dataset_name = dataset_
        self.logger.debug("Creating remote folder {}".format(full_remote_path))
        self.client.run_command("mkdir " + full_remote_path)

        #    data_path_content = os.listdir(path=src_data_path)
        #    assert(len(data_path_content) == 1)
        #    df_basename = data_path_content[0]
        df_basename = dataset_

        # self.logger.debug("system cmd: " + "tar cvf /tmp/" + remote_temp_folder + ".tar -C "
        #                   + os.path.join(local_datapath, df_basename) + " .")
        self.logger.debug("system cmd: tar cvf /tmp/{}.tar -C {} .".format(
            remote_temp_folder, os.path.join(local_datapath, df_basename)))

        os.system("tar cf /tmp/" + remote_temp_folder + ".tar -C " +
                  os.path.join(local_datapath, df_basename) + " .")
        try:
            self.logger.debug("Copying data tar file")
            g = self.client.scp_send("/tmp/" + remote_temp_folder + ".tar",
                                     remote_tar)
            joinall(g, raise_error=True)
        except SCPError as e:
            self.logger.error("Copy failed (scp error {})".format(e))
        except Exception as e:
            self.logger.error("Copy failed: {}".format(e))
            raise Exception("scp_send failed")

        s = "tar xvf " + remote_tar + " -C " + full_remote_path
        self.logger.debug("Untarring remote data: {}".format(s))

        output = self.client.run_command(s)
        self.client.join(output)

        errmsg = next(output[self.hostname]["stderr"], None)
        if errmsg is not None:
            self.logger.error("Error: " + errmsg)
            raise Exception("Error untarring data file: " + errmsg)

        errmsg = next(output[self.hostname]["stdout"], None)
        if errmsg is not None:
            self.logger.debug("stdout: " + errmsg)

        self.logger.debug("Remove remote temp tar file " + "/tmp/" +
                          remote_temp_folder + ".tar")
        os.remove("/tmp/" + remote_temp_folder + ".tar")

    # output files in base_dir/jobname/out
    def copy_data_from_remote(self,
                              jobid,
                              absolute_local_out_dir,
                              cleanup_temp=True):

        self.logger.debug("Copying data from remote")
        absolute_tar_fname = os.path.join(
            self.remote_abs_working_folder,
            self.remote_working_folder + "_out.tar")

        absolute_output_data_path = os.path.join(
            self.remote_abs_working_folder, "out")
        stdout_file = os.path.join(self.home_dir, "slurm-" + jobid + ".out")
        self.logger.debug(
            "  Remote data is located in {}".format(absolute_output_data_path))
        self.logger.debug("  Slurm output file is {}".format(stdout_file))

        try:
            self.logger.debug(
                "  Copying slurm file to {}".format(absolute_output_data_path))
            output = self.client.run_command("cp " + stdout_file + " " +
                                             absolute_output_data_path)
            self.client.join(output)
            self.logger.debug(output)
            self.logger.debug("  Tarring remote folder")
            output = self.client.run_command("tar cf " + absolute_tar_fname +
                                             " -C " +
                                             absolute_output_data_path + " .")
            self.client.join(output)
            self.logger.debug(output)
            # time.sleep(30)  # patch since run_command sems non-blocking
            self.logger.debug("Picking up tar file size")
            output = self.client.run_command("du -sb " + absolute_tar_fname)
            self.client.join(output)
            self.logger.debug(output)
            line = ""
            for char in output[self.hostname].stdout:
                line += char
            # print(line)
            tar_size = int(re.match("[0-9]*", line).group(0))

            self.logger.info("{} bytes to copy from remote".format(tar_size))
            local_tar = "/tmp/" + self.remote_working_folder + "_out.tar"
            # g = self.client.scp_recv(absolute_tar_fname, local_tar)
            self.logger.debug(
                "Remote tar file is {}".format(absolute_tar_fname))

            tries = 0
            while tries < 3:
                self.logger.debug("Copying tar file to /tmp")
                g = self.client.copy_remote_file(absolute_tar_fname,
                                                 local_tar)  # scp_recv
                joinall(g, raise_error=True)

                output = subprocess.check_output("du -sb " + local_tar + "_" +
                                                 self.hostname,
                                                 shell=True)
                recv_tar_size = int(
                    re.match("[0-9]*", output.decode("utf-8")).group(0))

                self.logger.debug("Received: {} bytes".format(recv_tar_size))
                if recv_tar_size == tar_size:
                    break
                tries += 1

            if tries == 3:
                raise Exception("Unable to copy tar file from remote end")

            if not os.path.exists(absolute_local_out_dir):
                # shutil.rmtree(absolute_local_out_dir)
                self.logger.debug(
                    "Local destination folder {} does not exist, creating".
                    format(absolute_local_out_dir))
                os.mkdir(absolute_local_out_dir)

            # os.mkdir(path.join(absolute_local_out_dir,jobname)
            self.logger.debug(
                "Untarring received file to {}".format(absolute_local_out_dir))
            os.system("tar xf " + local_tar + "_" + self.hostname + " -C " +
                      absolute_local_out_dir)
            if cleanup_temp:
                # print("todo: cleanup tmp file")
                os.remove(local_tar + "_" + self.hostname)

        except Exception as e:
            self.logger.error(
                "Exception during file transfer from remote: {}".format(e))

    def copy_singlefile_to_remote(self,
                                  local_filename,
                                  remote_path=".",
                                  is_executable=False):

        r = os.path.join(
            self.remote_abs_working_folder,
            remote_path,
            os.path.basename(local_filename),
        )
        g = self.client.copy_file(local_filename, r)
        joinall(g, raise_error=True)
        if is_executable:
            self.client.run_command("chmod ugo+x " + r)

    def create_remote_subdir(self, remote_subdir):

        self.client.run_command(
            "mkdir -p " +
            os.path.join(self.remote_abs_working_folder, remote_subdir))
        self.client.run_command(
            "chmod 777 " +
            os.path.join(self.remote_abs_working_folder, remote_subdir))

    # executable_ is either raven or ostrich

    def copy_batchscript(
        self,
        executable_,
        guessed_duration,
        datafile_basename,
        batch_tmplt_fname,
        shub_hostname,
    ):

        template_file = open(
            os.path.join(self.template_path, batch_tmplt_fname), "r")
        abs_remote_output_dir = os.path.join(self.remote_abs_working_folder,
                                             "out")
        tmplt = template_file.read()
        tmplt = tmplt.replace("ACCOUNT", constants.cc_account_info)
        tmplt = tmplt.replace("DURATION", guessed_duration)
        tmplt = tmplt.replace("TEMP_PATH", self.remote_abs_working_folder)
        tmplt = tmplt.replace("INPUT_PATH", self.remote_abs_working_folder)
        tmplt = tmplt.replace("OUTPUT_PATH", abs_remote_output_dir)
        tmplt = tmplt.replace("DATAFILE_BASENAME", datafile_basename)
        tmplt = tmplt.replace("SHUB_HOSTNAME", shub_hostname)
        tmplt = tmplt.replace("EXEC", executable_)

        # subst_template_file, subst_fname = tempfile.mkstemp(suffix=".sh")
        subst_fname = self.remote_working_folder + ".sh"
        file = open("/tmp/" + subst_fname, "w")
        file.write(tmplt)
        file.close()

        self.client.run_command("mkdir " + abs_remote_output_dir)
        self.client.run_command("chmod 777 " + self.remote_abs_working_folder)
        self.client.run_command("chmod 777 " + abs_remote_output_dir)
        g = self.client.copy_file(
            "/tmp/" + subst_fname,
            os.path.join(self.remote_abs_working_folder, subst_fname),
        )
        joinall(g, raise_error=True)
        self.client.run_command(
            "chmod ugo+x " +
            os.path.join(self.remote_abs_working_folder, subst_fname))
        os.remove("/tmp/" + subst_fname)

        return os.path.join(self.remote_abs_working_folder, subst_fname)

    def submit_job(self, script_fname):

        self.logger.debug("Submitting job {}".format(script_fname))
        # output = self.client.run_command("cd {}; ".format(self.home_dir) + constants.sbatch_cmd +
        #                                  " --parsable " + script_fname)
        output = self.client.run_command("cd {}; {} --parsable {}".format(
            self.home_dir, constants.sbatch_cmd, script_fname))
        self.client.join(output)
        errmsg = next(output[self.hostname]["stderr"], None)

        if errmsg is not None:
            for e in output[self.hostname]["stderr"]:
                errmsg += e + "\n"

            self.logger.error("  Error: {}".format(errmsg))
            raise Exception("Error: " + errmsg)

        self.live_job_id = next(output[self.hostname]["stdout"])
        self.logger.debug("  Job id {}".format(self.live_job_id))

        return self.live_job_id

    def read_from_remote(self, remote_filename):

        filecontent = []
        self.logger.debug("read_from_remote")
        retry = True
        # maybe remote file is being overwritten, try again if remote copy fails
        while True:
            try:
                local_filename = os.path.join(
                    "/tmp", self.remote_working_folder + "_progress.json")
                g = self.client.copy_remote_file(
                    os.path.join(self.remote_abs_working_folder,
                                 remote_filename),
                    local_filename,
                )
                joinall(g, raise_error=True)
                suffixed_local_filename = local_filename + "_" + self.hostname
                self.logger.debug("  Opening copied file")
                with open(suffixed_local_filename) as f:
                    for line in f:
                        self.logger.debug(line)
                        filecontent.append(line)
                break
            #        except SFTPIOError:
            #            print("SFTPIOError")
            #            return False
            except Exception as e:

                if retry:
                    self.logger.debug(
                        "exception {}, retrying".format(e)
                    )  # pass # e.g. missing progress file as execution starts
                    retry = False
                else:
                    break

        self.logger.debug("End read_from_remote")
        return filecontent

    def get_status(self, jobid):
        """
        :param jobid:
        :return:
        """
        self.logger.debug("Inside get_status: executing sacct")
        cmd = constants.squeue_cmd + " -j {} -n -p -b".format(jobid)

        output = self.client.run_command(cmd)
        self.client.join(output)
        status_output = None  # 1 line expected

        errmsg = next(output[self.hostname]["stderr"], None)
        if errmsg is not None:
            for e in output[self.hostname]["stderr"]:
                errmsg += e + "\n"
            self.logger.debug("  stderr: {}".format(errmsg))

            raise Exception("Error: " + errmsg)

        stdout_str = ""
        for line in output[self.hostname]["stdout"]:  # errmsg is None
            stdout_str += line + "\n"
            fields = line.split("|")
            if len(fields) >= 2:
                if fields[0] == jobid:
                    status_output = fields[1].split()[0]

        if status_output is None:
            raise Exception(
                "Error parsing sacct output: {}".format(stdout_str))

        if status_output not in [
                "COMPLETED",
                "PENDING",
                "RUNNING",
                "TIMEOUT",
                "CANCELLED",
        ]:
            raise Exception(
                "Status error: state {} unknown".format(status_output))

        return status_output

    def cancel_job(self, jobid):
        """
        :param jobid:
        :return:
        """
        cmd = constants.scancel_cmd + " {}".format(jobid)

        output = self.client.run_command(cmd)
        self.client.join(output)
        errmsg = next(output[self.hostname]["stderr"], None)
        if errmsg is not None:
            for e in output[self.hostname]["stderr"]:
                errmsg += e + "\n"
            self.logger.debug("  stderr: {}".format(errmsg))

            raise Exception("Cancel error: " + errmsg)

        stdout_str = ""
        for line in output[self.hostname]["stdout"]:  # errmsg is None
            stdout_str += line + "\n"
        if len(stdout_str) > 0:
            raise Exception("Cancel error: " + stdout_str)

    def reconnect(self):

        self.client = ParallelSSHClient([self.hostname],
                                        pkey=self.keypath,
                                        user=self.user,
                                        keepalive_seconds=300)

    """
    def check_slurmoutput_for(self, substr, jobid):

            slurmfname = "slurm-" + jobid + ".out"
            local_slurmfname = os.path.join("/tmp", slurmfname)
            stdout_file = os.path.join(self.home_dir, slurmfname)
            found = False
            try:
                g = self.client.copy_remote_file(stdout_file, local_slurmfname)
                joinall(g, raise_error=True)
                # scan file for substr
                with open(local_slurmfname + "_" + self.hostname) as f:
                    for line in f:
                        print("comparing {} with {}".format(substr,line))
                        match_obj = re.search(substr, line)
                        print(match_obj)
                        if match_obj:
                            found = True
                            print("found")

                os.remove(local_slurmfname + "_" + self.hostname)

            except Exception as e:
                print("Exception inside check_slurmoutput_for")
                print(e)
                pass

            return found
    """

    def cleanup(self, jobid):

        try:
            self.logger.debug("Deleting the remote folder")
            output1 = self.client.run_command("rm -rf {}".format(
                os.path.join(self.home_dir, self.remote_abs_working_folder)))
            self.logger.debug("Deleting the slurm log file")
            logfilepath = os.path.join(self.home_dir,
                                       "slurm-{}.out".format(jobid))
            output2 = self.client.run_command("rm {}".format(logfilepath))
            self.logger.debug("Deleting the local progress file")
            local_filename = os.path.join(
                "/tmp", self.remote_working_folder + "_progress.json")
            suffixed_local_filename = local_filename + "_" + self.hostname
            os.remove(suffixed_local_filename)

            self.logger.debug(next(output1[self.hostname]["stdout"]))
            self.logger.debug(next(output2[self.hostname]["stdout"]))
            self.logger.debug(next(output1[self.hostname]["stderr"]))
            self.logger.debug(next(output2[self.hostname]["stderr"]))

        except Exception as e:
            self.logger.debug("Hmm file cleanup failed: {}".format(e))
Ejemplo n.º 4
0
def _main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--profile',
                        help="""profile section in ~/.aws/credentials""")
    parser.add_argument('--name',
                        dest='name_tag',
                        help=f"""instance name tag to match
                                 (default: {DEFAULT_NAME_TAG})""")
    parser.add_argument('--ssh-key',
                        metavar='FILE',
                        help=f"""SSH key file to use
                                 (default: {DEFAULT_SSH_KEY})""")
    parser.add_argument('--ssh-user',
                        metavar='FILE',
                        help=f"""remote SSH username (default: ubuntu)""")
    parser.add_argument('--ida-dir',
                        metavar='FILE',
                        help=f"""harmony-one/ida repository path on instance
                                 (default: {DEFAULT_IDA_DIR})""")
    parser.add_argument('--action',
                        metavar='ACTION',
                        dest='actions',
                        help=f"""action(s) to take, a comma-separated list of:
                                 gen, start, stop, send, update""")
    parser.add_argument('--file',
                        metavar='FILE',
                        help=f"""the file to send (for --action=send)""")
    parser.add_argument('--id',
                        metavar='NODE_ID_IN_REGION',
                        type=int,
                        help=f"""the node is in a given region to login""")
    parser.add_argument('--query',
                        metavar='search query',
                        help=f"""search given query from the nodes """)
    parser.add_argument('--t0',
                        type=int,
                        metavar='MILLISECONDS',
                        help=f"""minimum interpacket delay
                                 (default: {DEFAULT_T0})""")
    parser.add_argument('--t1',
                        type=int,
                        metavar='MILLISECONDS',
                        help=f"""maximum interpacket delay
                                 (default: {DEFAULT_T1})""")
    parser.add_argument('--exp-base',
                        type=float,
                        metavar='NUM',
                        help=f"""interpacket delay exponential base
                                 (default: {DEFAULT_EXP_BASE})""")
    parser.add_argument('num_instances',
                        type=int,
                        metavar='N',
                        help="""number of instances""")
    parser.add_argument('regions',
                        nargs='+',
                        metavar='REGION',
                        help="""AWS regions (such as us-west-2)""")
    parser.set_defaults(name_tag=DEFAULT_NAME_TAG,
                        profile='default',
                        ssh_key=DEFAULT_SSH_KEY,
                        ssh_user='******',
                        ida_dir=DEFAULT_IDA_DIR,
                        actions='send',
                        t0=DEFAULT_T0,
                        t1=DEFAULT_T1,
                        exp_base=DEFAULT_EXP_BASE)
    args = parser.parse_args()

    logger.info(f"collecting all IP addresses")

    ips = {}
    all_ips = []

    for region in args.regions:
        proc = subprocess.run([
            'aws', f'--profile={args.profile}', f'--region={region}', 'ec2',
            'describe-instances', f'--filters=Name=tag:Name,'
            f'Values={args.name_tag!r}'
        ],
                              stdout=subprocess.PIPE,
                              check=True)
        r = json.loads(proc.stdout)
        ips1 = []
        for reservation in r['Reservations']:
            for instance in reservation['Instances']:
                if instance['State']['Name'] == 'running':
                    ips1.append(instance['PublicIpAddress'])
        ips1.sort(key=lambda ip: tuple(int(b) for b in ip.split('.')))
        if len(ips1) < args.num_instances:
            raise RuntimeError(
                f"{region} has {len(ips1)} matching instances,"
                f"should be greater or equal than {args.num_instances}")
        logger.info(f"total {len(ips1)} running instances in {region}")
        ips1 = ips1[:args.num_instances]
        ips[region] = ips1
        print(f"in {region} will use {len(ips[region])} instances")
        all_ips.extend(ips1)

    ida_dir = sq(args.ida_dir)

    def ssh(host_list, cmd, **kargs):
        try:
            client = ParallelSSHClient(host_list,
                                       user='******',
                                       pkey=f'{sq(args.ssh_key)}')
            output = client.run_command(cmd, **kargs)
            for host in output:
                logger.info(host)
                for line in output[host]['stdout']:
                    logger.info(line)
        except:
            logger.info('cannot connect to all the hosts')
            return

    def ssh1(ip, cmd, **kwargs):
        return subprocess.run([
            'ssh', '-oStrictHostKeyChecking=no',
            '-oUserKnownHostsFile=/dev/null', '-oControlMaster=auto',
            '-oControlPersist=yes', f'-i{sq(args.ssh_key)}',
            f'{args.ssh_user}@{ip}', cmd
        ], **kwargs)

    for action in args.actions.split(','):
        action = action.strip()
        if action == 'login':
            logger.info(f"log into given remote host")
            ip = ips[region][args.id]
            subprocess.run(
                ['ssh', f'-i{sq(args.ssh_key)}', f'{args.ssh_user}@{ip}'])
        if action == 'gen':
            logger.info(f"generating configurations")

            all_peers_config = io.StringIO()
            peer_configs = [io.StringIO() for ip in all_ips]
            for idx, ip in enumerate(all_ips):
                pk = hashlib.sha1(ip.encode()).hexdigest()
                print(f"{idx} {ip} 20000 10000 {pk} 2", file=all_peers_config)
                print(f"{idx} {ip} 20000 10000 {pk} 0", file=peer_configs[idx])
                for idx2, ip2 in enumerate(all_ips):
                    if idx2 == idx:
                        continue
                    pk2 = hashlib.sha1(ip2.encode()).hexdigest()
                    print(f"{idx2} {ip2} 20000 10000 {pk2} 1",
                          file=peer_configs[idx])
            all_peers_config = all_peers_config.getvalue().encode()
            for idx, peer_config in enumerate(peer_configs):
                peer_configs[idx] = peer_config.getvalue().encode()

            logger.info(f"removing old config files")
            ssh(all_ips, f'rm -f {ida_dir}/configs/*')

            with open('all_peers.txt', 'wb') as f:
                f.write(all_peers_config)
            logger.info(f"copying all-peer configs")
            client = ParallelSSHClient(all_ips,
                                       user='******',
                                       pkey=f'{sq(args.ssh_key)}')
            greenlets = client.copy_file('all_peers.txt',
                                         f'{ida_dir}/configs/all_peers.txt')
            joinall(greenlets, raise_error=True)

            logger.info(f"copying per-peer configs")
            for idx, ip in enumerate(all_ips):
                logger.info(f"... {ip}")
                ssh1(ip,
                     f'cat > {ida_dir}/configs/config.txt',
                     input=peer_configs[idx],
                     check=True)

        elif action == 'start':
            logger.info(f"starting server")
            logger.info(all_ips[1:])
            ssh(
                all_ips[1:], f'cd {ida_dir} && ls -l ida && {{ ./ida '
                '-nbr_config configs/config.txt '
                '-all_config configs/all_peers.txt '
                '> ida.out 2>&1 & ls -l ida.out || :; }')

        elif action == 'stop':
            logger.info(f"stopping server, and remove received files")
            ssh(all_ips, f'killall ida && rm -f {ida_dir}/received/*')

        elif action == 'send':
            if args.file is None:
                parser.error("--file not specified")
            logger.info(f"sending one file: {args.file}")
            with open(args.file, 'rb') as f:
                contents = f.read()
            hex_hash = hashlib.sha1(contents).hexdigest()
            logger.info(f"{len(contents)} byte(s), sha1 {hex_hash}")
            logger.info(f"copying over")
            ssh1(all_ips[0],
                 f'cat > {ida_dir}/{hex_hash}.dat',
                 input=contents,
                 check=True)
            logger.info(f"invoking ida")
            ssh1(all_ips[0], f'cd {ida_dir} && ./ida '
                 f'-nbr_config configs/config.txt '
                 f'-all_config configs/all_peers.txt '
                 f'-broadcast -msg_file {hex_hash}.dat '
                 f'-t0 {args.t0} '
                 f'-t1 {args.t1} '
                 f'-base {args.exp_base} '
                 f'2>&1 | tee ida.out',
                 check=True)

        elif action == 'grep':
            if args.query is None:
                parser.error("--query not specified")
            query = sq(args.query)
            logger.info(f"searching for {args.query} from nodes")
            logger.info(
                "**********************************************************")
            ssh(
                all_ips, f'cd {ida_dir} && ls -l ida.out && '
                f'ag {query} ida.out | cat')

        elif action == 'update':
            logger.info(f"downloading new binary from s3")
            url = 'https://s3.us-east-2.amazonaws.com/harmony-ida-binary/ida'
            ssh(
                all_ips, f'cd {ida_dir} && '
                f'rm -f ida && '
                f'curl -LsS -o ida {url} && '
                f'chmod a+x ida')
class ClusterConnector:
    def __init__(
        self,
        workspace,
        cluster_name,
        ssh_key,
        vm_type,
        admin_username="******",
    ):
        """Thin wrapper class around azureml.core.compute.AmlCluster

        Provides parallel ssh objects and helper for master node and all node commands
        and file copies.

        Usage:
        >>> cc = ClusterConnector(workspace, "MyCluster", sshkey, "Standard_ND40rs_v2")
        >>> cc.initialize(min_nodes=0, max_nodes=4, idle_timeout_secs=30)
        >>> cluster = cc.cluster
        >>> [print(node['name']) for node in cc.cluster.list_nodes()]
        """

        self.cluster_name = cluster_name
        self.workspace = workspace
        self.ssh_key = ssh_key
        self.vm_type = vm_type
        self.admin_username = admin_username

        enable_host_logger()
        hlog = logging.getLogger("pssh.host_logger")
        tstr = datetime.now().isoformat(timespec="minutes")
        [
            hlog.removeHandler(h) for h in hlog.handlers
            if isinstance(h, logging.StreamHandler)
        ]
        os.makedirs("clusterlogs", exist_ok=True)
        self.logfile = "clusterlogs/{}_{}.log".format(self.workspace.name,
                                                      tstr)
        hlog.addHandler(logging.FileHandler(self.logfile))

        self.cluster = None
        self._master_scp = None
        self._master_ssh = None
        self._all_ssh = None

    def initialise(self, min_nodes=0, max_nodes=0, idle_timeout_secs=1800):
        """Initialise underlying AmlCompute cluster instance"""
        self._create_or_update_cluster(min_nodes, max_nodes, idle_timeout_secs)

    def _check_logs_emessage(self, host, port):
        msg = "Remote command failed on {}:{}. For details see {}".format(
            host, port, self.logfile)
        return msg

    def terminate(self):

        print('Attempting to terminate cluster "{}"'.format(
            colored(self.cluster_name, "green")))
        try:
            self.cluster.update(min_nodes=0,
                                max_nodes=0,
                                idle_seconds_before_scaledown=10)
            self.cluster.wait_for_completion()
        except ComputeTargetException as err:
            raise RuntimeError(
                "Failed to terminate cluster nodes ({})".format(err))

        if len(self.cluster.list_nodes()):
            raise RuntimeError(
                "Failed to terminate cluster nodes (nodes still running)")

    @property
    def cluster_nodes(self):
        self.cluster.refresh_state()
        return sorted(self.cluster.list_nodes(), key=lambda n: n["port"])

    def _create_or_update_cluster(self, min_nodes, max_nodes,
                                  idle_timeout_secs):

        try:
            self.cluster = AmlCompute(workspace=self.workspace,
                                      name=self.cluster_name)
            print('Updating existing cluster "{}"'.format(
                colored(self.cluster_name, "green")))
            self.cluster.update(
                min_nodes=min_nodes,
                max_nodes=max_nodes,
                idle_seconds_before_scaledown=idle_timeout_secs,
            )
        except ComputeTargetException:
            print('Creating new cluster "{}"'.format(
                colored(self.cluster_name, "green")))
            cluster_config = AmlCompute.provisioning_configuration(
                vm_size=self.vm_type,
                min_nodes=min_nodes,
                max_nodes=max_nodes,
                idle_seconds_before_scaledown=idle_timeout_secs,
                admin_username=self.admin_username,
                admin_user_ssh_key=self.ssh_key,
                remote_login_port_public_access="Enabled",
            )
            self.cluster = AmlCompute.create(self.workspace, self.cluster_name,
                                             cluster_config)

        self.cluster.wait_for_completion()

        if len(self.cluster_nodes) < min_nodes:
            sleep(30)
            if len(self.cluster_nodes) < min_nodes:
                raise RuntimeError("Failed to provision sufficient nodes")

    def _copy_nodefile_to_nodes(self):

        if len(self.cluster_nodes) == 1:
            cprint("Single node cluster -- skipping IB config", "yellow")
            return

        print("Collecting cluster IB info")

        outputs = self._all_ssh.run_command(
            r'ifconfig ib0 | grep -oe "inet[^6][adr: ]*[0-9.]*" | cut -d" " -f2',
            shell="bash -c",
        )
        self._all_ssh.join(outputs)

        ibaddrs = []
        for output in outputs:
            host = output.host
            port = output.client.port
            if output.exit_code != 0:
                print(list(output.stdout))
                print(list(output.stderr))
                raise RuntimeError("Failed to get IB ip for {}:{}".format(
                    host, port))
            try:
                ibaddr = list(output.stdout)[0].split()[0]
            except IndexError:
                raise RuntimeError("Failed to get IB ip for {}:{} - "
                                   "No ib interface found!".format(host, port))
            print("Mapping {}:{} -> {}".format(host, port, ibaddr))
            if port == self._master_scp.port:
                cprint("IB Master: {}".format(ibaddr), "green")
                ibaddrs = [ibaddr] + ibaddrs
            else:
                ibaddrs.append(ibaddr)

        with NamedTemporaryFile(delete=False, mode="wt") as nfh:
            self.nodefile = nfh.name
            for addr in ibaddrs:
                nfh.write("{}\n".format(addr))

        self.ibaddrs = ibaddrs
        self.copy_to_all_nodes(self.nodefile, "./nodefile")

    def _create_cluster_ssh_conns(self):

        hostips = [n["publicIpAddress"] for n in self.cluster_nodes]
        hostconfigs = [HostConfig(port=n["port"]) for n in self.cluster_nodes]

        self._all_ssh = ParallelSSHClient(hostips,
                                          host_config=hostconfigs,
                                          user=self.admin_username)

        self._master_ssh = ParallelSSHClient(hostips[:1],
                                             host_config=hostconfigs[:1],
                                             user=self.admin_username)

        self._master_scp = SSHClient(hostips[0],
                                     port=hostconfigs[0].port,
                                     user=self.admin_username)

    def copy_to_all_nodes(self, source, dest):

        copy_jobs = self._all_ssh.copy_file(source, dest)
        joinall(copy_jobs, raise_error=True)

    def copy_to_master_node(self, source, dest):

        self._master_scp.copy_file(source, dest)

    def copy_from_master_node(self, source, dest):

        self._master_scp.copy_remote_file(source, dest)

    def run_on_all_nodes(self, command):

        outputs = self._all_ssh.run_command(command, shell="bash -c")
        self._all_ssh.join(outputs, consume_output=True)

        for output in outputs:
            if int(output.exit_code) != 0:
                host = output.host
                port = output.client.port
                raise RuntimeError(self._check_logs_emessage(host, port))

    def run_on_master_node(self, command):

        outputs = self._master_ssh.run_command(command, shell="bash -c")
        self._master_ssh.join(outputs)

        for output in outputs:
            if int(output.exit_code) != 0:
                host = output.host
                port = output.client.port
                raise RuntimeError(self._check_logs_emessage(host, port))

    def attempt_termination(self):
        try:
            self.terminate()
        except RuntimeError as err:
            print(colored("ERROR: {}\n\n", "red", attrs=["bold"]).format(err))
            self.warn_unterminated()

    def warn_unterminated(self):
        print(
            colored("WARNING: {}", "red", attrs=["bold"]).format(
                colored(
                    "Cluster {} is still running - terminate manually to avoid "
                    "additional compute costs".format(
                        colored(self.cluster_name, "green")),
                    "red",
                )))
import os
from gevent import joinall
from datetime import datetime
from pssh.clients import ParallelSSHClient

with open('file_copy', 'wb') as fh:
    for _ in range(2000000):
        fh.write(b'asdfa')

fileinfo = os.stat('file_copy')
client = ParallelSSHClient(['localhost'])
now = datetime.now()
cmd = client.copy_file('file_copy', '/tmp/file_copy')
joinall(cmd, raise_error=True)
taken = datetime.now() - now
mb_size = fileinfo.st_size / (1024000.0)
rate = mb_size / taken.total_seconds()
print("File size %sMB transfered in %s, transfer rate %s MB/s" %
      (mb_size, taken, rate))
Ejemplo n.º 7
0
 def copy(self, fname):
     """Copy a file to all the instances."""
     client = ParallelSSHClient([i.ip for i in self.hosts], timeout=60, pkey=self.keysfile)
     logger.debug(f'Copy file {fname}.')
     cmds = client.copy_file(fname, basename(fname))
     joinall(cmds, raise_error=True)