def dont_run_yet(endpoint_id=None, tasks=10, duration=1, hostname=None):
    # tasks_rq = EndpointQueue(f'task_{endpoint_id}', hostname)
    tasks_channel = RedisPubSub(hostname)
    tasks_channel.connect()
    redis_client = tasks_channel.redis_client
    redis_client.ping()
    fxs = FuncXSerializer()

    ser_code = fxs.serialize(slow_double)
    fn_code = fxs.pack_buffers([ser_code])

    start = time.time()
    task_ids = {}
    for i in range(tasks):
        time.sleep(duration)
        task_id = str(uuid.uuid4())
        print("Task_id : ", task_id)
        ser_args = fxs.serialize([i])
        ser_kwargs = fxs.serialize({"duration": duration})
        input_data = fxs.pack_buffers([ser_args, ser_kwargs])
        payload = fn_code + input_data
        container_id = "RAW"
        task = Task(redis_client, task_id, container_id, serializer="", payload=payload)
        task.endpoint = endpoint_id
        task.status = TaskState.WAITING_FOR_EP
        # tasks_rq.enqueue(task)
        tasks_channel.put(endpoint_id, task)
        task_ids[i] = task_id

    d1 = time.time() - start
    print(f"Time to launch {tasks} tasks: {d1:8.3f} s")

    delay = 5
    print(f"Sleeping {delay} seconds")
    time.sleep(delay)
    print(f"Launched {tasks} tasks")
    for i in range(tasks):
        task_id = task_ids[i]
        print("Task_id : ", task_id)
        task = Task.from_id(redis_client, task_id)
        # TODO: wait for task result...
        time.sleep(duration)
        try:
            result = fxs.deserialize(task.result)
            print(f"Result : {result}")
        except Exception as e:
            print(f"Task failed with exception:{e}")
            pass

    delta = time.time() - start
    print(f"Time to complete {tasks} tasks: {delta:8.3f} s")
    print(f"Throughput : {tasks / delta:8.3f} Tasks/s")
    return delta
예제 #2
0
def deserialize():
    """Return the deserialized result
    """

    fx_serializer = FuncXSerializer()
    # Return a failure message if all else fails
    ret_package = {'error': 'Failed to deserialize result'}
    try:
        inputs = request.json
        res = fx_serializer.deserialize(inputs)
        ret_package = jsonify(res)
    except Exception as e:
        print(e)
        return jsonify(ret_package), 500
    return ret_package, 200
예제 #3
0
def server(port=0, host="", debug=False, datasize=102400):

    try:
        from funcx.serialize import FuncXSerializer

        fxs = FuncXSerializer(use_offprocess_checker=False)
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind((host, port))
            bound_port = s.getsockname()[1]
            print(f"BINDING TO:{bound_port}", flush=True)
            s.listen(1)
            conn, addr = s.accept()  # we only expect one incoming connection here.
            with conn:
                while True:

                    b_msg = conn.recv(datasize)
                    if not b_msg:
                        print("Exiting")
                        return

                    msg = pickle.loads(b_msg)

                    if msg == "PING":
                        ret_value = ("PONG", None)
                    else:
                        try:
                            method = fxs.deserialize(msg)  # noqa
                            del method
                        except Exception as e:
                            ret_value = ("DESERIALIZE_FAIL", str(e))

                        else:
                            ret_value = ("SUCCESS", None)

                    ret_buf = pickle.dumps(ret_value)
                    conn.sendall(ret_buf)
    except Exception as e:
        print(f"OFF_PROCESS_CHECKER FAILURE, Exception:{e}")
        sys.exit()
예제 #4
0
    #    set_file_logger('executor.log', name='funcx_endpoint', level=logging.DEBUG)
    htex = HighThroughputExecutor(interchange_local=True, passthrough=True)

    htex.start(results_passthrough=results_queue)
    htex._start_remote_interchange_process()
    fx_serializer = FuncXSerializer()

    for i in range(10):
        task_id = str(uuid.uuid4())
        args = (i, )
        kwargs = {}

        fn_code = fx_serializer.serialize(double)
        ser_code = fx_serializer.pack_buffers([fn_code])
        ser_params = fx_serializer.pack_buffers(
            [fx_serializer.serialize(args),
             fx_serializer.serialize(kwargs)])

        payload = Task(task_id, "RAW", ser_code + ser_params)
        f = htex.submit_raw(payload.pack())
        time.sleep(0.5)

    for i in range(10):
        result_package = results_queue.get()
        # print("Result package : ", result_package)
        r = pickle.loads(result_package)
        result = fx_serializer.deserialize(r["result"])
        print(f"Result:{i}: {result}")

    print("All done")
