예제 #1
0
def test(endpoint_id=None, tasks=10, hostname=None, port=None):
    tasks_rq = RedisQueue(f'task_{endpoint_id}', hostname)
    results_rq = RedisQueue(f'results', hostname)
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    while True:
        try:
            x = results_rq.get(timeout=1)
        except:
            print("No more results left")
            break

    tasks_rq.connect()
    results_rq.connect()
    start = time.time()
    for i in range(tasks):
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({'duration':0})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        tasks_rq.put(f"0{i}", payload)

    for i in range(tasks):
        res = results_rq.get(timeout=1)
        print("Result : ", res)

    delta = time.time() - start
    print("Time to complete {} tasks: {:8.3f} s".format(tasks, delta))
    print("Throughput : {:8.3f} Tasks/s".format(tasks / delta))
    return delta
def dont_run_yet(endpoint_id=None, tasks=10, duration=1, hostname=None):
    # tasks_rq = EndpointQueue(f'task_{endpoint_id}', hostname)
    tasks_channel = RedisPubSub(hostname)
    tasks_channel.connect()
    redis_client = tasks_channel.redis_client
    redis_client.ping()
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    start = time.time()
    task_ids = {}
    for i in range(tasks):
        time.sleep(duration)
        task_id = str(uuid.uuid4())
        print("Task_id : ", task_id)
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({"duration": duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "RAW"
        task = Task(redis_client, task_id, container_id, serializer="", payload=payload)
        task.endpoint = endpoint_id
        task.status = TaskState.WAITING_FOR_EP
        # tasks_rq.enqueue(task)
        tasks_channel.put(endpoint_id, task)
        task_ids[i] = task_id

    d1 = time.time() - start
    print(f"Time to launch {tasks} tasks: {d1:8.3f} s")

    delay = 5
    print(f"Sleeping {delay} seconds")
    time.sleep(delay)
    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        task_id = task_ids[i]
        print("Task_id : ", task_id)
        task = Task.from_id(redis_client, task_id)
        # TODO: wait for task result...
        time.sleep(duration)
        try:
            result = fxs.deserialize(task.result)
            print(f"Result : {result}")
        except Exception as e:
            print(f"Task failed with exception:{e}")
            pass

    delta = time.time() - start
    print(f"Time to complete {tasks} tasks: {delta:8.3f} s")
    print(f"Throughput : {tasks / delta:8.3f} Tasks/s")
    return delta
예제 #3
0
def serialize_fx_inputs(*args, **kwargs):
    from funcx.serialize import FuncXSerializer
    fx_serializer = FuncXSerializer()
    ser_args = fx_serializer.serialize(args)
    ser_kwargs = fx_serializer.serialize(kwargs)
    payload = fx_serializer.pack_buffers([ser_args, ser_kwargs])
    return payload
예제 #4
0
def serialize_fx_inputs(*args, **kwargs):
    """Pack and serialize inputs
    """
    fx_serializer = FuncXSerializer()
    ser_args = fx_serializer.serialize(args)
    ser_kwargs = fx_serializer.serialize(kwargs)
    payload = fx_serializer.pack_buffers([ser_args, ser_kwargs])
    return payload
예제 #5
0
class Batch:
    """Utility class for creating batch submission in funcX"""
    def __init__(self):
        self.tasks = []
        self.fx_serializer = FuncXSerializer()

    def add(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Add an function invocation to a batch submission

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        None
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'function': function_id,
            'payload': payload
        }

        self.tasks.append(data)

    def prepare(self):
        """Prepare the payloads to be post to web service in a batch

        Parameters
        ----------

        Returns
        -------
        payloads in dictionary, Dict[str, list]
        """
        data = {'tasks': []}

        for task in self.tasks:
            new_task = (task['function'], task['endpoint'], task['payload'])
            data['tasks'].append(new_task)

        return data
예제 #6
0
def dont_run_yet(endpoint_id=None, tasks=10, duration=1, hostname=None):
    tasks_rq = EndpointQueue(f"task_{endpoint_id}", hostname)
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    tasks_rq.connect()
    start = time.time()
    task_ids = {}
    for i in range(tasks):
        task_id = str(uuid.uuid4())
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({"duration": duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "RAW"
        task = Task(tasks_rq.redis_client,
                    task_id,
                    container_id,
                    serializer="",
                    payload=payload)
        tasks_rq.enqueue(task)
        task_ids[i] = task_id

    d1 = time.time() - start
    print(f"Time to launch {tasks} tasks: {d1:8.3f} s")

    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        task_id = task_ids[i]
        task = Task.from_id(tasks_rq.redis_client, task_id)
        # TODO: wait for task result...
        time.sleep(2)
        print(f"Result: {task.result}")
        # res = results_rq.get('result', timeout=300)
        # print("Result : ", res)

    delta = time.time() - start
    print(f"Time to complete {tasks} tasks: {delta:8.3f} s")
    print(f"Throughput : {tasks / delta:8.3f} Tasks/s")
    return delta
예제 #7
0
def test(endpoint_id=None, tasks=10, duration=1, hostname=None, port=None):
    tasks_rq = RedisQueue(f'task_{endpoint_id}', hostname)
    results_rq = RedisQueue('results', hostname)
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    tasks_rq.connect()
    results_rq.connect()

    while True:
        try:
            _ = results_rq.get(timeout=1)
        except Exception:
            print("No more results left")
            break

    start = time.time()
    for i in range(tasks):
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({'duration': duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "odd" if i % 2 else "even"
        tasks_rq.put(f"0{i};{container_id}", payload)

    d1 = time.time() - start
    print("Time to launch {} tasks: {:8.3f} s".format(tasks, d1))

    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        _ = results_rq.get(timeout=300)
        # print("Result : ", res)

    delta = time.time() - start
    print("Time to complete {} tasks: {:8.3f} s".format(tasks, delta))
    print("Throughput : {:8.3f} Tasks/s".format(tasks / delta))
    return delta
예제 #8
0
    results_queue = Queue()
    #    set_file_logger('executor.log', name='funcx_endpoint', level=logging.DEBUG)
    htex = HighThroughputExecutor(interchange_local=True, passthrough=True)

    htex.start(results_passthrough=results_queue)
    htex._start_remote_interchange_process()
    fx_serializer = FuncXSerializer()

    for i in range(10):
        task_id = str(uuid.uuid4())
        args = (i, )
        kwargs = {}

        fn_code = fx_serializer.serialize(double)
        ser_code = fx_serializer.pack_buffers([fn_code])
        ser_params = fx_serializer.pack_buffers(
            [fx_serializer.serialize(args),
             fx_serializer.serialize(kwargs)])

        payload = Task(task_id, "RAW", ser_code + ser_params)
        f = htex.submit_raw(payload.pack())
        time.sleep(0.5)

    for i in range(10):
        result_package = results_queue.get()
        # print("Result package : ", result_package)
        r = pickle.loads(result_package)
        result = fx_serializer.deserialize(r["result"])
        print(f"Result:{i}: {result}")
예제 #9
0
print("Started the endpoint {}".format(endpoint_name))
print("Wating 10 seconds for the endpoint to start")
time.sleep(10)

# Connect to the task and result redis queue
endpoint_id = args.endpoint_id
tasks_rq = RedisQueue(f'task_{endpoint_id}', args.redis_hostname)
results_rq = RedisQueue(f'results', args.redis_hostname)
tasks_rq.connect()
results_rq.connect()
print("Redis queue connected")

# Create an instance of funcx serializer and serialize the function
fxs = FuncXSerializer()
ser_code = fxs.serialize(sleep)
fn_code = fxs.pack_buffers([ser_code])
print("Code serialized")

# Define the test function
def test(tasks=10, durs=[5, 10, 20], timeout=None):
    # Make sure there is no previous result left
    while True:
        try:
            x = results_rq.get(timeout=1)
        except:
            print("No more results left")
            break
    start_submit = time.time()
    time_table = {}
    rounds = 3
    #counts = [[1, 5, 20], [5, 20, 1], [20, 1, 5]]
예제 #10
0
class FuncXClient(throttling.ThrottledBaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = 'funcx_sdk_tokens.json'
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://api.funcx.org/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        funcx_service_address: str
        The address of the funcX web service to communicate with.
        Default: https://api.funcx.org/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.func_table = {}
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename))

        # TODO: if fx_authorizer is given, we still need to get an authorizer for Search
        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
        search_scope = "urn:globus:auth:scope:search.api.globus.org:all"
        scopes = [fx_scope, search_scope, "openid"]

        search_authorizer = None

        if not fx_authorizer:
            self.native_client.login(
                requested_scopes=scopes,
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=scopes)
            fx_authorizer = all_authorizers[fx_scope]
            search_authorizer = all_authorizers[search_scope]
            openid_authorizer = all_authorizers["openid"]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

        authclient = AuthClient(authorizer=openid_authorizer)
        user_info = authclient.oauth2_userinfo()
        self.searcher = SearchHelper(authorizer=search_authorizer,
                                     owner_uuid=user_info['sub'])
        self.funcx_service_address = funcx_service_address

    def version_check(self):
        """Check this client version meets the service's minimum supported version.
        """
        resp = self.get("version", params={"service": "all"})
        versions = resp.data
        if "min_ep_version" not in versions:
            raise VersionMismatch(
                "Failed to retrieve version information from funcX service.")

        min_ep_version = versions['min_ep_version']

        if ENDPOINT_VERSION is None:
            raise VersionMismatch(
                "You do not have the funcx endpoint installed.  You can use 'pip install funcx-endpoint'."
            )
        if ENDPOINT_VERSION < min_ep_version:
            raise VersionMismatch(
                f"Your version={ENDPOINT_VERSION} is lower than the "
                f"minimum version for an endpoint: {min_ep_version}.  Please update."
            )

    def logout(self):
        """Remove credentials from your local system
        """
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """ Parses the return message from the service and updates the internal func_tables

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        status = {'pending': True}

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
                completion_t = r_dict['completion_t']
            except Exception:
                raise SerializationError("Result Object Deserialization")
            else:
                status.update({
                    'pending': False,
                    'result': r_obj,
                    'completion_t': completion_t
                })
                self.func_table[task_id] = status

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                completion_t = r_dict['completion_t']
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise SerializationError(
                    "Task's exception object deserialization")
            else:
                status.update({
                    'pending': False,
                    'exception': r_exception,
                    'completion_t': completion_t
                })
                self.func_table[task_id] = status
        return status

    def get_task(self, task_id):
        """Get a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Task block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.get("tasks/{task_id}".format(task_id=task_id))
        logger.debug("Response string : {}".format(r))
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        task = self.get_task(task_id)
        if task['pending'] is True:
            raise Exception("Task pending")
        else:
            if 'result' in task:
                return task['result']
            else:
                logger.warning("We have an exception : {}".format(
                    task['exception']))
                task['exception'].reraise()

    def get_batch_status(self, task_id_list):
        """ Request status for a batch of task_ids
        """
        assert isinstance(task_id_list,
                          list), "get_batch_status expects a list of task ids"

        pending_task_ids = [
            t for t in task_id_list if t not in self.func_table
        ]

        results = {}

        if pending_task_ids:
            payload = {'task_ids': pending_task_ids}
            r = self.post("/batch_status", json_body=payload)
            logger.debug("Response string : {}".format(r))

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r['results'][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug(
                        "Task {} info was not available in the batch status")
                except Exception:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_status")
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def get_batch_result(self, task_id_list):
        """ Request results for a batch of task_ids
        """
        pass

    def run(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        batch = self.create_batch()
        batch.add(*args,
                  endpoint_id=endpoint_id,
                  function_id=function_id,
                  **kwargs)
        r = self.batch_run(batch)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """

        return r[0]

    def create_batch(self):
        """
        Create a Batch instance to handle batch submission in funcX

        Parameters
        ----------

        Returns
        -------
        Batch instance
            Status block containing "status" key.
        """
        batch = Batch()
        return batch

    def batch_run(self, batch):
        """Initiate a batch of tasks to funcX

        Parameters
        ----------
        batch: a Batch object

        Returns
        -------
        task_ids : a list of UUID strings that identify the tasks
        """
        servable_path = 'submit'
        assert isinstance(batch, Batch), "Requires a Batch object as input"
        assert len(batch.tasks) > 0, "Requires a non-empty batch"

        data = batch.prepare()

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)
        if r.get("status", "Failure") == "Failure":
            raise MalformedResponse("FuncX Request failed: {}".format(
                r.get("reason", "Unknown")))
        return r['task_uuids']

    def map_run(self,
                *args,
                endpoint_id=None,
                function_id=None,
                asynchronous=False,
                **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit_batch'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg, ))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            'endpoints': [endpoint_id],
            'func': function_id,
            'payload': batch_payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status != 200:
            raise Exception(r)

        if r.get("status", "Failure") == "Failure":
            raise MalformedResponse("FuncX Request failed: {}".format(
                r.get("reason", "Unknown")))
        return r['task_uuids']

    def register_endpoint(self,
                          name,
                          endpoint_uuid,
                          metadata=None,
                          endpoint_version=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        metadata : dict
            endpoint metadata, see default_config example
        endpoint_version: str
            Version string to be passed to the webService as a compatibility check

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        self.version_check()

        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "version": endpoint_version
        }
        if metadata:
            data['meta'] = metadata

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['container']

    def get_endpoint_status(self, endpoint_uuid):
        """Get the status reports for an endpoint.

        Parameters
        ----------
        endpoint_uuid : str
            UUID of the endpoint in question

        Returns
        -------
        dict
            The details of the endpoint's stats
        """
        stats_path = f'endpoints/{endpoint_uuid}/status'

        r = self.get(stats_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None,
                          public=False,
                          group=None,
                          searchable=True):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution
        function_name : str
            The entry point (function name) of the function. Default: None
        container_uuid : str
            Container UUID from registration with funcX
        description : str
            Description of the file
        public : bool
            Whether or not the function is publicly accessible. Default = False
        group : str
            A globus group uuid to share this function with
        searchable : bool
            If true, the function will be indexed into globus search with the appropriate permissions

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        source_code = ""
        try:
            source_code = getsource(function)
        except OSError:
            logger.error(
                "Failed to find source code during function registration.")

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "function_source": source_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description,
            "public": public,
            "group": group,
            "searchable": searchable
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        func_uuid = r.data['function_uuid']

        # Return the result
        return func_uuid

    def update_function(self, func_uuid, function):
        pass

    def search_function(self, q, offset=0, limit=10, advanced=False):
        """Search for function via the funcX service

        Parameters
        ----------
        q : str
            free-form query string
        offset : int
            offset into total results
        limit : int
            max number of results to return
        advanced : bool
            allows elastic-search like syntax in query string

        Returns
        -------
        FunctionSearchResults
        """
        return self.searcher.search_function(q,
                                             offset=offset,
                                             limit=limit,
                                             advanced=advanced)

    def search_endpoint(self, q, scope='all', owner_id=None):
        """

        Parameters
        ----------
        q
        scope : str
            Can be one of {'all', 'my-endpoints', 'shared-with-me'}
        owner_id
            should be urn like f"urn:globus:auth:identity:{owner_uuid}"

        Returns
        -------

        """
        return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id)

    def register_container(self,
                           location,
                           container_type,
                           name='',
                           description=''):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker). Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        container_path = 'containers'

        payload = {
            'name': name,
            'location': location,
            'description': description,
            'type': container_type
        }

        r = self.post(container_path, json_body=payload)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['container_id']

    def add_to_whitelist(self, endpoint_id, function_ids):
        """Adds the function to the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        req_path = f'endpoints/{endpoint_id}/whitelist'

        if not isinstance(function_ids, list):
            function_ids = [function_ids]

        payload = {'func': function_ids}

        r = self.post(req_path, json_body=payload)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r

    def get_whitelist(self, endpoint_id):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint

        Returns
        -------
        json
            The response of the request
        """
        req_path = f'endpoints/{endpoint_id}/whitelist'

        r = self.get(req_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r

    def delete_from_whitelist(self, endpoint_id, function_ids):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        if not isinstance(function_ids, list):
            function_ids = [function_ids]
        res = []
        for fid in function_ids:
            req_path = f'endpoints/{endpoint_id}/whitelist/{fid}'

            r = self.delete(req_path)
            if r.http_status != 200:
                raise HTTPError(r)
            res.append(r)

        # Return the result
        return res
예제 #11
0
파일: client.py 프로젝트: NickolausDS/funcX
class FuncXClient(BaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://dev.funcx.org/api/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        service_address: str
        The address of the funcX web service to communicate with.
        Default: https://dev.funcx.org/api/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        native_client = NativeClient(client_id=self.CLIENT_ID)

        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"

        if not fx_authorizer:
            native_client.login(
                requested_scopes=[fx_scope],
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = native_client.get_authorizers_by_scope(
                requested_scopes=[fx_scope])
            fx_authorizer = all_authorizers[fx_scope]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

    def logout(self):
        """Remove credentials from your local system
        """
        logout()

    def get_task_status(self, task_id):
        """Get the status of a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Status block containing "status" key.
        """

        r = self.get("{task_id}/status".format(task_id=task_id))
        return json.loads(r.text)

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """

        r = self.get("{task_id}/status".format(task_id=task_id))

        logger.info(f"Got from globus : {r}")
        r_dict = json.loads(r.text)

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
            except Exception:
                raise Exception(
                    "Failure during deserialization of the result object")
            else:
                return r_obj

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise Exception(
                    "Failure during deserialization of the Task's exception object"
                )
            else:
                r_exception.reraise()

        else:
            raise Exception("Task pending")

    def run(self,
            *args,
            endpoint_id=None,
            function_id=None,
            asynchronous=False,
            **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'func': function_id,
            'payload': payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuid' not in r:
            raise MalformedResponse(r)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """
        return r['task_uuid']

    def register_endpoint(self, name, endpoint_uuid, description=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "description": description
        }

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container']

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution

        function_name : str
            The entry point (function name) of the function. Default: None

        container_uuid : str
            Container UUID from registration with funcX

        description : str
            Description of the file

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['function_uuid']
예제 #12
0
파일: client.py 프로젝트: funcx-faas/funcX
class FuncXClient:
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = "funcx_sdk_tokens.json"
    FUNCX_SDK_CLIENT_ID = os.environ.get(
        "FUNCX_SDK_CLIENT_ID", "4cf29807-cf21-49ec-9443-ff9a3fb9f81c"
    )
    FUNCX_SCOPE = os.environ.get(
        "FUNCX_SCOPE",
        "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all",
    )

    def __init__(
        self,
        http_timeout=None,
        funcx_home=_FUNCX_HOME,
        force_login=False,
        fx_authorizer=None,
        search_authorizer=None,
        openid_authorizer=None,
        funcx_service_address=None,
        check_endpoint_version=False,
        asynchronous=False,
        loop=None,
        results_ws_uri=None,
        use_offprocess_checker=True,
        environment=None,
        **kwargs,
    ):
        """
        Initialize the client

        Parameters
        ----------
        http_timeout: int
            Timeout for any call to service in seconds.
            Default is no timeout

        force_login: bool
            Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer \
            <globus_sdk.authorizers.base.GlobusAuthorizer>`:
            A custom authorizer instance to communicate with funcX.
            Default: ``None``, will be created.

        search_authorizer:class:`GlobusAuthorizer \
            <globus_sdk.authorizers.base.GlobusAuthorizer>`:
            A custom authorizer instance to communicate with Globus Search.
            Default: ``None``, will be created.

        openid_authorizer:class:`GlobusAuthorizer \
            <globus_sdk.authorizers.base.GlobusAuthorizer>`:
            A custom authorizer instance to communicate with OpenID.
            Default: ``None``, will be created.

        funcx_service_address: str
            For internal use only. The address of the web service.

        results_ws_uri: str
            For internal use only. The address of the websocket service.

        environment: str
            For internal use only. The name of the environment to use.

        asynchronous: bool
        Should the API use asynchronous interactions with the web service? Currently
        only impacts the run method
        Default: False

        loop: AbstractEventLoop
        If asynchronous mode is requested, then you can provide an optional event loop
        instance. If None, then we will access asyncio.get_event_loop()
        Default: None

        use_offprocess_checker: Bool,
            Use this option to disable the offprocess_checker in the FuncXSerializer
            used by the client.
            Default: True

        Keyword arguments are the same as for BaseClient.

        """
        # resolve URLs if not set
        if funcx_service_address is None:
            funcx_service_address = get_web_service_url(environment)
        if results_ws_uri is None:
            results_ws_uri = get_web_socket_url(environment)

        self.func_table = {}
        self.use_offprocess_checker = use_offprocess_checker
        self.funcx_home = os.path.expanduser(funcx_home)
        self.session_task_group_id = str(uuid.uuid4())

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.FUNCX_SDK_CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename),
        )

        # TODO: if fx_authorizer is given, we still need to get an authorizer for Search
        search_scope = "urn:globus:auth:scope:search.api.globus.org:all"
        scopes = [self.FUNCX_SCOPE, search_scope, "openid"]

        if not fx_authorizer or not search_authorizer or not openid_authorizer:
            self.native_client.login(
                requested_scopes=scopes,
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login,
            )

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=scopes
            )
            fx_authorizer = all_authorizers[self.FUNCX_SCOPE]
            search_authorizer = all_authorizers[search_scope]
            openid_authorizer = all_authorizers["openid"]

        self.web_client = FuncxWebClient(
            base_url=funcx_service_address, authorizer=fx_authorizer
        )
        self.fx_serializer = FuncXSerializer(
            use_offprocess_checker=self.use_offprocess_checker
        )

        authclient = AuthClient(authorizer=openid_authorizer)
        user_info = authclient.oauth2_userinfo()
        self.searcher = SearchHelper(
            authorizer=search_authorizer, owner_uuid=user_info["sub"]
        )
        self.funcx_service_address = funcx_service_address
        self.check_endpoint_version = check_endpoint_version

        self.version_check()

        self.results_ws_uri = results_ws_uri
        self.asynchronous = asynchronous
        if asynchronous:
            self.loop = loop if loop else asyncio.get_event_loop()

            # Start up an asynchronous polling loop in the background
            self.ws_polling_task = WebSocketPollingTask(
                self,
                self.loop,
                init_task_group_id=self.session_task_group_id,
                results_ws_uri=self.results_ws_uri,
            )
        else:
            self.loop = None

    def version_check(self):
        """Check this client version meets the service's minimum supported version."""
        resp = self.web_client.get_version()
        versions = resp.data
        if "min_ep_version" not in versions:
            raise VersionMismatch(
                "Failed to retrieve version information from funcX service."
            )

        min_ep_version = versions["min_ep_version"]
        min_sdk_version = versions["min_sdk_version"]

        if self.check_endpoint_version:
            if ENDPOINT_VERSION is None:
                raise VersionMismatch(
                    "You do not have the funcx endpoint installed.  "
                    "You can use 'pip install funcx-endpoint'."
                )
            if LooseVersion(ENDPOINT_VERSION) < LooseVersion(min_ep_version):
                raise VersionMismatch(
                    f"Your version={ENDPOINT_VERSION} is lower than the "
                    f"minimum version for an endpoint: {min_ep_version}.  "
                    "Please update. "
                    f"pip install funcx-endpoint>={min_ep_version}"
                )
        else:
            if LooseVersion(SDK_VERSION) < LooseVersion(min_sdk_version):
                raise VersionMismatch(
                    f"Your version={SDK_VERSION} is lower than the "
                    f"minimum version for funcx SDK: {min_sdk_version}.  "
                    "Please update. "
                    f"pip install funcx>={min_sdk_version}"
                )

    def logout(self):
        """Remove credentials from your local system"""
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """Parses the return message from the service and updates the internal func_table

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        r_status = r_dict.get("status", "unknown")
        status = {"pending": True, "status": r_status}

        if "result" in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict["result"])
                completion_t = r_dict["completion_t"]
            except Exception:
                raise SerializationError("Result Object Deserialization")
            else:
                status.update(
                    {"pending": False, "result": r_obj, "completion_t": completion_t}
                )
                self.func_table[task_id] = status

        elif "exception" in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(r_dict["exception"])
                completion_t = r_dict["completion_t"]
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise SerializationError("Task's exception object deserialization")
            else:
                status.update(
                    {
                        "pending": False,
                        "exception": r_exception,
                        "completion_t": completion_t,
                    }
                )
                self.func_table[task_id] = status
        return status

    def get_task(self, task_id):
        """Get a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Task block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.web_client.get_task(task_id)
        logger.debug(f"Response string : {r}")
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        task = self.get_task(task_id)
        if task["pending"] is True:
            raise TaskPending(task["status"])
        else:
            if "result" in task:
                return task["result"]
            else:
                logger.warning("We have an exception : {}".format(task["exception"]))
                task["exception"].reraise()

    def get_batch_result(self, task_id_list):
        """Request status for a batch of task_ids"""
        assert isinstance(
            task_id_list, list
        ), "get_batch_result expects a list of task ids"

        pending_task_ids = [t for t in task_id_list if t not in self.func_table]

        results = {}

        if pending_task_ids:
            r = self.web_client.get_batch_status(pending_task_ids)
            logger.debug(f"Response string : {r}")

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r["results"][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug("Task {} info was not available in the batch status")
                except Exception:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_result"
                    )
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def run(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task if asynchronous is False

        funcX Task: asyncio.Task
        A future that will eventually resolve into the function's result if
        asynchronous is True
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        batch = self.create_batch()
        batch.add(*args, endpoint_id=endpoint_id, function_id=function_id, **kwargs)
        r = self.batch_run(batch)

        return r[0]

    def create_batch(self, task_group_id=None):
        """
        Create a Batch instance to handle batch submission in funcX

        Parameters
        ----------

        task_group_id : str
            Override the session wide session_task_group_id with a different
            task_group_id for this batch.
            If task_group_id is not specified, it will default to using the client's
            session_task_group_id

        Returns
        -------
        Batch instance
            Status block containing "status" key.
        """
        if not task_group_id:
            task_group_id = self.session_task_group_id

        batch = Batch(task_group_id=task_group_id)
        return batch

    def batch_run(self, batch):
        """Initiate a batch of tasks to funcX

        Parameters
        ----------
        batch: a Batch object

        Returns
        -------
        task_ids : a list of UUID strings that identify the tasks
        """
        assert isinstance(batch, Batch), "Requires a Batch object as input"
        assert len(batch.tasks) > 0, "Requires a non-empty batch"

        data = batch.prepare()

        # Send the data to funcX
        r = self.web_client.submit(data)

        task_uuids = []
        for result in r["results"]:
            task_id = result["task_uuid"]
            task_uuids.append(task_id)
            if result["http_status_code"] != 200:
                # this method of handling errors for a batch response is not
                # ideal, as it will raise any error in the multi-response,
                # but it will do until batch_run is deprecated in favor of Executer
                handle_response_errors(result)

        if self.asynchronous:
            task_group_id = r["task_group_id"]
            asyncio_tasks = []
            for task_id in task_uuids:
                funcx_task = FuncXTask(task_id)
                asyncio_task = self.loop.create_task(funcx_task.get_result())
                asyncio_tasks.append(asyncio_task)

                self.ws_polling_task.add_task(funcx_task)
            self.ws_polling_task.put_task_group_id(task_group_id)
            return asyncio_tasks

        return task_uuids

    def map_run(
        self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs
    ):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg,))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            "endpoints": [endpoint_id],
            "func": function_id,
            "payload": batch_payload,
            "is_async": asynchronous,
        }

        # Send the data to funcX
        r = self.web_client.submit_batch(data)
        return r["task_uuids"]

    def register_endpoint(
        self, name, endpoint_uuid, metadata=None, endpoint_version=None
    ):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        metadata : dict
            endpoint metadata, see default_config example
        endpoint_version: str
            Version string to be passed to the webService as a compatibility check

        Returns
        -------
        A dict
            {'endpoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        self.version_check()

        r = self.web_client.register_endpoint(
            endpoint_name=name,
            endpoint_id=endpoint_uuid,
            metadata=metadata,
            endpoint_version=endpoint_version,
        )
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        data = {"endpoint_name": name, "description": description}

        r = self.web_client.post("get_containers", data=data)
        return r.data["endpoint_uuid"], r.data["endpoint_containers"]

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        self.version_check()

        r = self.web_client.get(f"containers/{container_uuid}/{container_type}")
        return r.data["container"]

    def get_endpoint_status(self, endpoint_uuid):
        """Get the status reports for an endpoint.

        Parameters
        ----------
        endpoint_uuid : str
            UUID of the endpoint in question

        Returns
        -------
        dict
            The details of the endpoint's stats
        """
        r = self.web_client.get_endpoint_status(endpoint_uuid)
        return r.data

    def register_function(
        self,
        function,
        function_name=None,
        container_uuid=None,
        description=None,
        public=False,
        group=None,
        searchable=True,
    ):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution
        function_name : str
            The entry point (function name) of the function. Default: None
        container_uuid : str
            Container UUID from registration with funcX
        description : str
            Description of the file
        public : bool
            Whether or not the function is publicly accessible. Default = False
        group : str
            A globus group uuid to share this function with
        searchable : bool
            If true, the function will be indexed into globus search with the
            appropriate permissions

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        data = FunctionRegistrationData(
            function=function,
            failover_source="",
            container_uuid=container_uuid,
            entry_point=function_name,
            description=description,
            public=public,
            group=group,
            searchable=searchable,
            serializer=self.fx_serializer,
        )
        logger.info(f"Registering function : {data}")
        r = self.web_client.register_function(data)
        return r.data["function_uuid"]

    def search_function(self, q, offset=0, limit=10, advanced=False):
        """Search for function via the funcX service

        Parameters
        ----------
        q : str
            free-form query string
        offset : int
            offset into total results
        limit : int
            max number of results to return
        advanced : bool
            allows elastic-search like syntax in query string

        Returns
        -------
        FunctionSearchResults
        """
        return self.searcher.search_function(
            q, offset=offset, limit=limit, advanced=advanced
        )

    def search_endpoint(self, q, scope="all", owner_id=None):
        """

        Parameters
        ----------
        q
        scope : str
            Can be one of {'all', 'my-endpoints', 'shared-with-me'}
        owner_id
            should be urn like f"urn:globus:auth:identity:{owner_uuid}"

        Returns
        -------

        """
        return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id)

    def register_container(self, location, container_type, name="", description=""):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker).
            Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        payload = {
            "name": name,
            "location": location,
            "description": description,
            "type": container_type,
        }

        r = self.web_client.post("containers", data=payload)
        return r.data["container_id"]

    def add_to_whitelist(self, endpoint_id, function_ids):
        """Adds the function to the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        return self.web_client.whitelist_add(endpoint_id, function_ids)

    def get_whitelist(self, endpoint_id):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint

        Returns
        -------
        json
            The response of the request
        """
        return self.web_client.get_whitelist(endpoint_id)

    def delete_from_whitelist(self, endpoint_id, function_ids):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        if not isinstance(function_ids, list):
            function_ids = [function_ids]
        res = []
        for fid in function_ids:
            res.append(self.web_client.whitelist_remove(endpoint_id, fid))
        return res
예제 #13
0
파일: client.py 프로젝트: vincent-pli/funcX
class FuncXClient(throttling.ThrottledBaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = 'funcx_sdk_tokens.json'
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://funcx.org/api/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        service_address: str
        The address of the funcX web service to communicate with.
        Default: https://dev.funcx.org/api/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.func_table = {}
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename))

        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"

        if not fx_authorizer:
            self.native_client.login(
                requested_scopes=[fx_scope],
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=[fx_scope])
            fx_authorizer = all_authorizers[fx_scope]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

    def logout(self):
        """Remove credentials from your local system
        """
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """ Parses the return message from the service and updates the internal func_tables

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        status = {'pending': True}

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
            except Exception:
                raise Exception(
                    "Failure during deserialization of the result object")
            else:
                status.update({'pending': 'False', 'result': r_obj})
                self.func_table[task_id] = status

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise Exception(
                    "Failure during deserialization of the Task's exception object"
                )
            else:
                status.update({'pending': 'False', 'exception': r_exception})
                self.func_table[task_id] = status
        return status

    def get_task_status(self, task_id):
        """Get the status of a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Status block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.get("{task_id}/status".format(task_id=task_id))
        logger.debug("Response string : {}".format(r))
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        status = self.get_task_status(task_id)
        if status['pending'] is True:
            raise Exception("Task pending")
        else:
            if 'result' in status:
                return status['result']
            else:
                logger.warn("We have an exception : {}".format(
                    status['exception']))
                status['exception'].reraise()

    def get_batch_status(self, task_id_list):
        """ Request status for a batch of task_ids
        """
        assert isinstance(task_id_list,
                          list), "get_batch_status expects a list of task ids"

        pending_task_ids = [
            t for t in task_id_list if t not in self.func_table
        ]

        results = {}

        if pending_task_ids:
            payload = {'task_ids': pending_task_ids}
            r = self.post("/batch_status", json_body=payload)
            logger.debug("Response string : {}".format(r))

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r['results'][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug(
                        "Task {} info was not available in the batch status")
                except Exception as e:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_status")
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def get_batch_result(self, task_id_list):
        """ Request results for a batch of task_ids
        """
        pass

    def run(self,
            *args,
            endpoint_id=None,
            function_id=None,
            asynchronous=False,
            **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'func': function_id,
            'payload': payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuid' not in r:
            raise MalformedResponse(r)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """
        return r['task_uuid']

    def map_run(self,
                *args,
                endpoint_id=None,
                function_id=None,
                asynchronous=False,
                **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit_batch'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg, ))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            'endpoints': [endpoint_id],
            'func': function_id,
            'payload': batch_payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuids' not in r:
            raise MalformedResponse(r)

        return r['task_uuids']

    def register_endpoint(self, name, endpoint_uuid, description=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "description": description
        }

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container']

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution

        function_name : str
            The entry point (function name) of the function. Default: None

        container_uuid : str
            Container UUID from registration with funcX

        description : str
            Description of the file

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['function_uuid']

    def register_container(self,
                           location,
                           container_type,
                           name='',
                           description=''):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker). Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        container_path = f'containers'

        payload = {
            'name': name,
            'location': location,
            'description': description,
            'type': container_type
        }

        r = self.post(container_path, json_body=payload)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container_id']