Exemple #1
0
class ClusterShell:
    """
    ClusterShell lets you run commands on multiple EC2 instances.

    ClusterShell takes in information about a set of EC2 instances that exist and allows you to run commands on
    some or all of the nodes. It also has convenience methods for copying files between the local filesystem and
    the cluster.
    """


    def __init__(self, username, master_ip, worker_ips, ssh_key_path, use_bastion=False,
                 wait_for_ssh=True, wait_for_ssh_timeout=120):
        """
        Args:
            username: The username used to ssh to the instance. Often 'ubuntu' or 'ec2-user'
            master_ip: A single IP for the master node. Typically should be the public IP if the location this code is
                       running is outside of the VPC and the private IP if running from another EC2 node in the same
                       VPC. In many cases, the distinction between master and workers is arbitrary. If use_bastion is
                       True, the master node will be the bastion host.
            worker_ips: A possibly empty list of ips for the worker nodes. If there is only a single worker, a string
                        can be passed in instead of a list.
            ssh_key_path: The path to the SSH key required to SSH into the EC2 instances. Often ~/.ssh/something.pem
            use_bastion (bool): Whether or not to use the master node as the bastion host for SSHing to worker nodes.
            wait_for_ssh (bool): If true, block until commands can be run on all instances. This can be useful when you
                                 are launching EC2 instances, because the instances may be in the RUNNING state but the
                                 SSH daemon may not yet be running.
            wait_for_ssh_timeout: Number of seconds to spend trying to run commands on the instances before failing.
                                  This is NOT the SSH timeout, this upper bounds the amount of time spent retrying
                                  failed SSH connections. Only used if wait_for_ssh=True.
        """
        if not isinstance(worker_ips, list):
            worker_ips = [worker_ips]

        self._username = username
        self._master_ip = master_ip
        self._worker_ips = worker_ips
        self._all_ips = [self._master_ip] + self._worker_ips
        self.use_bastion = use_bastion


        connect_kwargs = {
            "key_filename": [os.path.expanduser(ssh_key_path)],
            "banner_timeout": 30    # NOTE 1 above
        }

        self._master_conn = Connection(user=self._username,
                                       host=self._master_ip,
                                       forward_agent=True,
                                       connect_kwargs=connect_kwargs)

        worker_conns = []
        for worker_ip in self._worker_ips:
            if self.use_bastion:
                c = Connection(user=self._username,
                               host=worker_ip,
                               connect_kwargs=connect_kwargs,
                               gateway=Connection(user=self._username,
                                                  host=master_ip,
                                                  forward_agent=True,
                                                  connect_kwargs=connect_kwargs))
            else:
                c = Connection(user=self._username, host=worker_ip, connect_kwargs=connect_kwargs)

            worker_conns.append(c)

        self._individual_worker_conns = worker_conns
        self._worker_conns = ThreadingGroup.from_connections(worker_conns)
        self._all_conns = ThreadingGroup.from_connections([self._master_conn] + worker_conns)

        if wait_for_ssh:
            self.wait_for_ssh_ready(wait_timeout=wait_for_ssh_timeout)

    def wait_for_ssh_ready(self, wait_timeout=120):
        """Repeatedly try to run commands on all instances until successful or until timeout is reached."""

        start_time = time.time()
        exceptions = []
        while True:

            try:
                self.run_on_all("hostname", hide=True)
                break
            except fabric2.exceptions.GroupException as e:
                exceptions.append(e)

                elapsed_time = time.time() - start_time
                if elapsed_time > wait_timeout:
                    exceptions_str = "\n".join([str(e) for e in exceptions])
                    raise RuntimeError(
                            f"[ClusterShell.wait_for_ssh_ready] Unable to establish an SSH connection after "
                            f"{wait_timeout} seconds. On EC2 this is often due to a problem with the security group, "
                            f"although there are many potential causes."
                            f"\nExceptions encountered:\n{exceptions_str}")

                secs_to_timeout = int(wait_timeout - elapsed_time)
                print(f"ClusterShell.wait_for_ssh_ready] Exception when SSHing to instances. Retrying until timeout in "
                      f"{secs_to_timeout} seconds")
                time.sleep(1)

    def run_local(self, cmd):
        """Run a shell command on the local machine.

        Will wait for the command to finish and raise an exception if the return code is non-zero.

        Args:
            cmd: The shell command to run

        Returns:
             The stdout of the command as a byte string.
        """
        return subprocess.check_output(shlex.split(cmd))


    def run_on_master(self, cmd, **kwargs):
        """Run a shell command on the master node.

        Args:
            cmd: The shell command to run
            kwargs: http://docs.fabfile.org/en/2.4/api/connection.html#fabric.connection.Connection.run

        Returns:
            Result: An invoke Result object. `http://docs.pyinvoke.org/en/latest/api/runners.html#invoke.runners.Result`
        """
        return self._master_conn.run(cmd, **kwargs)


    def run_on_all(self, cmd, **run_kwargs):
        """Run a shell command on every node.

        Args:
            cmd: The shell command to run
            run_kwargs: Keyword args to pass to fabric.run(). Fabric passes them through to Invoke, which are
                        documented here: http://docs.pyinvoke.org/en/latest/api/runners.html#invoke.runners.Runner.run.
                        Potentially useful args:
                            hide=True will prevent run output from being output locally

        Returns:
            List of invoke.Result objects. Order is not guaranteed. http://docs.pyinvoke.org/en/latest/api/runners.html#invoke.runners.Result
        """

        if self.use_bastion:
            if len(self._worker_ips) >= (MAX_CONNS_PER_GROUP - 1):
                results = self._run_on_all_workaround(cmd, MAX_CONNS_PER_GROUP, **run_kwargs)
                return list(results)

        results = self._all_conns.run(cmd, **run_kwargs)
        return list(results.values())


    # TODO: Confirm this is required with (10+ nodes)
    def _run_on_all_workaround(self, cmd, group_size, **run_kwargs):
        total_conns = len(self._worker_conns) + 1
        print(f'{total_conns} Nodes')
        groups = []

        group_conns = []
        for i, worker_conn in enumerate(self._individual_worker_conns):
            if i % group_size == 0 and i != 0:
                groups.append(ThreadingGroup.from_connections(group_conns))
                group_conns = []
            group_conns.append(worker_conn)

        flattened_results = []
        # Either add the master to one of the groups or create a group for it (if groups are all full or no workers)
        if len(group_conns) != 0 and len(group_conns) != group_size:
            group_conns.append(self._master_conn)
            groups.append(ThreadingGroup.from_connections(group_conns))

        else:
            if len(group_conns) != 0:
                groups.append(ThreadingGroup.from_connections(group_conns))
            master_result = self.run_on_master(cmd, **run_kwargs)
            flattened_results.append(master_result)

        for i, worker_conn_group in enumerate(groups):
            group_results = worker_conn_group.run(cmd, **run_kwargs)
            flattened_results.extend(group_results.values())

        return flattened_results


    def copy_from_master_to_local(self, remote_path, local_path):
        """Copy a file from the master node to the local node.

        Args:
            remote_path: The path of the file on the master node. If not an absolute path, will be relative to the
                         working directory, typically the home directory. Will not expand tilde (~).
            local_path: The path to save the file to on the local file system.
        """
        local_abs_path = Path(local_path).absolute()
        return self._master_conn.get(remote_path, local_abs_path)


    def copy_from_all_to_local(self, remote_abs_path, local_path):
        """Copy files from all nodes to the local filesystem.

        There will be one directory per node containing the file.

        Args:
            remote_abs_path: The absolute path of the file to download. Can be a directory or a cp/scp string including
                             wildcards
            local_path: The absolute path of a directory on the local filesystem to download the files into. The path
                        must not point to a file.
        """
        if self.use_bastion:
            raise NotImplementedError("Copying has not yet been implemented for bastion mode. Please open a ticket at "
                                      "https://github.com/armandmcqueen/ec2-cluster if you would like to see this "
                                      "feature implemented")
        local_abs_path = Path(local_path).absolute()

        if not local_abs_path.exists():
            local_abs_path.mkdir(parents=True)
        else:
            if local_abs_path.is_file():
                raise RuntimeError(f'[ClusterShell.copy_from_all_to_local] local_path points to a file: '
                                   f'{local_abs_path}')

        master_dir = local_abs_path / "0"
        master_dir.mkdir()
        master_ip_path = master_dir / "ip.txt"

        with open(master_ip_path, 'w') as f:
            f.write(self.master_ip)

        self.run_local(f'scp '
                       f'-o StrictHostKeyChecking=no '
                       f'-o "UserKnownHostsFile /dev/null" '
                       f'-o "LogLevel QUIET" '
                       f'-r '
                       f'{self._username}@{self.master_ip}:{remote_abs_path} {master_dir}/')

        # Create and populate staging folder for each worker's data
        for ind, worker_ip in enumerate(self._worker_ips):
            worker_id = ind + 1
            worker_node_dir = local_abs_path / str(worker_id)
            worker_node_dir.mkdir()
            worker_ip_path = worker_node_dir / "ip.txt"
            with open(worker_ip_path, 'w') as f:
                f.write(worker_ip)
            self.run_local(f'scp '
                           f'-o StrictHostKeyChecking=no '
                           f'-o "UserKnownHostsFile /dev/null" '
                           f'-o "LogLevel QUIET" '
                           f'-r '
                           f'{self._username}@{worker_ip}:{remote_abs_path} {worker_node_dir}/')



    def copy_from_local_to_master(self, local_path, remote_path):
        """Copy a file from the local filesystem to the master node.

        Args:
            local_path: The path of the file to send to the master node
            remote_path: The path where the file will be saved on the master node. Does not expand tilde (~), but if not
                         an absolute path, will usually interpret the path as relative to the home directory.
        """
        local_abs_path = Path(local_path).absolute()
        return self._master_conn.put(local_abs_path, remote_path)

    def copy_from_local_to_all(self, local_path, remote_path):
        """Copy a file from the local filesystem to every node in the cluster.

        Args:
            local_path: The path of the file to send to the master and worker nodes
            remote_path: The path where the file will be saved on the master and worker nodes. Does not expand tilde (~),
                         but if not an absolute path, will usually interpret the path as relative to the home directory.
        """
        if self.use_bastion:
            raise NotImplementedError("Copying has not yet been implemented for bastion mode. Please open a ticket at "
                                      "https://github.com/armandmcqueen/ec2-cluster if you would like to see this "
                                      "feature implemented")

        local_abs_path = Path(local_path).absolute()
        self.copy_from_local_to_master(local_abs_path, remote_path)
        for worker_conn in self._individual_worker_conns:
            worker_conn.put(local_abs_path, remote_path)

    @property
    def username(self):
        """The username used to instantiate the ClusterShell"""
        return self._username

    @property
    def master_ip(self):
        """The master IP used to instantiate the ClusterShell"""
        return self._master_ip

    @property
    def non_master_ips(self):
        """All IPs other than the master node. May be an empty list"""
        return self._worker_ips

    @property
    def all_ips(self):
        """A list of master and worker IPs"""
        return self._all_ips