예제 #5
0
class FuncXClient(throttling.ThrottledBaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = 'funcx_sdk_tokens.json'
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://api.funcx.org/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        funcx_service_address: str
        The address of the funcX web service to communicate with.
        Default: https://api.funcx.org/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.func_table = {}
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename))

        # TODO: if fx_authorizer is given, we still need to get an authorizer for Search
        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
        search_scope = "urn:globus:auth:scope:search.api.globus.org:all"
        scopes = [fx_scope, search_scope, "openid"]

        search_authorizer = None

        if not fx_authorizer:
            self.native_client.login(
                requested_scopes=scopes,
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=scopes)
            fx_authorizer = all_authorizers[fx_scope]
            search_authorizer = all_authorizers[search_scope]
            openid_authorizer = all_authorizers["openid"]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

        authclient = AuthClient(authorizer=openid_authorizer)
        user_info = authclient.oauth2_userinfo()
        self.searcher = SearchHelper(authorizer=search_authorizer,
                                     owner_uuid=user_info['sub'])
        self.funcx_service_address = funcx_service_address

    def version_check(self):
        """Check this client version meets the service's minimum supported version.
        """
        resp = self.get("version", params={"service": "all"})
        versions = resp.data
        if "min_ep_version" not in versions:
            raise VersionMismatch(
                "Failed to retrieve version information from funcX service.")

        min_ep_version = versions['min_ep_version']

        if ENDPOINT_VERSION is None:
            raise VersionMismatch(
                "You do not have the funcx endpoint installed.  You can use 'pip install funcx-endpoint'."
            )
        if ENDPOINT_VERSION < min_ep_version:
            raise VersionMismatch(
                f"Your version={ENDPOINT_VERSION} is lower than the "
                f"minimum version for an endpoint: {min_ep_version}.  Please update."
            )

    def logout(self):
        """Remove credentials from your local system
        """
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """ Parses the return message from the service and updates the internal func_tables

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        status = {'pending': True}

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
                completion_t = r_dict['completion_t']
            except Exception:
                raise SerializationError("Result Object Deserialization")
            else:
                status.update({
                    'pending': False,
                    'result': r_obj,
                    'completion_t': completion_t
                })
                self.func_table[task_id] = status

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                completion_t = r_dict['completion_t']
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise SerializationError(
                    "Task's exception object deserialization")
            else:
                status.update({
                    'pending': False,
                    'exception': r_exception,
                    'completion_t': completion_t
                })
                self.func_table[task_id] = status
        return status

    def get_task(self, task_id):
        """Get a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Task block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.get("tasks/{task_id}".format(task_id=task_id))
        logger.debug("Response string : {}".format(r))
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        task = self.get_task(task_id)
        if task['pending'] is True:
            raise Exception("Task pending")
        else:
            if 'result' in task:
                return task['result']
            else:
                logger.warning("We have an exception : {}".format(
                    task['exception']))
                task['exception'].reraise()

    def get_batch_status(self, task_id_list):
        """ Request status for a batch of task_ids
        """
        assert isinstance(task_id_list,
                          list), "get_batch_status expects a list of task ids"

        pending_task_ids = [
            t for t in task_id_list if t not in self.func_table
        ]

        results = {}

        if pending_task_ids:
            payload = {'task_ids': pending_task_ids}
            r = self.post("/batch_status", json_body=payload)
            logger.debug("Response string : {}".format(r))

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r['results'][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug(
                        "Task {} info was not available in the batch status")
                except Exception:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_status")
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def get_batch_result(self, task_id_list):
        """ Request results for a batch of task_ids
        """
        pass

    def run(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        batch = self.create_batch()
        batch.add(*args,
                  endpoint_id=endpoint_id,
                  function_id=function_id,
                  **kwargs)
        r = self.batch_run(batch)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """

        return r[0]

    def create_batch(self):
        """
        Create a Batch instance to handle batch submission in funcX

        Parameters
        ----------

        Returns
        -------
        Batch instance
            Status block containing "status" key.
        """
        batch = Batch()
        return batch

    def batch_run(self, batch):
        """Initiate a batch of tasks to funcX

        Parameters
        ----------
        batch: a Batch object

        Returns
        -------
        task_ids : a list of UUID strings that identify the tasks
        """
        servable_path = 'submit'
        assert isinstance(batch, Batch), "Requires a Batch object as input"
        assert len(batch.tasks) > 0, "Requires a non-empty batch"

        data = batch.prepare()

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)
        if r.get("status", "Failure") == "Failure":
            raise MalformedResponse("FuncX Request failed: {}".format(
                r.get("reason", "Unknown")))
        return r['task_uuids']

    def map_run(self,
                *args,
                endpoint_id=None,
                function_id=None,
                asynchronous=False,
                **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit_batch'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg, ))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            'endpoints': [endpoint_id],
            'func': function_id,
            'payload': batch_payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status != 200:
            raise Exception(r)

        if r.get("status", "Failure") == "Failure":
            raise MalformedResponse("FuncX Request failed: {}".format(
                r.get("reason", "Unknown")))
        return r['task_uuids']

    def register_endpoint(self,
                          name,
                          endpoint_uuid,
                          metadata=None,
                          endpoint_version=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        metadata : dict
            endpoint metadata, see default_config example
        endpoint_version: str
            Version string to be passed to the webService as a compatibility check

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        self.version_check()

        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "version": endpoint_version
        }
        if metadata:
            data['meta'] = metadata

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['container']

    def get_endpoint_status(self, endpoint_uuid):
        """Get the status reports for an endpoint.

        Parameters
        ----------
        endpoint_uuid : str
            UUID of the endpoint in question

        Returns
        -------
        dict
            The details of the endpoint's stats
        """
        stats_path = f'endpoints/{endpoint_uuid}/status'

        r = self.get(stats_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None,
                          public=False,
                          group=None,
                          searchable=True):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution
        function_name : str
            The entry point (function name) of the function. Default: None
        container_uuid : str
            Container UUID from registration with funcX
        description : str
            Description of the file
        public : bool
            Whether or not the function is publicly accessible. Default = False
        group : str
            A globus group uuid to share this function with
        searchable : bool
            If true, the function will be indexed into globus search with the appropriate permissions

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        source_code = ""
        try:
            source_code = getsource(function)
        except OSError:
            logger.error(
                "Failed to find source code during function registration.")

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "function_source": source_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description,
            "public": public,
            "group": group,
            "searchable": searchable
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status != 200:
            raise HTTPError(r)

        func_uuid = r.data['function_uuid']

        # Return the result
        return func_uuid

    def update_function(self, func_uuid, function):
        pass

    def search_function(self, q, offset=0, limit=10, advanced=False):
        """Search for function via the funcX service

        Parameters
        ----------
        q : str
            free-form query string
        offset : int
            offset into total results
        limit : int
            max number of results to return
        advanced : bool
            allows elastic-search like syntax in query string

        Returns
        -------
        FunctionSearchResults
        """
        return self.searcher.search_function(q,
                                             offset=offset,
                                             limit=limit,
                                             advanced=advanced)

    def search_endpoint(self, q, scope='all', owner_id=None):
        """

        Parameters
        ----------
        q
        scope : str
            Can be one of {'all', 'my-endpoints', 'shared-with-me'}
        owner_id
            should be urn like f"urn:globus:auth:identity:{owner_uuid}"

        Returns
        -------

        """
        return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id)

    def register_container(self,
                           location,
                           container_type,
                           name='',
                           description=''):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker). Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        container_path = 'containers'

        payload = {
            'name': name,
            'location': location,
            'description': description,
            'type': container_type
        }

        r = self.post(container_path, json_body=payload)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r.data['container_id']

    def add_to_whitelist(self, endpoint_id, function_ids):
        """Adds the function to the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        req_path = f'endpoints/{endpoint_id}/whitelist'

        if not isinstance(function_ids, list):
            function_ids = [function_ids]

        payload = {'func': function_ids}

        r = self.post(req_path, json_body=payload)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r

    def get_whitelist(self, endpoint_id):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint

        Returns
        -------
        json
            The response of the request
        """
        req_path = f'endpoints/{endpoint_id}/whitelist'

        r = self.get(req_path)
        if r.http_status != 200:
            raise HTTPError(r)

        # Return the result
        return r

    def delete_from_whitelist(self, endpoint_id, function_ids):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        if not isinstance(function_ids, list):
            function_ids = [function_ids]
        res = []
        for fid in function_ids:
            req_path = f'endpoints/{endpoint_id}/whitelist/{fid}'

            r = self.delete(req_path)
            if r.http_status != 200:
                raise HTTPError(r)
            res.append(r)

        # Return the result
        return res
예제 #6
0
파일: client.py 프로젝트: NickolausDS/funcX
class FuncXClient(BaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://dev.funcx.org/api/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        service_address: str
        The address of the funcX web service to communicate with.
        Default: https://dev.funcx.org/api/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        native_client = NativeClient(client_id=self.CLIENT_ID)

        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"

        if not fx_authorizer:
            native_client.login(
                requested_scopes=[fx_scope],
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = native_client.get_authorizers_by_scope(
                requested_scopes=[fx_scope])
            fx_authorizer = all_authorizers[fx_scope]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

    def logout(self):
        """Remove credentials from your local system
        """
        logout()

    def get_task_status(self, task_id):
        """Get the status of a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Status block containing "status" key.
        """

        r = self.get("{task_id}/status".format(task_id=task_id))
        return json.loads(r.text)

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """

        r = self.get("{task_id}/status".format(task_id=task_id))

        logger.info(f"Got from globus : {r}")
        r_dict = json.loads(r.text)

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
            except Exception:
                raise Exception(
                    "Failure during deserialization of the result object")
            else:
                return r_obj

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise Exception(
                    "Failure during deserialization of the Task's exception object"
                )
            else:
                r_exception.reraise()

        else:
            raise Exception("Task pending")

    def run(self,
            *args,
            endpoint_id=None,
            function_id=None,
            asynchronous=False,
            **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'func': function_id,
            'payload': payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuid' not in r:
            raise MalformedResponse(r)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """
        return r['task_uuid']

    def register_endpoint(self, name, endpoint_uuid, description=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "description": description
        }

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container']

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution

        function_name : str
            The entry point (function name) of the function. Default: None

        container_uuid : str
            Container UUID from registration with funcX

        description : str
            Description of the file

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['function_uuid']
예제 #7
0
파일: client.py 프로젝트: funcx-faas/funcX
class FuncXClient:
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = "funcx_sdk_tokens.json"
    FUNCX_SDK_CLIENT_ID = os.environ.get(
        "FUNCX_SDK_CLIENT_ID", "4cf29807-cf21-49ec-9443-ff9a3fb9f81c"
    )
    FUNCX_SCOPE = os.environ.get(
        "FUNCX_SCOPE",
        "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all",
    )

    def __init__(
        self,
        http_timeout=None,
        funcx_home=_FUNCX_HOME,
        force_login=False,
        fx_authorizer=None,
        search_authorizer=None,
        openid_authorizer=None,
        funcx_service_address=None,
        check_endpoint_version=False,
        asynchronous=False,
        loop=None,
        results_ws_uri=None,
        use_offprocess_checker=True,
        environment=None,
        **kwargs,
    ):
        """
        Initialize the client

        Parameters
        ----------
        http_timeout: int
            Timeout for any call to service in seconds.
            Default is no timeout

        force_login: bool
            Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer \
            <globus_sdk.authorizers.base.GlobusAuthorizer>`:
            A custom authorizer instance to communicate with funcX.
            Default: ``None``, will be created.

        search_authorizer:class:`GlobusAuthorizer \
            <globus_sdk.authorizers.base.GlobusAuthorizer>`:
            A custom authorizer instance to communicate with Globus Search.
            Default: ``None``, will be created.

        openid_authorizer:class:`GlobusAuthorizer \
            <globus_sdk.authorizers.base.GlobusAuthorizer>`:
            A custom authorizer instance to communicate with OpenID.
            Default: ``None``, will be created.

        funcx_service_address: str
            For internal use only. The address of the web service.

        results_ws_uri: str
            For internal use only. The address of the websocket service.

        environment: str
            For internal use only. The name of the environment to use.

        asynchronous: bool
        Should the API use asynchronous interactions with the web service? Currently
        only impacts the run method
        Default: False

        loop: AbstractEventLoop
        If asynchronous mode is requested, then you can provide an optional event loop
        instance. If None, then we will access asyncio.get_event_loop()
        Default: None

        use_offprocess_checker: Bool,
            Use this option to disable the offprocess_checker in the FuncXSerializer
            used by the client.
            Default: True

        Keyword arguments are the same as for BaseClient.

        """
        # resolve URLs if not set
        if funcx_service_address is None:
            funcx_service_address = get_web_service_url(environment)
        if results_ws_uri is None:
            results_ws_uri = get_web_socket_url(environment)

        self.func_table = {}
        self.use_offprocess_checker = use_offprocess_checker
        self.funcx_home = os.path.expanduser(funcx_home)
        self.session_task_group_id = str(uuid.uuid4())

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.FUNCX_SDK_CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename),
        )

        # TODO: if fx_authorizer is given, we still need to get an authorizer for Search
        search_scope = "urn:globus:auth:scope:search.api.globus.org:all"
        scopes = [self.FUNCX_SCOPE, search_scope, "openid"]

        if not fx_authorizer or not search_authorizer or not openid_authorizer:
            self.native_client.login(
                requested_scopes=scopes,
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login,
            )

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=scopes
            )
            fx_authorizer = all_authorizers[self.FUNCX_SCOPE]
            search_authorizer = all_authorizers[search_scope]
            openid_authorizer = all_authorizers["openid"]

        self.web_client = FuncxWebClient(
            base_url=funcx_service_address, authorizer=fx_authorizer
        )
        self.fx_serializer = FuncXSerializer(
            use_offprocess_checker=self.use_offprocess_checker
        )

        authclient = AuthClient(authorizer=openid_authorizer)
        user_info = authclient.oauth2_userinfo()
        self.searcher = SearchHelper(
            authorizer=search_authorizer, owner_uuid=user_info["sub"]
        )
        self.funcx_service_address = funcx_service_address
        self.check_endpoint_version = check_endpoint_version

        self.version_check()

        self.results_ws_uri = results_ws_uri
        self.asynchronous = asynchronous
        if asynchronous:
            self.loop = loop if loop else asyncio.get_event_loop()

            # Start up an asynchronous polling loop in the background
            self.ws_polling_task = WebSocketPollingTask(
                self,
                self.loop,
                init_task_group_id=self.session_task_group_id,
                results_ws_uri=self.results_ws_uri,
            )
        else:
            self.loop = None

    def version_check(self):
        """Check this client version meets the service's minimum supported version."""
        resp = self.web_client.get_version()
        versions = resp.data
        if "min_ep_version" not in versions:
            raise VersionMismatch(
                "Failed to retrieve version information from funcX service."
            )

        min_ep_version = versions["min_ep_version"]
        min_sdk_version = versions["min_sdk_version"]

        if self.check_endpoint_version:
            if ENDPOINT_VERSION is None:
                raise VersionMismatch(
                    "You do not have the funcx endpoint installed.  "
                    "You can use 'pip install funcx-endpoint'."
                )
            if LooseVersion(ENDPOINT_VERSION) < LooseVersion(min_ep_version):
                raise VersionMismatch(
                    f"Your version={ENDPOINT_VERSION} is lower than the "
                    f"minimum version for an endpoint: {min_ep_version}.  "
                    "Please update. "
                    f"pip install funcx-endpoint>={min_ep_version}"
                )
        else:
            if LooseVersion(SDK_VERSION) < LooseVersion(min_sdk_version):
                raise VersionMismatch(
                    f"Your version={SDK_VERSION} is lower than the "
                    f"minimum version for funcx SDK: {min_sdk_version}.  "
                    "Please update. "
                    f"pip install funcx>={min_sdk_version}"
                )

    def logout(self):
        """Remove credentials from your local system"""
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """Parses the return message from the service and updates the internal func_table

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        r_status = r_dict.get("status", "unknown")
        status = {"pending": True, "status": r_status}

        if "result" in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict["result"])
                completion_t = r_dict["completion_t"]
            except Exception:
                raise SerializationError("Result Object Deserialization")
            else:
                status.update(
                    {"pending": False, "result": r_obj, "completion_t": completion_t}
                )
                self.func_table[task_id] = status

        elif "exception" in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(r_dict["exception"])
                completion_t = r_dict["completion_t"]
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise SerializationError("Task's exception object deserialization")
            else:
                status.update(
                    {
                        "pending": False,
                        "exception": r_exception,
                        "completion_t": completion_t,
                    }
                )
                self.func_table[task_id] = status
        return status

    def get_task(self, task_id):
        """Get a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Task block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.web_client.get_task(task_id)
        logger.debug(f"Response string : {r}")
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        task = self.get_task(task_id)
        if task["pending"] is True:
            raise TaskPending(task["status"])
        else:
            if "result" in task:
                return task["result"]
            else:
                logger.warning("We have an exception : {}".format(task["exception"]))
                task["exception"].reraise()

    def get_batch_result(self, task_id_list):
        """Request status for a batch of task_ids"""
        assert isinstance(
            task_id_list, list
        ), "get_batch_result expects a list of task ids"

        pending_task_ids = [t for t in task_id_list if t not in self.func_table]

        results = {}

        if pending_task_ids:
            r = self.web_client.get_batch_status(pending_task_ids)
            logger.debug(f"Response string : {r}")

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r["results"][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug("Task {} info was not available in the batch status")
                except Exception:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_result"
                    )
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def run(self, *args, endpoint_id=None, function_id=None, **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task if asynchronous is False

        funcX Task: asyncio.Task
        A future that will eventually resolve into the function's result if
        asynchronous is True
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        batch = self.create_batch()
        batch.add(*args, endpoint_id=endpoint_id, function_id=function_id, **kwargs)
        r = self.batch_run(batch)

        return r[0]

    def create_batch(self, task_group_id=None):
        """
        Create a Batch instance to handle batch submission in funcX

        Parameters
        ----------

        task_group_id : str
            Override the session wide session_task_group_id with a different
            task_group_id for this batch.
            If task_group_id is not specified, it will default to using the client's
            session_task_group_id

        Returns
        -------
        Batch instance
            Status block containing "status" key.
        """
        if not task_group_id:
            task_group_id = self.session_task_group_id

        batch = Batch(task_group_id=task_group_id)
        return batch

    def batch_run(self, batch):
        """Initiate a batch of tasks to funcX

        Parameters
        ----------
        batch: a Batch object

        Returns
        -------
        task_ids : a list of UUID strings that identify the tasks
        """
        assert isinstance(batch, Batch), "Requires a Batch object as input"
        assert len(batch.tasks) > 0, "Requires a non-empty batch"

        data = batch.prepare()

        # Send the data to funcX
        r = self.web_client.submit(data)

        task_uuids = []
        for result in r["results"]:
            task_id = result["task_uuid"]
            task_uuids.append(task_id)
            if result["http_status_code"] != 200:
                # this method of handling errors for a batch response is not
                # ideal, as it will raise any error in the multi-response,
                # but it will do until batch_run is deprecated in favor of Executer
                handle_response_errors(result)

        if self.asynchronous:
            task_group_id = r["task_group_id"]
            asyncio_tasks = []
            for task_id in task_uuids:
                funcx_task = FuncXTask(task_id)
                asyncio_task = self.loop.create_task(funcx_task.get_result())
                asyncio_tasks.append(asyncio_task)

                self.ws_polling_task.add_task(funcx_task)
            self.ws_polling_task.put_task_group_id(task_group_id)
            return asyncio_tasks

        return task_uuids

    def map_run(
        self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs
    ):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg,))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            "endpoints": [endpoint_id],
            "func": function_id,
            "payload": batch_payload,
            "is_async": asynchronous,
        }

        # Send the data to funcX
        r = self.web_client.submit_batch(data)
        return r["task_uuids"]

    def register_endpoint(
        self, name, endpoint_uuid, metadata=None, endpoint_version=None
    ):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        metadata : dict
            endpoint metadata, see default_config example
        endpoint_version: str
            Version string to be passed to the webService as a compatibility check

        Returns
        -------
        A dict
            {'endpoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        self.version_check()

        r = self.web_client.register_endpoint(
            endpoint_name=name,
            endpoint_id=endpoint_uuid,
            metadata=metadata,
            endpoint_version=endpoint_version,
        )
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        data = {"endpoint_name": name, "description": description}

        r = self.web_client.post("get_containers", data=data)
        return r.data["endpoint_uuid"], r.data["endpoint_containers"]

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        self.version_check()

        r = self.web_client.get(f"containers/{container_uuid}/{container_type}")
        return r.data["container"]

    def get_endpoint_status(self, endpoint_uuid):
        """Get the status reports for an endpoint.

        Parameters
        ----------
        endpoint_uuid : str
            UUID of the endpoint in question

        Returns
        -------
        dict
            The details of the endpoint's stats
        """
        r = self.web_client.get_endpoint_status(endpoint_uuid)
        return r.data

    def register_function(
        self,
        function,
        function_name=None,
        container_uuid=None,
        description=None,
        public=False,
        group=None,
        searchable=True,
    ):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution
        function_name : str
            The entry point (function name) of the function. Default: None
        container_uuid : str
            Container UUID from registration with funcX
        description : str
            Description of the file
        public : bool
            Whether or not the function is publicly accessible. Default = False
        group : str
            A globus group uuid to share this function with
        searchable : bool
            If true, the function will be indexed into globus search with the
            appropriate permissions

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        data = FunctionRegistrationData(
            function=function,
            failover_source="",
            container_uuid=container_uuid,
            entry_point=function_name,
            description=description,
            public=public,
            group=group,
            searchable=searchable,
            serializer=self.fx_serializer,
        )
        logger.info(f"Registering function : {data}")
        r = self.web_client.register_function(data)
        return r.data["function_uuid"]

    def search_function(self, q, offset=0, limit=10, advanced=False):
        """Search for function via the funcX service

        Parameters
        ----------
        q : str
            free-form query string
        offset : int
            offset into total results
        limit : int
            max number of results to return
        advanced : bool
            allows elastic-search like syntax in query string

        Returns
        -------
        FunctionSearchResults
        """
        return self.searcher.search_function(
            q, offset=offset, limit=limit, advanced=advanced
        )

    def search_endpoint(self, q, scope="all", owner_id=None):
        """

        Parameters
        ----------
        q
        scope : str
            Can be one of {'all', 'my-endpoints', 'shared-with-me'}
        owner_id
            should be urn like f"urn:globus:auth:identity:{owner_uuid}"

        Returns
        -------

        """
        return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id)

    def register_container(self, location, container_type, name="", description=""):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker).
            Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        payload = {
            "name": name,
            "location": location,
            "description": description,
            "type": container_type,
        }

        r = self.web_client.post("containers", data=payload)
        return r.data["container_id"]

    def add_to_whitelist(self, endpoint_id, function_ids):
        """Adds the function to the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        return self.web_client.whitelist_add(endpoint_id, function_ids)

    def get_whitelist(self, endpoint_id):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint

        Returns
        -------
        json
            The response of the request
        """
        return self.web_client.get_whitelist(endpoint_id)

    def delete_from_whitelist(self, endpoint_id, function_ids):
        """List the endpoint's whitelist

        Parameters
        ----------
        endpoint_id : str
            The uuid of the endpoint
        function_ids : list
            A list of function id's to be whitelisted

        Returns
        -------
        json
            The response of the request
        """
        if not isinstance(function_ids, list):
            function_ids = [function_ids]
        res = []
        for fid in function_ids:
            res.append(self.web_client.whitelist_remove(endpoint_id, fid))
        return res
예제 #8
0
class CentralScheduler(object):
    def __init__(self,
                 endpoints,
                 strategy='round-robin',
                 runtime_predictor='rolling-average',
                 last_n=3,
                 train_every=1,
                 log_level='INFO',
                 import_model_file=None,
                 transfer_model_file=None,
                 sync_level='exists',
                 max_backups=0,
                 backup_delay_threshold=2.0,
                 *args,
                 **kwargs):
        self._fxc = FuncXClient(*args, **kwargs)

        # Initialize a transfer client
        self._transfer_manger = TransferManager(endpoints=endpoints,
                                                sync_level=sync_level,
                                                log_level=log_level)

        # Info about FuncX endpoints we can execute on
        self._endpoints = endpoints
        self._dead_endpoints = set()
        self.last_result_time = defaultdict(float)
        self.temperature = defaultdict(lambda: 'WARM')
        self._imports = defaultdict(list)
        self._imports_required = defaultdict(list)

        # Track which endpoints a function can't run on
        self._blocked = defaultdict(set)

        # Track pending tasks
        # We will provide the client our own task ids, since we may submit the
        # same task multiple times to the FuncX service, and sometimes we may
        # wait to submit a task to FuncX (e.g., wait for a data transfer).
        self._task_id_translation = {}
        self._pending = {}
        self._pending_by_endpoint = defaultdict(set)
        self._task_info = {}
        # List of endpoints a (virtual) task was scheduled to
        self._endpoints_sent_to = defaultdict(list)
        self.max_backups = max_backups
        self.backup_delay_threshold = backup_delay_threshold
        self._latest_status = {}
        self._last_task_ETA = defaultdict(float)
        # Maximum ETA, if any, of a task which we allow to be scheduled on an
        # endpoint. This is to prevent backfill tasks to be longer than the
        # estimated time for when a pending data transfer will finish.
        self._transfer_ETAs = defaultdict(dict)
        # Estimated error in the pending-task time of an endpoint.
        # Updated every time a task result is received from an endpoint.
        self._queue_error = defaultdict(float)

        # Set logging levels
        logger.setLevel(log_level)
        self.execution_log = []

        # Intialize serializer
        self.fx_serializer = FuncXSerializer()
        self.fx_serializer.use_custom('03\n', 'code')

        # Initialize runtime predictor
        self.runtime = init_runtime_predictor(runtime_predictor,
                                              endpoints=endpoints,
                                              last_n=last_n,
                                              train_every=train_every)
        logger.info(f"Runtime predictor using strategy {self.runtime}")

        # Initialize transfer-time predictor
        self.transfer_time = TransferPredictor(endpoints=endpoints,
                                               train_every=train_every,
                                               state_file=transfer_model_file)

        # Initialize import-time predictor
        self.import_predictor = ImportPredictor(endpoints=endpoints,
                                                state_file=import_model_file)

        # Initialize scheduling strategy
        self.strategy = init_strategy(strategy,
                                      endpoints=endpoints,
                                      runtime_predictor=self.runtime,
                                      queue_predictor=self.queue_delay,
                                      cold_start_predictor=self.cold_start,
                                      transfer_predictor=self.transfer_time)
        logger.info(f"Scheduler using strategy {self.strategy}")

        # Start thread to check on endpoints regularly
        self._endpoint_watchdog = Thread(target=self._check_endpoints)
        self._endpoint_watchdog.start()

        # Start thread to monitor tasks and send tasks to FuncX service
        self._scheduled_tasks = Queue()
        self._task_watchdog_sleep = 0.15
        self._task_watchdog = Thread(target=self._monitor_tasks)
        self._task_watchdog.start()

    def block(self, func, endpoint):
        if endpoint not in self._endpoints:
            logger.error('Cannot block unknown endpoint {}'.format(endpoint))
            return {
                'status': 'Failed',
                'reason': 'Unknown endpoint {}'.format(endpoint)
            }
        elif len(self._blocked[func]) == len(self._endpoints) - 1:
            logger.error(
                'Cannot block last remaining endpoint {}'.format(endpoint))
            return {
                'status': 'Failed',
                'reason': 'Cannot block all endpoints for {}'.format(func)
            }
        else:
            logger.info('Blocking endpoint {} for function {}'.format(
                endpoint_name(endpoint), func))
            self._blocked[func].add(endpoint)
            return {'status': 'Success'}

    def register_imports(self, func, imports):
        logger.info('Registered function {} with imports {}'.format(
            func, imports))
        self._imports_required[func] = imports

    def batch_submit(self, tasks, headers):
        # TODO: smarter scheduling for batch submissions

        task_ids = []
        endpoints = []

        for func, payload in tasks:
            _, ser_kwargs = self.fx_serializer.unpack_buffers(payload)
            kwargs = self.fx_serializer.deserialize(ser_kwargs)
            files = kwargs['_globus_files']

            task_id, endpoint = self._schedule_task(func=func,
                                                    payload=payload,
                                                    headers=headers,
                                                    files=files)
            task_ids.append(task_id)
            endpoints.append(endpoint)

        return task_ids, endpoints

    def _schedule_task(self, func, payload, headers, files, task_id=None):

        # If this is the first time scheduling this task_id
        # (i.e., non-backup task), record the necessary metadata
        if task_id is None:
            # Create (fake) task id to return to client
            task_id = str(uuid.uuid4())

            # Store task information
            self._task_id_translation[task_id] = set()

            # Information required to schedule the task, now and in the future
            info = {
                'function_id': func,
                'payload': payload,
                'headers': headers,
                'files': files,
                'time_requested': time.time()
            }
            self._task_info[task_id] = info

        # TODO: do not choose a dead endpoint (reliably)
        # exclude = self._blocked[func] | self._dead_endpoints | set(self._endpoints_sent_to[task_id])  # noqa
        if len(self._dead_endpoints) > 0:
            logger.warn('{} endpoints seem dead. Hope they still work!'.format(
                len(self._dead_endpoints)))
        exclude = self._blocked[func] | set(self._endpoints_sent_to[task_id])
        choice = self.strategy.choose_endpoint(
            func,
            payload=payload,
            files=files,
            exclude=exclude,
            transfer_ETAs=self._transfer_ETAs)  # noqa
        endpoint = choice['endpoint']
        logger.info('Choosing endpoint {} for func {}, task id {}'.format(
            endpoint_name(endpoint), func, task_id))
        choice['ETA'] = self.strategy.predict_ETA(func,
                                                  endpoint,
                                                  payload,
                                                  files=files)

        # Start Globus transfer of required files, if any
        if len(files) > 0:
            transfer_num = self._transfer_manger.transfer(
                files, endpoint, task_id)
            if transfer_num is not None:
                transfer_ETA = time.time() + self.transfer_time(
                    files, endpoint)
                self._transfer_ETAs[endpoint][transfer_num] = transfer_ETA
        else:
            transfer_num = None
            # Record endpoint ETA for queue-delay prediction here,
            # since task will be immediately scheduled
            self._last_task_ETA[endpoint] = choice['ETA']

        # If a cold endpoint is being started, mark it as no longer cold,
        # so that subsequent launch-time predictions are correct (i.e., 0)
        if self.temperature[endpoint] == 'COLD':
            self.temperature[endpoint] = 'WARMING'
            logger.info(
                'A cold endpoint {} was chosen; marked as warming.'.format(
                    endpoint_name(endpoint)))

        # Schedule task for sending to FuncX
        self._endpoints_sent_to[task_id].append(endpoint)
        self._scheduled_tasks.put((task_id, endpoint, transfer_num))

        return task_id, endpoint

    def translate_task_id(self, task_id):
        return self._task_id_translation[task_id]

    def log_status(self, real_task_id, data):
        if real_task_id not in self._pending:
            logger.warn('Ignoring unknown task id {}'.format(real_task_id))
            return

        task_id = self._pending[real_task_id]['task_id']
        func = self._pending[real_task_id]['function_id']
        endpoint = self._pending[real_task_id]['endpoint_id']
        # Don't overwrite latest status if it is a result/exception
        if task_id not in self._latest_status or \
                self._latest_status[task_id].get('status') == 'PENDING':
            self._latest_status[task_id] = data

        if 'result' in data:
            result = self.fx_serializer.deserialize(data['result'])
            runtime = result['runtime']
            name = endpoint_name(endpoint)
            logger.info('Got result from {} for task {} with time {}'.format(
                name, real_task_id, runtime))

            self.runtime.update(self._pending[real_task_id], runtime)
            self._pending[real_task_id]['runtime'] = runtime
            self._record_completed(real_task_id)
            self.last_result_time[endpoint] = time.time()
            self._imports[endpoint] = result['imports']

        elif 'exception' in data:
            exception = self.fx_serializer.deserialize(data['exception'])
            try:
                exception.reraise()
            except Exception as e:
                logger.error('Got exception on task {}: {}'.format(
                    real_task_id, e))
                exc_type, _, _ = sys.exc_info()
                if exc_type in BLOCK_ERRORS:
                    self.block(func, endpoint)

            self._record_completed(real_task_id)
            self.last_result_time[endpoint] = time.time()

        elif 'status' in data and data['status'] == 'PENDING':
            pass

        else:
            logger.error('Unexpected status message: {}'.format(data))

    def get_status(self, task_id):
        if task_id not in self._task_id_translation:
            logger.warn('Unknown client task id {}'.format(task_id))

        elif len(self._task_id_translation[task_id]) == 0:
            return {'status': 'PENDING'}  # Task has not been scheduled yet

        elif task_id not in self._latest_status:
            return {'status': 'PENDING'}  # Status has not been queried yet

        else:
            return self._latest_status[task_id]

    def queue_delay(self, endpoint):
        # Otherwise, queue delay is the ETA of most recent task,
        # plus the estimated error in the ETA prediction.
        # Note that if there are no pending tasks on endpoint, no queue delay.
        # This is implicit since, in this case, both summands will be 0.
        delay = self._last_task_ETA[endpoint] + self._queue_error[endpoint]
        return max(delay, time.time())

    def _record_completed(self, real_task_id):
        info = self._pending[real_task_id]
        endpoint = info['endpoint_id']

        # If this is the last pending task on this endpoint, reset ETA offset
        if len(self._pending_by_endpoint[endpoint]) == 1:
            self._last_task_ETA[endpoint] = 0.0
            self._queue_error[endpoint] = 0.0
        else:
            prediction_error = time.time() - self._pending[real_task_id]['ETA']
            self._queue_error[endpoint] = prediction_error
            # print(colored(f'Prediction error {prediction_error}', 'red'))

        info['ATA'] = time.time()
        del info['headers']
        self.execution_log.append(info)

        logger.info(
            'Task exec time: expected = {:.3f}, actual = {:.3f}'.format(
                info['ETA'] - info['time_sent'],
                time.time() - info['time_sent']))
        # logger.info(f'ETA_offset = {self._queue_error[endpoint]:.3f}')

        # Stop tracking this task
        del self._pending[real_task_id]
        self._pending_by_endpoint[endpoint].remove(real_task_id)
        if info['task_id'] in self._task_info:
            del self._task_info[info['task_id']]

    def cold_start(self, endpoint, func):
        # If endpoint is warm, there is no launch time
        if self.temperature[endpoint] != 'COLD':
            launch_time = 0.0
        # Otherwise, return the launch time in the endpoint config
        elif 'launch_time' in self._endpoints[endpoint]:
            launch_time = self._endpoints[endpoint]['launch_time']
        else:
            logger.warn(
                'Endpoint {} should always be warm, but is cold'.format(
                    endpoint_name(endpoint)))
            launch_time = 0.0

        # Time to import dependencies
        import_time = 0.0
        for pkg in self._imports_required[func]:
            if pkg not in self._imports[endpoint]:
                logger.debug(
                    'Cold-start has import time for pkg {} on {}'.format(
                        pkg, endpoint_name(endpoint)))
                import_time += self.import_predictor(pkg, endpoint)

        return launch_time + import_time

    def _monitor_tasks(self):
        logger.info('Starting task-watchdog thread')

        scheduled = {}

        while True:

            time.sleep(self._task_watchdog_sleep)

            # Get newly scheduled tasks
            while True:
                try:
                    task_id, end, num = self._scheduled_tasks.get_nowait()
                    if task_id not in self._task_info:
                        logger.warn(
                            'Task id {} scheduled but no info found'.format(
                                task_id))
                        continue
                    info = self._task_info[task_id]
                    scheduled[task_id] = dict(info)  # Create new copy of info
                    scheduled[task_id]['task_id'] = task_id
                    scheduled[task_id]['endpoint_id'] = end
                    scheduled[task_id]['transfer_num'] = num
                except Empty:
                    break

            # Filter out all tasks whose data transfer has not been completed
            ready_to_send = set()
            for task_id, info in scheduled.items():
                transfer_num = info['transfer_num']
                if transfer_num is None:
                    ready_to_send.add(task_id)
                    info['transfer_time'] = 0.0
                elif self._transfer_manger.is_complete(transfer_num):
                    ready_to_send.add(task_id)
                    del self._transfer_ETAs[info['endpoint_id']][transfer_num]
                    info[
                        'transfer_time'] = self._transfer_manger.get_transfer_time(
                            transfer_num)  # noqa
                else:  # This task cannot be scheduled yet
                    continue

            if len(ready_to_send) == 0:
                logger.debug('No new tasks to send. Task watchdog sleeping...')
                continue

            # TODO: different clients send different headers. change eventually
            headers = list(scheduled.values())[0]['headers']

            logger.info('Scheduling a batch of {} tasks'.format(
                len(ready_to_send)))

            # Submit all ready tasks to FuncX
            data = {'tasks': []}
            for task_id in ready_to_send:
                info = scheduled[task_id]
                submit_info = (info['function_id'], info['endpoint_id'],
                               info['payload'])
                data['tasks'].append(submit_info)

            res_str = requests.post(f'{FUNCX_API}/submit',
                                    headers=headers,
                                    data=json.dumps(data))
            try:
                res = res_str.json()
            except ValueError:
                logger.error(f'Could not parse JSON from {res_str.text}')
                continue
            if res['status'] != 'Success':
                logger.error(
                    'Could not send tasks to FuncX. Got response: {}'.format(
                        res))
                continue

            # Update task info with submission info
            for task_id, real_task_id in zip(ready_to_send, res['task_uuids']):
                info = scheduled[task_id]
                # This ETA calculation does not take into account transfer time
                # since, at this point, the transfer has already completed.
                info['ETA'] = self.strategy.predict_ETA(
                    info['function_id'], info['endpoint_id'], info['payload'])
                # Record if this ETA prediction is "reliable". If it is not
                # (e.g., when we have not learned about this (func, ep) pair),
                # backup tasks will not be sent for this task if it is delayed.
                info['is_ETA_reliable'] = self.runtime.has_learned(
                    info['function_id'], info['endpoint_id'])

                info['time_sent'] = time.time()

                endpoint = info['endpoint_id']
                self._task_id_translation[task_id].add(real_task_id)

                self._pending[real_task_id] = info
                self._pending_by_endpoint[endpoint].add(real_task_id)

                # Record endpoint ETA for queue-delay prediction
                self._last_task_ETA[endpoint] = info['ETA']

                logger.info(
                    'Sent task id {} to {} with real task id {}'.format(
                        task_id, endpoint_name(endpoint), real_task_id))

            # Stop tracking all newly sent tasks
            for task_id in ready_to_send:
                del scheduled[task_id]

    def _check_endpoints(self):
        logger.info('Starting endpoint-watchdog thread')

        while True:
            for end in self._endpoints.keys():
                statuses = self._fxc.get_endpoint_status(end)
                if len(statuses) == 0:
                    logger.warn(
                        'Endpoint {} does not have any statuses'.format(
                            endpoint_name(end)))
                else:
                    status = statuses[0]  # Most recent endpoint status

                    # Mark endpoint as dead/alive based on heartbeat's age
                    # Heartbeats are delayed when an endpoint is executing
                    # tasks, so take into account last execution too
                    age = time.time() - max(status['timestamp'],
                                            self.last_result_time[end])
                    is_dead = end in self._dead_endpoints
                    if not is_dead and age > HEARTBEAT_THRESHOLD:
                        self._dead_endpoints.add(end)
                        logger.warn(
                            'Endpoint {} seems to have died! '
                            'Last heartbeat was {:.2f} seconds ago.'.format(
                                endpoint_name(end), age))
                    elif is_dead and age <= HEARTBEAT_THRESHOLD:
                        self._dead_endpoints.remove(end)
                        logger.warn(
                            'Endpoint {} is back alive! '
                            'Last heartbeat was {:.2f} seconds ago.'.format(
                                endpoint_name(end), age))

                    # Mark endpoint as "cold" or "warm" depending on if it
                    # has active managers (nodes) allocated to it
                    if self.temperature[end] == 'WARM' \
                            and status['active_managers'] == 0:
                        self.temperature[end] = 'COLD'
                        logger.info('Endpoint {} is cold!'.format(
                            endpoint_name(end)))
                    elif self.temperature[end] != 'WARM' \
                            and status['active_managers'] > 0:
                        self.temperature[end] = 'WARM'
                        logger.info('Endpoint {} is warm again!'.format(
                            endpoint_name(end)))

            # Send backup tasks if needed
            self._send_backups_if_needed()

            # Sleep before checking statuses again
            time.sleep(5)

    def _send_backups_if_needed(self):
        # Get all tasks which have not been completed yet and still have a
        # pending (real) task on a dead endpoint
        task_ids = {
            self._pending[real_task_id]['task_id']
            for endpoint in self._dead_endpoints
            for real_task_id in self._pending_by_endpoint[endpoint]
            if self._pending[real_task_id]['task_id'] in self._task_info
        }

        # Get all tasks for which we had ETA-predictions but haven't
        # been completed even past their ETA
        for real_task_id, info in self._pending.items():
            # If the predicted ETA wasn't reliable, don't send backups
            if not info['is_ETA_reliable']:
                continue

            expected = info['ETA'] - info['time_sent']
            elapsed = time.time() - info['time_sent']

            if elapsed / expected > self.backup_delay_threshold:
                task_ids.add(info['task_id'])

        for task_id in task_ids:
            if len(self._endpoints_sent_to[task_id]) > self.max_backups:
                logger.debug(f'Skipping sending new backup task for {task_id}')
            else:
                logger.info(f'Sending new backup task for {task_id}')
                info = self._task_info[task_id]
                self._schedule_task(info['function_id'], info['payload'],
                                    info['headers'], info['files'], task_id)
예제 #9
0
event = None

endpoint = '68bade94-bf58-4a7a-bfeb-9c6a61fa5443'


items_to_batch = [{"func_id": func_uuid, "event": {}}, {"func_id": func_uuid, "event": {}}]
x = remote_extract_batch(items_to_batch, endpoint, headers=headers)

fx_ser = FuncXSerializer()

import time
while True:
    a = remote_poll_batch(x, headers)
    print(f"The returned: {a}")

    for tid in a:
        if "exception"  in a[tid]:
            exception = a[tid]["exception"]
            print(f"The serialized exception: {exception}")
            d_exception = fx_ser.deserialize(exception)
            print(f"The deserialized exception {d_exception}")

            print("RERAISING EXCEPTION!")

            d_exception.reraise()
                #print(fx_ser.deserialize(a[tid]["exception"]))
        else:
            print(a[tid])
    time.sleep(5.1)

class ExtractorOrchestrator:
    def __init__(self,
                 funcx_eid,
                 mdata_store_path,
                 source_eid=None,
                 dest_eid=None,
                 gdrive_token=None,
                 extractor_finder='gdrive',
                 prefetch_remote=False,
                 data_prefetch_path=None,
                 dataset_mdata=None):

        prefetch_remote = False

        # TODO -- fix this.
        # self.crawl_type = 'from_file'

        self.write_cpe = False

        self.dataset_mdata = dataset_mdata

        self.t_crawl_start = time.time()
        self.t_send_batch = 0
        self.t_transfer = 0

        self.prefetch_remote = prefetch_remote
        self.data_prefetch_path = data_prefetch_path

        self.extractor_finder = extractor_finder

        self.funcx_eid = funcx_eid
        self.func_dict = {
            "image": xtract_images.ImageExtractor(),
            "images": xtract_images.ImageExtractor(),
            "tabular": xtract_tabular.TabularExtractor(),
            "text": xtract_keyword.KeywordExtractor(),
            "matio": xtract_matio.MatioExtractor()
        }

        self.fx_ser = FuncXSerializer()

        self.send_status = "STARTING"
        self.poll_status = "STARTING"
        self.commit_completed = False

        self.source_endpoint = source_eid
        self.dest_endpoint = dest_eid
        self.gdrive_token = gdrive_token

        self.num_families_fetched = 0
        self.get_families_start_time = time.time()
        self.last_checked = time.time()

        self.pre_launch_counter = 0

        self.success_returns = 0
        self.failed_returns = 0

        self.to_send_queue = Queue()

        self.poll_gap_s = 5

        self.get_families_status = "STARTING"

        self.task_dict = {
            "active": Queue(),
            "pending": Queue(),
            "failed": Queue()
        }

        # Batch size we use to send tasks to funcx.  (and the subbatch size)
        self.map_size = 8
        self.fx_batch_size = 16
        self.fx_task_sublist_size = 500

        # Want to store attributes about funcX requests/responses.
        self.tot_fx_send_payload_size = 0
        self.tot_fx_poll_payload_size = 0
        self.tot_fx_poll_result_size = 0
        self.num_send_reqs = 0
        self.num_poll_reqs = 0
        self.t_first_funcx_invoke = None
        self.max_result_size = 0

        # Number (current and max) of number of tasks sent to funcX for extraction.
        self.max_extracting_tasks = 5
        self.num_extracting_tasks = 0

        self.max_pre_prefetch = 15000  # TODO: Integrate this to actually fix timing bug.

        self.status_things = Queue()

        # If this is turned on, should mean that we hit our local task maximum and don't want to pull down new work...
        self.pause_q_consume = False

        self.file_count = 0
        self.current_batch = []
        self.extract_end = None

        self.mdata_store_path = mdata_store_path
        self.n_fams_transferred = 0

        self.prefetcher_tid = None
        self.prefetch_status = None

        self.fx_headers = {
            "Authorization": f"Bearer {self.headers['FuncX']}",
            'FuncX': self.headers['FuncX']
        }

        self.family_headers = None
        if 'Petrel' in self.headers:
            self.fx_headers['Petrel'] = self.headers['Petrel']
            self.family_headers = {
                'Authorization': f"Bearer {self.headers['Petrel']}",
                'Transfer': self.headers['Transfer'],
                'FuncX': self.headers['FuncX'],
                'Petrel': self.headers['Petrel']
            }

        self.logger = logging.getLogger(__name__)
        handler = logging.StreamHandler()
        formatter = logging.Formatter(
            '%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
        self.logger.setLevel(
            logging.INFO)  # TODO: let's make this configurable.
        self.families_to_process = Queue()
        self.to_validate_q = Queue()

        self.sqs_push_threads = {}
        self.thr_ls = []
        self.commit_threads = 1
        self.get_family_threads = 20

        if self.prefetch_remote:
            self.logger.info("Launching prefetcher...")

            self.logger.info("Prefetcher successfully launched!")

            prefetch_thread = threading.Thread(
                target=self.prefetcher.main_poller_loop, args=())
            prefetch_thread.start()

        for i in range(0, self.commit_threads):
            thr = threading.Thread(target=self.validate_enqueue_loop,
                                   args=(i, ))
            self.thr_ls.append(thr)
            thr.start()
            self.sqs_push_threads[i] = True
        self.logger.info(
            f"Successfully started {len(self.sqs_push_threads)} SQS push threads!"
        )

        if self.crawl_type != 'from_file':
            for i in range(0, self.get_family_threads):
                self.logger.info(
                    f"Attempting to start get_next_families() as its own thread [{i}]... "
                )
                consumer_thr = threading.Thread(
                    target=self.get_next_families_loop, args=())
                consumer_thr.start()
                print(
                    f"Successfully started the get_next_families() thread number {i} "
                )
        else:
            print("ATTEMPTING TO LAUNCH **FILE** CRAWL THREAD. ")
            file_crawl_thr = threading.Thread(
                target=self.read_next_families_from_file_loop, args=())
            file_crawl_thr.start()
            print("Successfully started the **FILE** CRAWL thread!")

        for i in range(0, 15):
            fx_push_thr = threading.Thread(target=self.send_subbatch_thread,
                                           args=())
            fx_push_thr.start()
        print("Successfully spun up {i} send threads!")

        with open("cpe_times.csv", 'w') as f:
            f.close()

    def send_subbatch_thread(self):

        # TODO: THIS IS NEW. HUNGRY STANDALONE THREADPOOL THAT SENDS ALL TASKS.
        # TODO: SHOULD TERMINATE WHEN COMPLETED <-- do after paper deadline.

        while True:
            sub_batch = []
            for i in range(10):

                if self.to_send_queue.empty():
                    break

                # I believe this is a blocking call.
                part_batch = self.to_send_queue.get()
                sub_batch.extend(part_batch)

            if len(sub_batch) == 0:
                time.sleep(0.5)
                continue

            batch_send_t = time.time()
            task_ids = remote_extract_batch(sub_batch,
                                            ep_id=self.funcx_eid,
                                            headers=self.fx_headers)
            batch_recv_t = time.time()

            print(f"Time to send batch: {batch_recv_t - batch_send_t}")

            self.num_send_reqs += 1
            self.pre_launch_counter -= len(sub_batch)

            if type(task_ids) is dict:
                self.logger.exception(
                    f"Caught funcX error: {task_ids['exception_caught']}. \n"
                    f"Putting the tasks back into active queue for retry")

                for reject_fam_batch in self.current_batch:

                    fam_batch_dict = reject_fam_batch['event']['family_batch']

                    for reject_fam in fam_batch_dict['families']:
                        self.families_to_process.put(json.dumps(reject_fam))

                self.logger.info(f"Pausing for 10 seconds...")

            for task_id in task_ids:
                self.task_dict["active"].put(task_id)
                self.num_extracting_tasks += 1

            time.sleep(0.5)

    def validate_enqueue_loop(self, thr_id):

        self.logger.debug("[VALIDATE] In validation enqueue loop!")
        while True:
            insertables = []

            # If empty, then we want to return.
            if self.to_validate_q.empty():
                # If ingest queue empty, we can demote to "idle"
                if self.poll_status == "COMMITTING":
                    self.sqs_push_threads[thr_id] = "IDLE"
                    print(f"Thread {thr_id} is committing and idle!")
                    time.sleep(0.25)

                    # NOW if all threads idle, then return!
                    if all(value == "IDLE"
                           for value in self.sqs_push_threads.values()):
                        self.commit_completed = True
                        self.poll_status = "COMPLETED"
                        self.extract_end = time.time()
                        self.logger.info(f"Thread {thr_id} is terminating!")
                        return 0

                self.logger.info("[Validate]: sleeping for 5 seconds... ")
                time.sleep(5)
                continue

            self.sqs_push_threads[thr_id] = "ACTIVE"

            # Remove up to n elements from queue, where n is current_batch.
            current_batch = 1
            while not self.to_validate_q.empty(
            ) and current_batch < 8:  # TODO: manual downsize so fits on Q.
                item_to_add = self.to_validate_q.get()

                # TODO: THIS IS PULLING THINGS BEFORE THEY GET TO VALIDATE QUEUE.

                insertables.append(item_to_add)
                current_batch += 1

            # TODO: boot all of this out to file.
            if self.write_cpe:
                with open("cpe_times.csv", 'a') as f:

                    csv_writer = csv.writer(f)

                    for item in insertables:

                        fam_batch = json.loads(item['MessageBody'])

                        for family in fam_batch['families']:
                            # print(family)
                            # crawl_timestamp = family['metadata']['crawl_timestamp']
                            pf_timestamp = family['metadata'][
                                'pf_transfer_completed']
                            fx_timestamp = family['metadata'][
                                't_funcx_req_received']

                            total_file_size = 0

                            all_files = family['files']
                            total_files = len(all_files)
                            for file_obj in all_files:
                                total_file_size += file_obj['metadata'][
                                    'physical']['size']

                            csv_writer.writerow([
                                'x', 0, pf_timestamp, fx_timestamp,
                                total_files, total_file_size
                            ])

            # TODO: investigate this. Why is this here?
            # try:
            #     ta = time.time()
            #     self.client.send_message_batch(QueueUrl=self.validation_queue_url,
            #                                    Entries=insertables)
            #     tb = time.time()
            #     self.t_send_batch += tb-ta
            #
            # except Exception as e:  # TODO: too vague
            #     print(f"WAS UNABLE TO PROPERLY CONNECT to SQS QUEUE: {e}")

    def send_families_loop(self):

        self.send_status = "RUNNING"
        last_send = time.time()

        while True:

            # print(f"Time since last send: {last_send - time.time()}")
            last_send = time.time()

            if self.num_extracting_tasks > self.max_extracting_tasks:
                self.logger.info(
                    f"[SECOND] Num. active tasks ({self.num_extracting_tasks}) "
                    f"above threshold. Sleeping for 1 second and continuing..."
                )
                time.sleep(1)
                continue

            if self.prefetch_remote:
                while not self.prefetcher.orch_reader_q.empty():
                    family = self.prefetcher.orch_reader_q.get()
                    family_size = self.prefetcher.get_family_size(
                        json.loads(family))

                    self.prefetcher.bytes_pf_completed -= family_size
                    self.prefetcher.orch_unextracted_bytes += family_size
                    self.pre_launch_counter += 1

                    self.families_to_process.put(family)

            family_list = []
            # Now keeping filling our list of families until it is empty.
            while len(
                    family_list
            ) < self.fx_batch_size and not self.families_to_process.empty():
                family_list.append(self.families_to_process.get())

            if len(family_list) == 0:
                # Here we check if the crawl is complete. If so, then we can start the teardown checks.
                status_dict = get_crawl_status(self.crawl_id)

                if status_dict['crawl_status'] in ["SUCCEEDED", "FAILED"]:

                    # Checking second time due to narrow race condition.
                    if self.families_to_process.empty(
                    ) and self.get_families_status == "IDLE":

                        if self.prefetch_remote and self.prefetch_status in [
                                "SUCCEEDED", "FAILED"
                        ]:

                            self.send_status = "SUCCEEDED"
                            self.logger.info(
                                "[SEND] Queue still empty -- terminating!")

                            # this should terminate thread, because there is nothing to process and queue empty
                            return

                    else:  # Something snuck in during the race condition... process it!
                        self.logger.info(
                            "[SEND] Discovered final output despite crawl termination. Processing..."
                        )
                        time.sleep(
                            0.5
                        )  # This is a multi-minute task and only is reached due to starvation.

            # self.map_size =
            # Cast list to FamilyBatch

            family_batch = FamilyBatch()
            for family in family_list:

                # Get extractor out of each group
                if self.extractor_finder == 'matio':
                    d_type = "HTTPS"
                    extr_code = 'matio'
                    xtr_fam_obj = Family(download_type=d_type)

                    xtr_fam_obj.from_dict(json.loads(family))
                    xtr_fam_obj.headers = self.family_headers

                # TODO: kick this logic for finding extractor into sdk/crawler.
                elif self.extractor_finder == 'gdrive':
                    d_type = 'gdrive'
                    xtr_fam_obj = Family(download_type=d_type)

                    xtr_fam_obj.from_dict(json.loads(family))
                    xtr_fam_obj.headers = self.headers

                    extr_code = xtr_fam_obj.groups[list(
                        xtr_fam_obj.groups.keys())[0]].parser

                else:
                    raise ValueError(
                        f"Incorrect extractor_finder arg: {self.extractor_finder}"
                    )

                # TODO: add the decompression work and the hdf5/netcdf extractors!
                if extr_code is None or extr_code == 'hierarch' or extr_code == 'compressed':
                    continue

                extractor = self.func_dict[extr_code]

                # TODO TYLER: Get the proper function ID here!!!
                # ex_func_id = extractor.func_id
                ex_func_id = mapping['xtract-matio::midway2']['func_uuid']

                # Putting into family batch -- we use funcX batching now, but no use rewriting...
                # family_batch = FamilyBatch()
                family_batch.add_family(xtr_fam_obj)
                # print(f"Length of family batch: {len(family_batch.families)}")

                if len(family_batch.families) >= self.map_size:

                    if d_type == "gdrive":
                        self.current_batch.append({
                            "event": {
                                "family_batch": family_batch,
                                "creds": self.gdrive_token[0]
                            },
                            "func_id": ex_func_id
                        })
                    elif d_type == "HTTPS":
                        self.current_batch.append({
                            "event": {
                                "family_batch": family_batch.to_dict()
                            },
                            "func_id": ex_func_id
                        })

                    family_batch = FamilyBatch()

            # Catch any tasks currently in the map and append them to the batch
            if len(family_batch.families) > 0:

                if d_type == "gdrive":
                    self.current_batch.append({
                        "event": {
                            "family_batch": family_batch,
                            "creds": self.gdrive_token[0]
                        },
                        "func_id": ex_func_id
                    })
                elif d_type == "HTTPS":
                    self.current_batch.append({
                        "event": {
                            "family_batch": family_batch.to_dict()
                        },
                        "func_id": ex_func_id
                    })

            # Now take that straggling family batch and append it.
            if len(self.current_batch) > 0:
                if self.t_first_funcx_invoke is None:
                    self.t_first_funcx_invoke = time.time()

                req_size = 0
                for item in self.current_batch:
                    req_size += sys.getsizeof(item)
                    req_size += 2 * sys.getsizeof(
                        self.funcx_eid)  # need size of fx ep and function id.
                req_size += sys.getsizeof(self.fx_headers)

                self.tot_fx_send_payload_size += req_size

            sub_batches = create_list_chunks(self.current_batch,
                                             self.fx_task_sublist_size)

            # print(f"Total sub_batches: {len(sub_batches)}")

            send_threads = []

            # for sub_batch in sub_batches:

            self.to_send_queue.put(self.current_batch)

            # i = 0
            # # TODO: THESE NEED TO BE STANDALONE THREADPOOL
            # for subbatch in sub_batches:
            #     send_thr = threading.Thread(target=self.send_subbatch_thread, args=(subbatch,))
            #     send_thr.start()
            #     send_threads.append(send_thr)
            #     i += 1
            # if i > 0:
            #     print(f"Spun up {i} task-send threads!")
            #
            # for thr in send_threads:
            #     thr.join()

            # Empty the batch! Everything in here has been sent :)
            self.current_batch = []

    def launch_poll(self):
        self.logger.info("Launching poller...")
        po_thr = threading.Thread(target=self.poll_extractions_and_stats,
                                  args=())
        po_thr.start()
        self.logger.info("Poller successfully launched!")

    def read_next_families_from_file_loop(self):
        """
        This loads saved crawler state (from a json file) and quickly adds all files to our local
        families_to_process queue. This avoids making calls to SQS to retrieve crawl results.
        """

        with open(
                '/Users/tylerskluzacek/PycharmProjects/xtracthub-service/experiments/tyler_200k.json',
                'r') as f:
            all_families = json.load(f)

            num_fams = 0
            for family in all_families:

                self.families_to_process.put(json.dumps(family))
                num_fams += 1
                if num_fams > self.task_cap_until_termination:
                    break

        # self.prefetcher.kill_prefetcher = True
        # self.prefetcher.last_batch = True  # TODO: bring this back for prefetcher.
        print("ENDING LOOP")

    def launch_extract(self):
        ex_thr = threading.Thread(target=self.send_families_loop, args=())
        ex_thr.start()

    def unpack_returned_family_batch(self, family_batch):
        fam_batch_dict = family_batch.to_dict()
        return fam_batch_dict

    def update_and_print_stats(self):
        # STAT CHECK: if we haven't updated stats in 5 seconds, then we update.
        cur_time = time.time()
        if cur_time - self.last_checked > 5:

            # TODO: all the commented-out jazz here should really be 'if prefetch_remote'.
            # total_bytes = self.prefetcher.orch_unextracted_bytes + \
            #               self.prefetcher.bytes_pf_completed + \
            #               self.prefetcher.bytes_pf_in_flight

            print("Phase 4: polling")
            print(f"\t Successes: {self.success_returns}")
            print(f"\t Failures: {self.failed_returns}\n")
            self.logger.debug(
                f"[VALIDATE] Length of validation queue: {self.to_validate_q.qsize()}"
            )

            # n_pulled = self.prefetcher.next_prefetch_queue.qsize()
            # n_pulled_per_sec = self.num_families_fetched / (cur_time - self.get_families_start_time)
            # n_pf_batch = self.prefetcher.pf_msgs_pulled_since_last_batch
            # n_families_pf_per_sec = self.prefetcher.num_families_transferred / (cur_time - self.get_families_start_time)
            # n_pf = self.prefetcher.num_families_mid_transfer
            # n_awaiting_fx = self.pre_launch_counter + self.prefetcher.orch_reader_q.qsize()
            n_in_fx = self.num_extracting_tasks
            n_success = self.success_returns

            # total_tracked = n_success + n_in_fx + n_awaiting_fx + n_pf + n_pf_batch + n_pulled

            # self.logger.info(f"\n** TASK LOCATION BREAKDOWN **\n"
            #                  f"--Pulled: {n_pulled}\t|\t({n_pulled_per_sec}/s)\n"
            #                  f"--In pf batching: {n_pf_batch}\n"
            #                  f"--Prefetching: {n_pf}\t|\t({n_families_pf_per_sec})\n"
            #                  f"--Awaiting extraction: {n_awaiting_fx}\n"
            #                  f"--In extraction: {n_in_fx}\n"
            #                  f"--Completed: {n_success}\n"
            #                  f"\n-- Fetch-Track Delta: {self.num_families_fetched - total_tracked}\n")

            if self.prefetch_remote:
                # print(f"\t Eff. dir size (GB): {total_bytes / 1024 / 1024 / 1024}")
                print(f"\t N Transfer Tasks: {self.prefetcher.num_transfers}")

            if self.num_send_reqs > 0 and self.num_poll_reqs > 0 and n_success > 0:
                self.logger.info(
                    f"\n** funcX Stats **\n"
                    f"Num. Send Requests: {self.num_send_reqs}\t|\t({time.time() - self.t_first_funcx_invoke})\n"
                    f"Num. Poll Requests: {self.num_poll_reqs}\t|\t()\n"
                    f"Avg. Send Request Size: {self.tot_fx_send_payload_size / self.num_send_reqs}\n"
                    f"Avg. Poll Request Size: {self.tot_fx_poll_payload_size / self.num_poll_reqs}\n"
                    f"Avg. Result Size: {self.tot_fx_poll_result_size / n_success}\n"
                    f"Max Result Size: {self.max_result_size}\n")

            self.last_checked = time.time()
            print(
                f"[GET] Elapsed Extract time: {time.time() - self.t_crawl_start}"
            )

    def poll_batch_chunks(self, thr_num, sublist, headers):
        status_thing = remote_poll_batch(sublist, headers)
        self.num_poll_reqs += 1

        status_tup = (thr_num, status_thing)

        self.status_things.put(status_tup)
        time.sleep(
            1
        )  # TODO: MIGHT WANT TO TAKE THIS OUT. JUST SEEING IF THIS FIXES funcX SCALE-UP.
        return

    def poll_extractions_and_stats(self):
        mod_not_found = 0
        type_errors = 0

        self.poll_status = "RUNNING"

        while True:

            # This will attempt to print stats on every iteration.
            #   Note: internally, this will only execute max every 5 seconds.
            self.update_and_print_stats()

            if self.task_dict["active"].empty():

                if self.send_status == "SUCCEEDED":
                    print("[POLL] Send status is SUCCEEDED... ")
                    print(
                        f"[POLL] Active tasks: {self.task_dict['active'].empty()}"
                    )
                    if self.task_dict["active"].empty(
                    ):  # check second time b/c of rare r.c.
                        print(
                            "Extraction completed. Upgrading status to COMMITTING."
                        )
                        self.poll_status = "COMMITTING"
                        return

                self.logger.debug("No live IDs... sleeping...")
                time.sleep(10)
                continue

            # Here we pull all values from active task queue to create a batch of them!
            num_elem = self.task_dict["active"].qsize()
            print(f"[POLL] Size active queue: {num_elem}"
                  )  # TODO: need to shush this.
            tids_to_poll = []  # the batch

            for i in range(0, num_elem):

                ex_id = self.task_dict["active"].get()
                tids_to_poll.append(ex_id)

            t_last_poll = time.time()

            # Send off task_ids to poll, retrieve a bunch of statuses.
            tid_sublists = create_list_chunks(tids_to_poll,
                                              self.fx_task_sublist_size)
            # TODO: ^^ this'll need to be a dictionary of some kind so we can track.

            polling_threads = []

            thr_to_sublist_map = dict()

            i = 0
            for tid_sublist in tid_sublists:  # TODO: put a cap on this (if i>10, break (for example)

                # If there's an actual thing in the list...
                if len(tid_sublist) > 0 and i < 15:
                    self.tot_fx_poll_payload_size += sys.getsizeof(tid_sublist)
                    self.tot_fx_poll_payload_size += sys.getsizeof(
                        self.fx_headers) * len(tid_sublist)

                    thr = threading.Thread(target=self.poll_batch_chunks,
                                           args=(i, tid_sublist,
                                                 self.fx_headers))
                    thr.start()
                    polling_threads.append(thr)

                    # NEW STEP: add the thread ID to the dict
                    thr_to_sublist_map[i] = tid_sublist
                    i += 1

            print(f"[POLL] Spun up {i} polling threads!")

            # Now we should wait for our fan-out threads to fan-in
            for thread in polling_threads:
                thread.join()

            # TODO: here is where we can fix this
            # TODO: Create tuples in status things that's a tuple with tid_sublist (OR THREAD ID)
            #  that the status thing is about.
            while not self.status_things.empty():

                thr_id, status_thing = self.status_things.get()

                if "exception_caught" in status_thing:
                    self.logger.exception(
                        f"Caught funcX error: {status_thing['exception_caught']}"
                    )
                    self.logger.exception(
                        f"Putting the tasks back into active queue for retry")

                    for reject_tid in thr_to_sublist_map[thr_id]:
                        self.task_dict["active"].put(reject_tid)

                    print(f"Pausing for 20 seconds...")
                    self.pause_q_consume = True
                    time.sleep(20)
                    continue
                else:
                    self.pause_q_consume = False

                for tid in status_thing:
                    task_obj = status_thing[tid]

                    if "result" in task_obj:
                        res = self.fx_ser.deserialize(task_obj['result'])

                        # Decrement number of tasks being extracted!
                        self.num_extracting_tasks -= 1

                        if "family_batch" in res:
                            family_batch = res["family_batch"]
                            unpacked_metadata = self.unpack_returned_family_batch(
                                family_batch)

                            # print(unpacked_metadata)

                            # TODO: make this a regular feature for matio (so this code isn't necessary...)
                            if 'event' in unpacked_metadata:
                                family_batch = unpacked_metadata['event'][
                                    'family_batch']
                                unpacked_metadata = family_batch.to_dict()
                                unpacked_metadata[
                                    'dataset'] = self.dataset_mdata

                            if self.prefetch_remote:
                                total_family_size = 0
                                for family in family_batch.families:
                                    # family['metadata']["t_funcx_req_received"] = time.time()
                                    total_family_size += self.prefetcher.get_family_size(
                                        family.to_dict())

                                self.prefetcher.orch_unextracted_bytes -= total_family_size

                            for family in unpacked_metadata['families']:
                                family['metadata'][
                                    "t_funcx_req_received"] = time.time()

                            json_mdata = json.dumps(unpacked_metadata,
                                                    cls=NumpyEncoder)

                            result_size = sys.getsizeof(json_mdata)

                            self.tot_fx_poll_result_size += result_size

                            if result_size > self.max_result_size:
                                self.max_result_size = result_size

                            try:
                                self.to_validate_q.put({
                                    "Id":
                                    str(self.file_count),
                                    "MessageBody":
                                    json_mdata
                                })
                                self.file_count += 1
                            except TypeError as e1:
                                self.logger.exception(f"Type error: {e1}")
                                type_errors += 1
                                self.logger.exception(
                                    f"Total type errors: {type_errors}")
                        else:
                            self.logger.error(
                                f"[Poller]: \"family_batch\" not in res!")

                        self.success_returns += 1

                        # Leave this in. Great for debugging and we have the IO for it.
                        self.logger.debug(f"Received response: {res}")

                        if type(res) is not dict:
                            self.logger.exception(f"Res is not dict: {res}")
                            continue

                        elif 'transfer_time' in res:
                            if res['transfer_time'] > 0:
                                self.t_transfer += res['transfer_time']
                                self.n_fams_transferred += 1
                                print(
                                    f"Avg. Transfer time: {self.t_transfer/self.n_fams_transferred}"
                                )

                        # This just means fixing Google Drive extractors...
                        elif 'trans_time' in res:
                            if res['trans_time'] > 0:
                                self.t_transfer += res['trans_time']
                                self.n_fams_transferred += 1
                                print(
                                    f"Avg. Transfer time: {self.t_transfer/self.n_fams_transferred}"
                                )

                    elif "exception" in task_obj:
                        exc = self.fx_ser.deserialize(task_obj['exception'])
                        try:
                            exc.reraise()
                        except ModuleNotFoundError:
                            mod_not_found += 1
                            self.logger.exception(
                                f"Num. ModuleNotFound: {mod_not_found}")

                        except Exception as e:
                            self.logger.exception(f"Caught exception: {e}")
                            self.logger.exception("Continuing!")
                            pass

                        self.failed_returns += 1
                        self.logger.exception(f"Exception caught: {exc}")

                    else:
                        self.task_dict["active"].put(tid)

            t_since_last_poll = time.time() - t_last_poll
            t_poll_gap = self.poll_gap_s - t_since_last_poll

            if t_poll_gap > 0:
                time.sleep(t_poll_gap)
예제 #11
0
    # TODO: Add ghetto-retry for pending tasks to catch lost ones.
    # TODO x10000: move this logic
    timeout = 120
    failed_counter = 0
    while True:

        if task_dict["active"].empty():
            print("Active task queue empty... sleeping... ")
            time.sleep(0.5)
            break  # This should jump out to main_loop

        cur_tid = task_dict["active"].get()
        print(cur_tid)
        status_thing = requests.get(get_url.format(cur_tid),
                                    headers=headers).json()

        if 'result' in status_thing:
            result = fx_ser.deserialize(status_thing['result'])
            print(f"Result: {result}")
            task_dict["results"].append(result)
            print(len(task_dict["results"]))

        elif 'exception' in status_thing:
            print(
                f"Exception: {fx_ser.deserialize(status_thing['exception'])}")
            # break
        else:
            task_dict["active"].put(cur_tid)

            # Manager: f2ab81f26f40
예제 #12
0
파일: client.py 프로젝트: vincent-pli/funcX
class FuncXClient(throttling.ThrottledBaseClient):
    """Main class for interacting with the funcX service

    Holds helper operations for performing common tasks with the funcX service.
    """

    TOKEN_DIR = os.path.expanduser("~/.funcx/credentials")
    TOKEN_FILENAME = 'funcx_sdk_tokens.json'
    CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c'

    def __init__(self,
                 http_timeout=None,
                 funcx_home=os.path.join('~', '.funcx'),
                 force_login=False,
                 fx_authorizer=None,
                 funcx_service_address='https://funcx.org/api/v1',
                 **kwargs):
        """ Initialize the client

        Parameters
        ----------
        http_timeout: int
        Timeout for any call to service in seconds.
        Default is no timeout

        force_login: bool
        Whether to force a login to get new credentials.

        fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`:
        A custom authorizer instance to communicate with funcX.
        Default: ``None``, will be created.

        service_address: str
        The address of the funcX web service to communicate with.
        Default: https://dev.funcx.org/api/v1

        Keyword arguments are the same as for BaseClient.
        """
        self.func_table = {}
        self.ep_registration_path = 'register_endpoint_2'
        self.funcx_home = os.path.expanduser(funcx_home)

        if not os.path.exists(self.TOKEN_DIR):
            os.makedirs(self.TOKEN_DIR)

        tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME)
        self.native_client = NativeClient(
            client_id=self.CLIENT_ID,
            app_name="FuncX SDK",
            token_storage=JSONTokenStorage(tokens_filename))

        fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"

        if not fx_authorizer:
            self.native_client.login(
                requested_scopes=[fx_scope],
                no_local_server=kwargs.get("no_local_server", True),
                no_browser=kwargs.get("no_browser", True),
                refresh_tokens=kwargs.get("refresh_tokens", True),
                force=force_login)

            all_authorizers = self.native_client.get_authorizers_by_scope(
                requested_scopes=[fx_scope])
            fx_authorizer = all_authorizers[fx_scope]

        super(FuncXClient, self).__init__("funcX",
                                          environment='funcx',
                                          authorizer=fx_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=funcx_service_address,
                                          **kwargs)
        self.fx_serializer = FuncXSerializer()

    def logout(self):
        """Remove credentials from your local system
        """
        self.native_client.logout()

    def update_table(self, return_msg, task_id):
        """ Parses the return message from the service and updates the internal func_tables

        Parameters
        ----------

        return_msg : str
           Return message received from the funcx service
        task_id : str
           task id string
        """
        if isinstance(return_msg, str):
            r_dict = json.loads(return_msg)
        else:
            r_dict = return_msg

        status = {'pending': True}

        if 'result' in r_dict:
            try:
                r_obj = self.fx_serializer.deserialize(r_dict['result'])
            except Exception:
                raise Exception(
                    "Failure during deserialization of the result object")
            else:
                status.update({'pending': 'False', 'result': r_obj})
                self.func_table[task_id] = status

        elif 'exception' in r_dict:
            try:
                r_exception = self.fx_serializer.deserialize(
                    r_dict['exception'])
                logger.info(f"Exception : {r_exception}")
            except Exception:
                raise Exception(
                    "Failure during deserialization of the Task's exception object"
                )
            else:
                status.update({'pending': 'False', 'exception': r_exception})
                self.func_table[task_id] = status
        return status

    def get_task_status(self, task_id):
        """Get the status of a funcX task.

        Parameters
        ----------
        task_id : str
            UUID of the task

        Returns
        -------
        dict
            Status block containing "status" key.
        """
        if task_id in self.func_table:
            return self.func_table[task_id]

        r = self.get("{task_id}/status".format(task_id=task_id))
        logger.debug("Response string : {}".format(r))
        try:
            rets = self.update_table(r.text, task_id)
        except Exception as e:
            raise e
        return rets

    def get_result(self, task_id):
        """ Get the result of a funcX task

        Parameters
        ----------
        task_id: str
            UUID of the task

        Returns
        -------
        Result obj: If task completed

        Raises
        ------
        Exception obj: Exception due to which the task failed
        """
        status = self.get_task_status(task_id)
        if status['pending'] is True:
            raise Exception("Task pending")
        else:
            if 'result' in status:
                return status['result']
            else:
                logger.warn("We have an exception : {}".format(
                    status['exception']))
                status['exception'].reraise()

    def get_batch_status(self, task_id_list):
        """ Request status for a batch of task_ids
        """
        assert isinstance(task_id_list,
                          list), "get_batch_status expects a list of task ids"

        pending_task_ids = [
            t for t in task_id_list if t not in self.func_table
        ]

        results = {}

        if pending_task_ids:
            payload = {'task_ids': pending_task_ids}
            r = self.post("/batch_status", json_body=payload)
            logger.debug("Response string : {}".format(r))

        pending_task_ids = set(pending_task_ids)

        for task_id in task_id_list:
            if task_id in pending_task_ids:
                try:
                    data = r['results'][task_id]
                    rets = self.update_table(data, task_id)
                    results[task_id] = rets
                except KeyError:
                    logger.debug(
                        "Task {} info was not available in the batch status")
                except Exception as e:
                    logger.exception(
                        "Failure while unpacking results fom get_batch_status")
            else:
                results[task_id] = self.func_table[task_id]

        return results

    def get_batch_result(self, task_id_list):
        """ Request results for a batch of task_ids
        """
        pass

    def run(self,
            *args,
            endpoint_id=None,
            function_id=None,
            asynchronous=False,
            **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_args = self.fx_serializer.serialize(args)
        ser_kwargs = self.fx_serializer.serialize(kwargs)
        payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])

        data = {
            'endpoint': endpoint_id,
            'func': function_id,
            'payload': payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuid' not in r:
            raise MalformedResponse(r)
        """
        Create a future to deal with the result
        funcx_future = FuncXFuture(self, task_id, async_poll)

        if not asynchronous:
            return funcx_future.result()

        # Return the result
        return funcx_future
        """
        return r['task_uuid']

    def map_run(self,
                *args,
                endpoint_id=None,
                function_id=None,
                asynchronous=False,
                **kwargs):
        """Initiate an invocation

        Parameters
        ----------
        *args : Any
            Args as specified by the function signature
        endpoint_id : uuid str
            Endpoint UUID string. Required
        function_id : uuid str
            Function UUID string. Required
        asynchronous : bool
            Whether or not to run the function asynchronously

        Returns
        -------
        task_id : str
        UUID string that identifies the task
        """
        servable_path = 'submit_batch'
        assert endpoint_id is not None, "endpoint_id key-word argument must be set"
        assert function_id is not None, "function_id key-word argument must be set"

        ser_kwargs = self.fx_serializer.serialize(kwargs)

        batch_payload = []
        iterator = args[0]
        for arg in iterator:
            ser_args = self.fx_serializer.serialize((arg, ))
            payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs])
            batch_payload.append(payload)

        data = {
            'endpoints': [endpoint_id],
            'func': function_id,
            'payload': batch_payload,
            'is_async': asynchronous
        }

        # Send the data to funcX
        r = self.post(servable_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        if 'task_uuids' not in r:
            raise MalformedResponse(r)

        return r['task_uuids']

    def register_endpoint(self, name, endpoint_uuid, description=None):
        """Register an endpoint with the funcX service.

        Parameters
        ----------
        name : str
            Name of the endpoint
        endpoint_uuid : str
                The uuid of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        A dict
            {'endopoint_id' : <>,
             'address' : <>,
             'client_ports': <>}
        """
        data = {
            "endpoint_name": name,
            "endpoint_uuid": endpoint_uuid,
            "description": description
        }

        r = self.post(self.ep_registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data

    def get_containers(self, name, description=None):
        """Register a DLHub endpoint with the funcX service and get the containers to launch.

        Parameters
        ----------
        name : str
            Name of the endpoint
        description : str
            Description of the endpoint

        Returns
        -------
        int
            The port to connect to and a list of containers
        """
        registration_path = 'get_containers'

        data = {"endpoint_name": name, "description": description}

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['endpoint_uuid'], r.data['endpoint_containers']

    def get_container(self, container_uuid, container_type):
        """Get the details of a container for staging it locally.

        Parameters
        ----------
        container_uuid : str
            UUID of the container in question
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker)

        Returns
        -------
        dict
            The details of the containers to deploy
        """
        container_path = f'containers/{container_uuid}/{container_type}'

        r = self.get(container_path)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container']

    def register_function(self,
                          function,
                          function_name=None,
                          container_uuid=None,
                          description=None):
        """Register a function code with the funcX service.

        Parameters
        ----------
        function : Python Function
            The function to be registered for remote execution

        function_name : str
            The entry point (function name) of the function. Default: None

        container_uuid : str
            Container UUID from registration with funcX

        description : str
            Description of the file

        Returns
        -------
        function uuid : str
            UUID identifier for the registered function
        """
        registration_path = 'register_function'

        serialized_fn = self.fx_serializer.serialize(function)
        packed_code = self.fx_serializer.pack_buffers([serialized_fn])

        data = {
            "function_name": function.__name__,
            "function_code": packed_code,
            "container_uuid": container_uuid,
            "entry_point":
            function_name if function_name else function.__name__,
            "description": description
        }

        logger.info("Registering function : {}".format(data))

        r = self.post(registration_path, json_body=data)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['function_uuid']

    def register_container(self,
                           location,
                           container_type,
                           name='',
                           description=''):
        """Register a container with the funcX service.

        Parameters
        ----------
        location : str
            The location of the container (e.g., its docker url). Required
        container_type : str
            The type of containers that will be used (Singularity, Shifter, Docker). Required

        name : str
            A name for the container. Default = ''
        description : str
            A description to associate with the container. Default = ''

        Returns
        -------
        str
            The id of the container
        """
        container_path = f'containers'

        payload = {
            'name': name,
            'location': location,
            'description': description,
            'type': container_type
        }

        r = self.post(container_path, json_body=payload)
        if r.http_status is not 200:
            raise Exception(r)

        # Return the result
        return r.data['container_id']