Exemple #2
0
class gce_api:
    
    URI = 'https://www.googleapis.com'
    
    CommonCalls = {'machineTypeList': 'https://www.googleapis.com/compute/v1/projects/{project}/zones/{zone}/machineTypes',
                   'imagesList':      'https://www.googleapis.com/compute/v1/projects/{project}/global/images',
                   'projectInfo':     'https://www.googleapis.com/compute/v1/projects/{project}',
                   'firewallList':    'https://www.googleapis.com/compute/v1/projects/{project}/global/firewalls',
                   'firewallResource':'https://www.googleapis.com/compute/v1/projects/{project}/global/firewalls/{firewallName}', 
                   'instances':       'https://www.googleapis.com/compute/v1/projects/{project}/zones/{zone}/instances',
                   'serialPort':      'https://www.googleapis.com/compute/v1/projects/{project}/zones/{zone}/instances/{instanceName}/serialPort',
                   'instanceInfo':    'https://www.googleapis.com/compute/v1/projects/{project}/zones/{zone}/instances/{instanceName}'
    }
    
    def __init__(self,json_key,properties,storage_key):
        
        self.properties = properties
        self.properties['keyFile'] = F'{os.path.join(self.properties["keyDir"],self.properties["instanceName"])}'
        self.properties['pubKeyFile'] = F'{self.properties["keyFile"] + ".pub"}'
        self.credentials = service_account.Credentials.from_service_account_file(json_key)
        self.credentials_storage = service_account.Credentials.from_service_account_file(storage_key)
        self.scoped_credentials = self.credentials.with_scopes(['https://www.googleapis.com/auth/cloud-platform'])
        self.storage_credentials = self.credentials_storage.with_scopes(['https://www.googleapis.com/auth/devstorage.full_control'])
        
        self.authed_session = AuthorizedSession(self.scoped_credentials)
        self.storage_session = AuthorizedSession(self.storage_credentials)
        os.environ['GOOGLE_APPLICATION_CREDENTIALS']=storage_key
        self.storage_client = storage.Client() #GOOGLE_APPLICATION_CREDENTIALS should have been set as an environment variable. This is shit but storage_client here can't seem to accept the path to the json file
    
   
    def waitUntilDone(func):
        def wrapper(self,*args,**kwargs):
            response = func(self,*args,**kwargs)
            if 'status' in response.keys() and response != None:
                while True: #response['status'] != "DONE":
                    display(response)
                    time.sleep(0.5)
                    response = func(self,*args,**kwargs)
                    
#                     display(response)
            else :
                response = None
            return response
        return wrapper
    
    def get(self,*args,**kwargs):
        self.method = "get"
        return self.selectRunType(*args,**kwargs)

    def post(self,*args,**kwargs):
        self.method = "post"
        return self.selectRunType(*args,**kwargs)
    
    def delete(self,*args,**kwargs):
        self.method = "delete"
        return self.selectRunType(*args,**kwargs)
    
    
    def selectRunType(self,*args,**kwargs):
        wait = kwargs.get('wait',False)
        kwargs.pop('wait',None)
        if not wait:
            result = self.runRequest(*args,**kwargs)
        else: 
            result = self.persistent(*args,**kwargs)
        return result
        
       
    def runRequest(self,*args,**kwargs):
        properties = kwargs.get('properties',None)
        if properties != None:
            self.properties = properties
        kwargs.pop('properties',None)
        call=gce_api.CommonCalls[args[0]].format(**self.properties)
        #display(kwargs)
        response = getattr(self.authed_session,self.method)(call,**kwargs)
#         display(call)
        if response.status_code == 200:
            return json.loads(response.text)
        else:
            display("Response code was {}. It might not have worked".format(response.status_code))
            return None
        
    def request_storage(self,url, payload='None', method='get'):
        if payload is 'None':
            return getattr(self.storage_session,method)(url)
        else:
            return getattr(self.storage_session,method)(url,json=payload)        
        
        
    @waitUntilDone
    def persistent(self,*args,**kwargs):
        return self.runRequest(*args,**kwargs)
    
    def create_bucket(self,name):
        return self.storage_client.create_bucket(name)
    
    
    def generateSSHKey(self):
        display('Generating ssh key...')
        c = Connection('localhost')
        c.local('rm -f "{keyFile}.*"'.format(**self.properties))
        c.local("echo 'yes' | ssh-keygen -t rsa -f {keyFile} -C {username}  -N '' ".format(**self.properties),hide='out')
        c.close()
        #p = Popen("echo 'yes' | ssh-keygen -t rsa -f {keyFile} -C {username} -N '' ".format(**self.properties),
        #              stdout=PIPE,
        #              shell=True,
        #              stderr=PIPE
        #               )
        #print(p.communicate())
        with open (self.properties['pubKeyFile'],'r') as f:
            display('Opening {}'.format(self.properties['pubKeyFile']))
            self.pub = f.read().strip()
            
    def setConnection(self):        
        self.connection = Connection(host=self.properties['ip'],
                       user=self.properties['username'],
                       connect_kwargs={"key_filename": self.properties['keyFile'],}
                       )
        #self.connection.open()
        
    def setSSHPort(self,ip='',inOffice='True'):
        #display(cloudInfo)
        ipList = ["151.157.0.0/16",]
        if not inOffice:
            ipList.append(ip)
            
        info = self.get('firewallList')
        firewalls = [i['name'] for i in info['items']]
        ssh = {
              "name" : "ssh",  
              "allowed": [
                {
                  "IPProtocol": "tcp",
                  "ports": [
                    "22",
                  ]
                }
              ],
              "sourceRanges": ipList,
              "targetTags": [
                "ssh"
              ]
            }

        if 'ssh' in firewalls:
            self.properties['firewallName'] = 'ssh'
            info = self.delete('firewallResource')
            display(info['operationType'],info['targetLink'])

        #Waiting until the firewall has been deleted
        info = self.get('firewallList')
        firewalls = [i['name'] for i in info['items']]

        while 'ssh' in firewalls:
            time.sleep(0.5)
            info=self.get('firewallList')
            firewalls = [i['name'] for i in info['items']]

        # Actually creating the firewall
        info = self.post('firewallList',json=ssh)
        display(info['operationType'],info['targetLink'])
        
        
    def runScript(self,file,getResults=False,out='results.txt'):
        self.connection.put(file)
        name = os.path.basename(file)
        self.connection.run('chmod +x {}'.format(name))
        self.connection.run('./{}'.format(name))
        if getResults:
            self.connection.get("results.txt",out)