def dont_run_yet(endpoint_id=None, tasks=10, duration=1, hostname=None): # tasks_rq = EndpointQueue(f'task_{endpoint_id}', hostname) tasks_channel = RedisPubSub(hostname) tasks_channel.connect() redis_client = tasks_channel.redis_client redis_client.ping() fxs = FuncXSerializer() ser_code = fxs.serialize(slow_double) fn_code = fxs.pack_buffers([ser_code]) start = time.time() task_ids = {} for i in range(tasks): time.sleep(duration) task_id = str(uuid.uuid4()) print("Task_id : ", task_id) ser_args = fxs.serialize([i]) ser_kwargs = fxs.serialize({"duration": duration}) input_data = fxs.pack_buffers([ser_args, ser_kwargs]) payload = fn_code + input_data container_id = "RAW" task = Task(redis_client, task_id, container_id, serializer="", payload=payload) task.endpoint = endpoint_id task.status = TaskState.WAITING_FOR_EP # tasks_rq.enqueue(task) tasks_channel.put(endpoint_id, task) task_ids[i] = task_id d1 = time.time() - start print(f"Time to launch {tasks} tasks: {d1:8.3f} s") delay = 5 print(f"Sleeping {delay} seconds") time.sleep(delay) print(f"Launched {tasks} tasks") for i in range(tasks): task_id = task_ids[i] print("Task_id : ", task_id) task = Task.from_id(redis_client, task_id) # TODO: wait for task result... time.sleep(duration) try: result = fxs.deserialize(task.result) print(f"Result : {result}") except Exception as e: print(f"Task failed with exception:{e}") pass delta = time.time() - start print(f"Time to complete {tasks} tasks: {delta:8.3f} s") print(f"Throughput : {tasks / delta:8.3f} Tasks/s") return delta
def deserialize(): """Return the deserialized result """ fx_serializer = FuncXSerializer() # Return a failure message if all else fails ret_package = {'error': 'Failed to deserialize result'} try: inputs = request.json res = fx_serializer.deserialize(inputs) ret_package = jsonify(res) except Exception as e: print(e) return jsonify(ret_package), 500 return ret_package, 200
def server(port=0, host="", debug=False, datasize=102400): try: from funcx.serialize import FuncXSerializer fxs = FuncXSerializer(use_offprocess_checker=False) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind((host, port)) bound_port = s.getsockname()[1] print(f"BINDING TO:{bound_port}", flush=True) s.listen(1) conn, addr = s.accept() # we only expect one incoming connection here. with conn: while True: b_msg = conn.recv(datasize) if not b_msg: print("Exiting") return msg = pickle.loads(b_msg) if msg == "PING": ret_value = ("PONG", None) else: try: method = fxs.deserialize(msg) # noqa del method except Exception as e: ret_value = ("DESERIALIZE_FAIL", str(e)) else: ret_value = ("SUCCESS", None) ret_buf = pickle.dumps(ret_value) conn.sendall(ret_buf) except Exception as e: print(f"OFF_PROCESS_CHECKER FAILURE, Exception:{e}") sys.exit()
# set_file_logger('executor.log', name='funcx_endpoint', level=logging.DEBUG) htex = HighThroughputExecutor(interchange_local=True, passthrough=True) htex.start(results_passthrough=results_queue) htex._start_remote_interchange_process() fx_serializer = FuncXSerializer() for i in range(10): task_id = str(uuid.uuid4()) args = (i, ) kwargs = {} fn_code = fx_serializer.serialize(double) ser_code = fx_serializer.pack_buffers([fn_code]) ser_params = fx_serializer.pack_buffers( [fx_serializer.serialize(args), fx_serializer.serialize(kwargs)]) payload = Task(task_id, "RAW", ser_code + ser_params) f = htex.submit_raw(payload.pack()) time.sleep(0.5) for i in range(10): result_package = results_queue.get() # print("Result package : ", result_package) r = pickle.loads(result_package) result = fx_serializer.deserialize(r["result"]) print(f"Result:{i}: {result}") print("All done")
class FuncXClient(throttling.ThrottledBaseClient): """Main class for interacting with the funcX service Holds helper operations for performing common tasks with the funcX service. """ TOKEN_DIR = os.path.expanduser("~/.funcx/credentials") TOKEN_FILENAME = 'funcx_sdk_tokens.json' CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c' def __init__(self, http_timeout=None, funcx_home=os.path.join('~', '.funcx'), force_login=False, fx_authorizer=None, funcx_service_address='https://api.funcx.org/v1', **kwargs): """ Initialize the client Parameters ---------- http_timeout: int Timeout for any call to service in seconds. Default is no timeout force_login: bool Whether to force a login to get new credentials. fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`: A custom authorizer instance to communicate with funcX. Default: ``None``, will be created. funcx_service_address: str The address of the funcX web service to communicate with. Default: https://api.funcx.org/v1 Keyword arguments are the same as for BaseClient. """ self.func_table = {} self.ep_registration_path = 'register_endpoint_2' self.funcx_home = os.path.expanduser(funcx_home) if not os.path.exists(self.TOKEN_DIR): os.makedirs(self.TOKEN_DIR) tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME) self.native_client = NativeClient( client_id=self.CLIENT_ID, app_name="FuncX SDK", token_storage=JSONTokenStorage(tokens_filename)) # TODO: if fx_authorizer is given, we still need to get an authorizer for Search fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" search_scope = "urn:globus:auth:scope:search.api.globus.org:all" scopes = [fx_scope, search_scope, "openid"] search_authorizer = None if not fx_authorizer: self.native_client.login( requested_scopes=scopes, no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True), refresh_tokens=kwargs.get("refresh_tokens", True), force=force_login) all_authorizers = self.native_client.get_authorizers_by_scope( requested_scopes=scopes) fx_authorizer = all_authorizers[fx_scope] search_authorizer = all_authorizers[search_scope] openid_authorizer = all_authorizers["openid"] super(FuncXClient, self).__init__("funcX", environment='funcx', authorizer=fx_authorizer, http_timeout=http_timeout, base_url=funcx_service_address, **kwargs) self.fx_serializer = FuncXSerializer() authclient = AuthClient(authorizer=openid_authorizer) user_info = authclient.oauth2_userinfo() self.searcher = SearchHelper(authorizer=search_authorizer, owner_uuid=user_info['sub']) self.funcx_service_address = funcx_service_address def version_check(self): """Check this client version meets the service's minimum supported version. """ resp = self.get("version", params={"service": "all"}) versions = resp.data if "min_ep_version" not in versions: raise VersionMismatch( "Failed to retrieve version information from funcX service.") min_ep_version = versions['min_ep_version'] if ENDPOINT_VERSION is None: raise VersionMismatch( "You do not have the funcx endpoint installed. You can use 'pip install funcx-endpoint'." ) if ENDPOINT_VERSION < min_ep_version: raise VersionMismatch( f"Your version={ENDPOINT_VERSION} is lower than the " f"minimum version for an endpoint: {min_ep_version}. Please update." ) def logout(self): """Remove credentials from your local system """ self.native_client.logout() def update_table(self, return_msg, task_id): """ Parses the return message from the service and updates the internal func_tables Parameters ---------- return_msg : str Return message received from the funcx service task_id : str task id string """ if isinstance(return_msg, str): r_dict = json.loads(return_msg) else: r_dict = return_msg status = {'pending': True} if 'result' in r_dict: try: r_obj = self.fx_serializer.deserialize(r_dict['result']) completion_t = r_dict['completion_t'] except Exception: raise SerializationError("Result Object Deserialization") else: status.update({ 'pending': False, 'result': r_obj, 'completion_t': completion_t }) self.func_table[task_id] = status elif 'exception' in r_dict: try: r_exception = self.fx_serializer.deserialize( r_dict['exception']) completion_t = r_dict['completion_t'] logger.info(f"Exception : {r_exception}") except Exception: raise SerializationError( "Task's exception object deserialization") else: status.update({ 'pending': False, 'exception': r_exception, 'completion_t': completion_t }) self.func_table[task_id] = status return status def get_task(self, task_id): """Get a funcX task. Parameters ---------- task_id : str UUID of the task Returns ------- dict Task block containing "status" key. """ if task_id in self.func_table: return self.func_table[task_id] r = self.get("tasks/{task_id}".format(task_id=task_id)) logger.debug("Response string : {}".format(r)) try: rets = self.update_table(r.text, task_id) except Exception as e: raise e return rets def get_result(self, task_id): """ Get the result of a funcX task Parameters ---------- task_id: str UUID of the task Returns ------- Result obj: If task completed Raises ------ Exception obj: Exception due to which the task failed """ task = self.get_task(task_id) if task['pending'] is True: raise Exception("Task pending") else: if 'result' in task: return task['result'] else: logger.warning("We have an exception : {}".format( task['exception'])) task['exception'].reraise() def get_batch_status(self, task_id_list): """ Request status for a batch of task_ids """ assert isinstance(task_id_list, list), "get_batch_status expects a list of task ids" pending_task_ids = [ t for t in task_id_list if t not in self.func_table ] results = {} if pending_task_ids: payload = {'task_ids': pending_task_ids} r = self.post("/batch_status", json_body=payload) logger.debug("Response string : {}".format(r)) pending_task_ids = set(pending_task_ids) for task_id in task_id_list: if task_id in pending_task_ids: try: data = r['results'][task_id] rets = self.update_table(data, task_id) results[task_id] = rets except KeyError: logger.debug( "Task {} info was not available in the batch status") except Exception: logger.exception( "Failure while unpacking results fom get_batch_status") else: results[task_id] = self.func_table[task_id] return results def get_batch_result(self, task_id_list): """ Request results for a batch of task_ids """ pass def run(self, *args, endpoint_id=None, function_id=None, **kwargs): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task """ assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" batch = self.create_batch() batch.add(*args, endpoint_id=endpoint_id, function_id=function_id, **kwargs) r = self.batch_run(batch) """ Create a future to deal with the result funcx_future = FuncXFuture(self, task_id, async_poll) if not asynchronous: return funcx_future.result() # Return the result return funcx_future """ return r[0] def create_batch(self): """ Create a Batch instance to handle batch submission in funcX Parameters ---------- Returns ------- Batch instance Status block containing "status" key. """ batch = Batch() return batch def batch_run(self, batch): """Initiate a batch of tasks to funcX Parameters ---------- batch: a Batch object Returns ------- task_ids : a list of UUID strings that identify the tasks """ servable_path = 'submit' assert isinstance(batch, Batch), "Requires a Batch object as input" assert len(batch.tasks) > 0, "Requires a non-empty batch" data = batch.prepare() # Send the data to funcX r = self.post(servable_path, json_body=data) if r.http_status != 200: raise HTTPError(r) if r.get("status", "Failure") == "Failure": raise MalformedResponse("FuncX Request failed: {}".format( r.get("reason", "Unknown"))) return r['task_uuids'] def map_run(self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task """ servable_path = 'submit_batch' assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" ser_kwargs = self.fx_serializer.serialize(kwargs) batch_payload = [] iterator = args[0] for arg in iterator: ser_args = self.fx_serializer.serialize((arg, )) payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs]) batch_payload.append(payload) data = { 'endpoints': [endpoint_id], 'func': function_id, 'payload': batch_payload, 'is_async': asynchronous } # Send the data to funcX r = self.post(servable_path, json_body=data) if r.http_status != 200: raise Exception(r) if r.get("status", "Failure") == "Failure": raise MalformedResponse("FuncX Request failed: {}".format( r.get("reason", "Unknown"))) return r['task_uuids'] def register_endpoint(self, name, endpoint_uuid, metadata=None, endpoint_version=None): """Register an endpoint with the funcX service. Parameters ---------- name : str Name of the endpoint endpoint_uuid : str The uuid of the endpoint metadata : dict endpoint metadata, see default_config example endpoint_version: str Version string to be passed to the webService as a compatibility check Returns ------- A dict {'endopoint_id' : <>, 'address' : <>, 'client_ports': <>} """ self.version_check() data = { "endpoint_name": name, "endpoint_uuid": endpoint_uuid, "version": endpoint_version } if metadata: data['meta'] = metadata r = self.post(self.ep_registration_path, json_body=data) if r.http_status != 200: raise HTTPError(r) # Return the result return r.data def get_containers(self, name, description=None): """Register a DLHub endpoint with the funcX service and get the containers to launch. Parameters ---------- name : str Name of the endpoint description : str Description of the endpoint Returns ------- int The port to connect to and a list of containers """ registration_path = 'get_containers' data = {"endpoint_name": name, "description": description} r = self.post(registration_path, json_body=data) if r.http_status != 200: raise HTTPError(r) # Return the result return r.data['endpoint_uuid'], r.data['endpoint_containers'] def get_container(self, container_uuid, container_type): """Get the details of a container for staging it locally. Parameters ---------- container_uuid : str UUID of the container in question container_type : str The type of containers that will be used (Singularity, Shifter, Docker) Returns ------- dict The details of the containers to deploy """ container_path = f'containers/{container_uuid}/{container_type}' r = self.get(container_path) if r.http_status != 200: raise HTTPError(r) # Return the result return r.data['container'] def get_endpoint_status(self, endpoint_uuid): """Get the status reports for an endpoint. Parameters ---------- endpoint_uuid : str UUID of the endpoint in question Returns ------- dict The details of the endpoint's stats """ stats_path = f'endpoints/{endpoint_uuid}/status' r = self.get(stats_path) if r.http_status != 200: raise HTTPError(r) # Return the result return r.data def register_function(self, function, function_name=None, container_uuid=None, description=None, public=False, group=None, searchable=True): """Register a function code with the funcX service. Parameters ---------- function : Python Function The function to be registered for remote execution function_name : str The entry point (function name) of the function. Default: None container_uuid : str Container UUID from registration with funcX description : str Description of the file public : bool Whether or not the function is publicly accessible. Default = False group : str A globus group uuid to share this function with searchable : bool If true, the function will be indexed into globus search with the appropriate permissions Returns ------- function uuid : str UUID identifier for the registered function """ registration_path = 'register_function' source_code = "" try: source_code = getsource(function) except OSError: logger.error( "Failed to find source code during function registration.") serialized_fn = self.fx_serializer.serialize(function) packed_code = self.fx_serializer.pack_buffers([serialized_fn]) data = { "function_name": function.__name__, "function_code": packed_code, "function_source": source_code, "container_uuid": container_uuid, "entry_point": function_name if function_name else function.__name__, "description": description, "public": public, "group": group, "searchable": searchable } logger.info("Registering function : {}".format(data)) r = self.post(registration_path, json_body=data) if r.http_status != 200: raise HTTPError(r) func_uuid = r.data['function_uuid'] # Return the result return func_uuid def update_function(self, func_uuid, function): pass def search_function(self, q, offset=0, limit=10, advanced=False): """Search for function via the funcX service Parameters ---------- q : str free-form query string offset : int offset into total results limit : int max number of results to return advanced : bool allows elastic-search like syntax in query string Returns ------- FunctionSearchResults """ return self.searcher.search_function(q, offset=offset, limit=limit, advanced=advanced) def search_endpoint(self, q, scope='all', owner_id=None): """ Parameters ---------- q scope : str Can be one of {'all', 'my-endpoints', 'shared-with-me'} owner_id should be urn like f"urn:globus:auth:identity:{owner_uuid}" Returns ------- """ return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id) def register_container(self, location, container_type, name='', description=''): """Register a container with the funcX service. Parameters ---------- location : str The location of the container (e.g., its docker url). Required container_type : str The type of containers that will be used (Singularity, Shifter, Docker). Required name : str A name for the container. Default = '' description : str A description to associate with the container. Default = '' Returns ------- str The id of the container """ container_path = 'containers' payload = { 'name': name, 'location': location, 'description': description, 'type': container_type } r = self.post(container_path, json_body=payload) if r.http_status != 200: raise HTTPError(r) # Return the result return r.data['container_id'] def add_to_whitelist(self, endpoint_id, function_ids): """Adds the function to the endpoint's whitelist Parameters ---------- endpoint_id : str The uuid of the endpoint function_ids : list A list of function id's to be whitelisted Returns ------- json The response of the request """ req_path = f'endpoints/{endpoint_id}/whitelist' if not isinstance(function_ids, list): function_ids = [function_ids] payload = {'func': function_ids} r = self.post(req_path, json_body=payload) if r.http_status != 200: raise HTTPError(r) # Return the result return r def get_whitelist(self, endpoint_id): """List the endpoint's whitelist Parameters ---------- endpoint_id : str The uuid of the endpoint Returns ------- json The response of the request """ req_path = f'endpoints/{endpoint_id}/whitelist' r = self.get(req_path) if r.http_status != 200: raise HTTPError(r) # Return the result return r def delete_from_whitelist(self, endpoint_id, function_ids): """List the endpoint's whitelist Parameters ---------- endpoint_id : str The uuid of the endpoint function_ids : list A list of function id's to be whitelisted Returns ------- json The response of the request """ if not isinstance(function_ids, list): function_ids = [function_ids] res = [] for fid in function_ids: req_path = f'endpoints/{endpoint_id}/whitelist/{fid}' r = self.delete(req_path) if r.http_status != 200: raise HTTPError(r) res.append(r) # Return the result return res
class FuncXClient(BaseClient): """Main class for interacting with the funcX service Holds helper operations for performing common tasks with the funcX service. """ TOKEN_DIR = os.path.expanduser("~/.funcx/credentials") CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c' def __init__(self, http_timeout=None, funcx_home=os.path.join('~', '.funcx'), force_login=False, fx_authorizer=None, funcx_service_address='https://dev.funcx.org/api/v1', **kwargs): """ Initialize the client Parameters ---------- http_timeout: int Timeout for any call to service in seconds. Default is no timeout force_login: bool Whether to force a login to get new credentials. fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`: A custom authorizer instance to communicate with funcX. Default: ``None``, will be created. service_address: str The address of the funcX web service to communicate with. Default: https://dev.funcx.org/api/v1 Keyword arguments are the same as for BaseClient. """ self.ep_registration_path = 'register_endpoint_2' self.funcx_home = os.path.expanduser(funcx_home) native_client = NativeClient(client_id=self.CLIENT_ID) fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" if not fx_authorizer: native_client.login( requested_scopes=[fx_scope], no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True), refresh_tokens=kwargs.get("refresh_tokens", True), force=force_login) all_authorizers = native_client.get_authorizers_by_scope( requested_scopes=[fx_scope]) fx_authorizer = all_authorizers[fx_scope] super(FuncXClient, self).__init__("funcX", environment='funcx', authorizer=fx_authorizer, http_timeout=http_timeout, base_url=funcx_service_address, **kwargs) self.fx_serializer = FuncXSerializer() def logout(self): """Remove credentials from your local system """ logout() def get_task_status(self, task_id): """Get the status of a funcX task. Parameters ---------- task_id : str UUID of the task Returns ------- dict Status block containing "status" key. """ r = self.get("{task_id}/status".format(task_id=task_id)) return json.loads(r.text) def get_result(self, task_id): """ Get the result of a funcX task Parameters ---------- task_id: str UUID of the task Returns ------- Result obj: If task completed Raises ------ Exception obj: Exception due to which the task failed """ r = self.get("{task_id}/status".format(task_id=task_id)) logger.info(f"Got from globus : {r}") r_dict = json.loads(r.text) if 'result' in r_dict: try: r_obj = self.fx_serializer.deserialize(r_dict['result']) except Exception: raise Exception( "Failure during deserialization of the result object") else: return r_obj elif 'exception' in r_dict: try: r_exception = self.fx_serializer.deserialize( r_dict['exception']) logger.info(f"Exception : {r_exception}") except Exception: raise Exception( "Failure during deserialization of the Task's exception object" ) else: r_exception.reraise() else: raise Exception("Task pending") def run(self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task """ servable_path = 'submit' assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" ser_args = self.fx_serializer.serialize(args) ser_kwargs = self.fx_serializer.serialize(kwargs) payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs]) data = { 'endpoint': endpoint_id, 'func': function_id, 'payload': payload, 'is_async': asynchronous } # Send the data to funcX r = self.post(servable_path, json_body=data) if r.http_status is not 200: raise Exception(r) if 'task_uuid' not in r: raise MalformedResponse(r) """ Create a future to deal with the result funcx_future = FuncXFuture(self, task_id, async_poll) if not asynchronous: return funcx_future.result() # Return the result return funcx_future """ return r['task_uuid'] def register_endpoint(self, name, endpoint_uuid, description=None): """Register an endpoint with the funcX service. Parameters ---------- name : str Name of the endpoint endpoint_uuid : str The uuid of the endpoint description : str Description of the endpoint Returns ------- A dict {'endopoint_id' : <>, 'address' : <>, 'client_ports': <>} """ data = { "endpoint_name": name, "endpoint_uuid": endpoint_uuid, "description": description } r = self.post(self.ep_registration_path, json_body=data) if r.http_status is not 200: raise Exception(r) # Return the result return r.data def get_containers(self, name, description=None): """Register a DLHub endpoint with the funcX service and get the containers to launch. Parameters ---------- name : str Name of the endpoint description : str Description of the endpoint Returns ------- int The port to connect to and a list of containers """ registration_path = 'get_containers' data = {"endpoint_name": name, "description": description} r = self.post(registration_path, json_body=data) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['endpoint_uuid'], r.data['endpoint_containers'] def get_container(self, container_uuid, container_type): """Get the details of a container for staging it locally. Parameters ---------- container_uuid : str UUID of the container in question container_type : str The type of containers that will be used (Singularity, Shifter, Docker) Returns ------- dict The details of the containers to deploy """ container_path = f'containers/{container_uuid}/{container_type}' r = self.get(container_path) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['container'] def register_function(self, function, function_name=None, container_uuid=None, description=None): """Register a function code with the funcX service. Parameters ---------- function : Python Function The function to be registered for remote execution function_name : str The entry point (function name) of the function. Default: None container_uuid : str Container UUID from registration with funcX description : str Description of the file Returns ------- function uuid : str UUID identifier for the registered function """ registration_path = 'register_function' serialized_fn = self.fx_serializer.serialize(function) packed_code = self.fx_serializer.pack_buffers([serialized_fn]) data = { "function_name": function.__name__, "function_code": packed_code, "container_uuid": container_uuid, "entry_point": function_name if function_name else function.__name__, "description": description } logger.info("Registering function : {}".format(data)) r = self.post(registration_path, json_body=data) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['function_uuid']
class FuncXClient: """Main class for interacting with the funcX service Holds helper operations for performing common tasks with the funcX service. """ TOKEN_DIR = os.path.expanduser("~/.funcx/credentials") TOKEN_FILENAME = "funcx_sdk_tokens.json" FUNCX_SDK_CLIENT_ID = os.environ.get( "FUNCX_SDK_CLIENT_ID", "4cf29807-cf21-49ec-9443-ff9a3fb9f81c" ) FUNCX_SCOPE = os.environ.get( "FUNCX_SCOPE", "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", ) def __init__( self, http_timeout=None, funcx_home=_FUNCX_HOME, force_login=False, fx_authorizer=None, search_authorizer=None, openid_authorizer=None, funcx_service_address=None, check_endpoint_version=False, asynchronous=False, loop=None, results_ws_uri=None, use_offprocess_checker=True, environment=None, **kwargs, ): """ Initialize the client Parameters ---------- http_timeout: int Timeout for any call to service in seconds. Default is no timeout force_login: bool Whether to force a login to get new credentials. fx_authorizer:class:`GlobusAuthorizer \ <globus_sdk.authorizers.base.GlobusAuthorizer>`: A custom authorizer instance to communicate with funcX. Default: ``None``, will be created. search_authorizer:class:`GlobusAuthorizer \ <globus_sdk.authorizers.base.GlobusAuthorizer>`: A custom authorizer instance to communicate with Globus Search. Default: ``None``, will be created. openid_authorizer:class:`GlobusAuthorizer \ <globus_sdk.authorizers.base.GlobusAuthorizer>`: A custom authorizer instance to communicate with OpenID. Default: ``None``, will be created. funcx_service_address: str For internal use only. The address of the web service. results_ws_uri: str For internal use only. The address of the websocket service. environment: str For internal use only. The name of the environment to use. asynchronous: bool Should the API use asynchronous interactions with the web service? Currently only impacts the run method Default: False loop: AbstractEventLoop If asynchronous mode is requested, then you can provide an optional event loop instance. If None, then we will access asyncio.get_event_loop() Default: None use_offprocess_checker: Bool, Use this option to disable the offprocess_checker in the FuncXSerializer used by the client. Default: True Keyword arguments are the same as for BaseClient. """ # resolve URLs if not set if funcx_service_address is None: funcx_service_address = get_web_service_url(environment) if results_ws_uri is None: results_ws_uri = get_web_socket_url(environment) self.func_table = {} self.use_offprocess_checker = use_offprocess_checker self.funcx_home = os.path.expanduser(funcx_home) self.session_task_group_id = str(uuid.uuid4()) if not os.path.exists(self.TOKEN_DIR): os.makedirs(self.TOKEN_DIR) tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME) self.native_client = NativeClient( client_id=self.FUNCX_SDK_CLIENT_ID, app_name="FuncX SDK", token_storage=JSONTokenStorage(tokens_filename), ) # TODO: if fx_authorizer is given, we still need to get an authorizer for Search search_scope = "urn:globus:auth:scope:search.api.globus.org:all" scopes = [self.FUNCX_SCOPE, search_scope, "openid"] if not fx_authorizer or not search_authorizer or not openid_authorizer: self.native_client.login( requested_scopes=scopes, no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True), refresh_tokens=kwargs.get("refresh_tokens", True), force=force_login, ) all_authorizers = self.native_client.get_authorizers_by_scope( requested_scopes=scopes ) fx_authorizer = all_authorizers[self.FUNCX_SCOPE] search_authorizer = all_authorizers[search_scope] openid_authorizer = all_authorizers["openid"] self.web_client = FuncxWebClient( base_url=funcx_service_address, authorizer=fx_authorizer ) self.fx_serializer = FuncXSerializer( use_offprocess_checker=self.use_offprocess_checker ) authclient = AuthClient(authorizer=openid_authorizer) user_info = authclient.oauth2_userinfo() self.searcher = SearchHelper( authorizer=search_authorizer, owner_uuid=user_info["sub"] ) self.funcx_service_address = funcx_service_address self.check_endpoint_version = check_endpoint_version self.version_check() self.results_ws_uri = results_ws_uri self.asynchronous = asynchronous if asynchronous: self.loop = loop if loop else asyncio.get_event_loop() # Start up an asynchronous polling loop in the background self.ws_polling_task = WebSocketPollingTask( self, self.loop, init_task_group_id=self.session_task_group_id, results_ws_uri=self.results_ws_uri, ) else: self.loop = None def version_check(self): """Check this client version meets the service's minimum supported version.""" resp = self.web_client.get_version() versions = resp.data if "min_ep_version" not in versions: raise VersionMismatch( "Failed to retrieve version information from funcX service." ) min_ep_version = versions["min_ep_version"] min_sdk_version = versions["min_sdk_version"] if self.check_endpoint_version: if ENDPOINT_VERSION is None: raise VersionMismatch( "You do not have the funcx endpoint installed. " "You can use 'pip install funcx-endpoint'." ) if LooseVersion(ENDPOINT_VERSION) < LooseVersion(min_ep_version): raise VersionMismatch( f"Your version={ENDPOINT_VERSION} is lower than the " f"minimum version for an endpoint: {min_ep_version}. " "Please update. " f"pip install funcx-endpoint>={min_ep_version}" ) else: if LooseVersion(SDK_VERSION) < LooseVersion(min_sdk_version): raise VersionMismatch( f"Your version={SDK_VERSION} is lower than the " f"minimum version for funcx SDK: {min_sdk_version}. " "Please update. " f"pip install funcx>={min_sdk_version}" ) def logout(self): """Remove credentials from your local system""" self.native_client.logout() def update_table(self, return_msg, task_id): """Parses the return message from the service and updates the internal func_table Parameters ---------- return_msg : str Return message received from the funcx service task_id : str task id string """ if isinstance(return_msg, str): r_dict = json.loads(return_msg) else: r_dict = return_msg r_status = r_dict.get("status", "unknown") status = {"pending": True, "status": r_status} if "result" in r_dict: try: r_obj = self.fx_serializer.deserialize(r_dict["result"]) completion_t = r_dict["completion_t"] except Exception: raise SerializationError("Result Object Deserialization") else: status.update( {"pending": False, "result": r_obj, "completion_t": completion_t} ) self.func_table[task_id] = status elif "exception" in r_dict: try: r_exception = self.fx_serializer.deserialize(r_dict["exception"]) completion_t = r_dict["completion_t"] logger.info(f"Exception : {r_exception}") except Exception: raise SerializationError("Task's exception object deserialization") else: status.update( { "pending": False, "exception": r_exception, "completion_t": completion_t, } ) self.func_table[task_id] = status return status def get_task(self, task_id): """Get a funcX task. Parameters ---------- task_id : str UUID of the task Returns ------- dict Task block containing "status" key. """ if task_id in self.func_table: return self.func_table[task_id] r = self.web_client.get_task(task_id) logger.debug(f"Response string : {r}") try: rets = self.update_table(r.text, task_id) except Exception as e: raise e return rets def get_result(self, task_id): """Get the result of a funcX task Parameters ---------- task_id: str UUID of the task Returns ------- Result obj: If task completed Raises ------ Exception obj: Exception due to which the task failed """ task = self.get_task(task_id) if task["pending"] is True: raise TaskPending(task["status"]) else: if "result" in task: return task["result"] else: logger.warning("We have an exception : {}".format(task["exception"])) task["exception"].reraise() def get_batch_result(self, task_id_list): """Request status for a batch of task_ids""" assert isinstance( task_id_list, list ), "get_batch_result expects a list of task ids" pending_task_ids = [t for t in task_id_list if t not in self.func_table] results = {} if pending_task_ids: r = self.web_client.get_batch_status(pending_task_ids) logger.debug(f"Response string : {r}") pending_task_ids = set(pending_task_ids) for task_id in task_id_list: if task_id in pending_task_ids: try: data = r["results"][task_id] rets = self.update_table(data, task_id) results[task_id] = rets except KeyError: logger.debug("Task {} info was not available in the batch status") except Exception: logger.exception( "Failure while unpacking results fom get_batch_result" ) else: results[task_id] = self.func_table[task_id] return results def run(self, *args, endpoint_id=None, function_id=None, **kwargs): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task if asynchronous is False funcX Task: asyncio.Task A future that will eventually resolve into the function's result if asynchronous is True """ assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" batch = self.create_batch() batch.add(*args, endpoint_id=endpoint_id, function_id=function_id, **kwargs) r = self.batch_run(batch) return r[0] def create_batch(self, task_group_id=None): """ Create a Batch instance to handle batch submission in funcX Parameters ---------- task_group_id : str Override the session wide session_task_group_id with a different task_group_id for this batch. If task_group_id is not specified, it will default to using the client's session_task_group_id Returns ------- Batch instance Status block containing "status" key. """ if not task_group_id: task_group_id = self.session_task_group_id batch = Batch(task_group_id=task_group_id) return batch def batch_run(self, batch): """Initiate a batch of tasks to funcX Parameters ---------- batch: a Batch object Returns ------- task_ids : a list of UUID strings that identify the tasks """ assert isinstance(batch, Batch), "Requires a Batch object as input" assert len(batch.tasks) > 0, "Requires a non-empty batch" data = batch.prepare() # Send the data to funcX r = self.web_client.submit(data) task_uuids = [] for result in r["results"]: task_id = result["task_uuid"] task_uuids.append(task_id) if result["http_status_code"] != 200: # this method of handling errors for a batch response is not # ideal, as it will raise any error in the multi-response, # but it will do until batch_run is deprecated in favor of Executer handle_response_errors(result) if self.asynchronous: task_group_id = r["task_group_id"] asyncio_tasks = [] for task_id in task_uuids: funcx_task = FuncXTask(task_id) asyncio_task = self.loop.create_task(funcx_task.get_result()) asyncio_tasks.append(asyncio_task) self.ws_polling_task.add_task(funcx_task) self.ws_polling_task.put_task_group_id(task_group_id) return asyncio_tasks return task_uuids def map_run( self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs ): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task """ assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" ser_kwargs = self.fx_serializer.serialize(kwargs) batch_payload = [] iterator = args[0] for arg in iterator: ser_args = self.fx_serializer.serialize((arg,)) payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs]) batch_payload.append(payload) data = { "endpoints": [endpoint_id], "func": function_id, "payload": batch_payload, "is_async": asynchronous, } # Send the data to funcX r = self.web_client.submit_batch(data) return r["task_uuids"] def register_endpoint( self, name, endpoint_uuid, metadata=None, endpoint_version=None ): """Register an endpoint with the funcX service. Parameters ---------- name : str Name of the endpoint endpoint_uuid : str The uuid of the endpoint metadata : dict endpoint metadata, see default_config example endpoint_version: str Version string to be passed to the webService as a compatibility check Returns ------- A dict {'endpoint_id' : <>, 'address' : <>, 'client_ports': <>} """ self.version_check() r = self.web_client.register_endpoint( endpoint_name=name, endpoint_id=endpoint_uuid, metadata=metadata, endpoint_version=endpoint_version, ) return r.data def get_containers(self, name, description=None): """Register a DLHub endpoint with the funcX service and get the containers to launch. Parameters ---------- name : str Name of the endpoint description : str Description of the endpoint Returns ------- int The port to connect to and a list of containers """ data = {"endpoint_name": name, "description": description} r = self.web_client.post("get_containers", data=data) return r.data["endpoint_uuid"], r.data["endpoint_containers"] def get_container(self, container_uuid, container_type): """Get the details of a container for staging it locally. Parameters ---------- container_uuid : str UUID of the container in question container_type : str The type of containers that will be used (Singularity, Shifter, Docker) Returns ------- dict The details of the containers to deploy """ self.version_check() r = self.web_client.get(f"containers/{container_uuid}/{container_type}") return r.data["container"] def get_endpoint_status(self, endpoint_uuid): """Get the status reports for an endpoint. Parameters ---------- endpoint_uuid : str UUID of the endpoint in question Returns ------- dict The details of the endpoint's stats """ r = self.web_client.get_endpoint_status(endpoint_uuid) return r.data def register_function( self, function, function_name=None, container_uuid=None, description=None, public=False, group=None, searchable=True, ): """Register a function code with the funcX service. Parameters ---------- function : Python Function The function to be registered for remote execution function_name : str The entry point (function name) of the function. Default: None container_uuid : str Container UUID from registration with funcX description : str Description of the file public : bool Whether or not the function is publicly accessible. Default = False group : str A globus group uuid to share this function with searchable : bool If true, the function will be indexed into globus search with the appropriate permissions Returns ------- function uuid : str UUID identifier for the registered function """ data = FunctionRegistrationData( function=function, failover_source="", container_uuid=container_uuid, entry_point=function_name, description=description, public=public, group=group, searchable=searchable, serializer=self.fx_serializer, ) logger.info(f"Registering function : {data}") r = self.web_client.register_function(data) return r.data["function_uuid"] def search_function(self, q, offset=0, limit=10, advanced=False): """Search for function via the funcX service Parameters ---------- q : str free-form query string offset : int offset into total results limit : int max number of results to return advanced : bool allows elastic-search like syntax in query string Returns ------- FunctionSearchResults """ return self.searcher.search_function( q, offset=offset, limit=limit, advanced=advanced ) def search_endpoint(self, q, scope="all", owner_id=None): """ Parameters ---------- q scope : str Can be one of {'all', 'my-endpoints', 'shared-with-me'} owner_id should be urn like f"urn:globus:auth:identity:{owner_uuid}" Returns ------- """ return self.searcher.search_endpoint(q, scope=scope, owner_id=owner_id) def register_container(self, location, container_type, name="", description=""): """Register a container with the funcX service. Parameters ---------- location : str The location of the container (e.g., its docker url). Required container_type : str The type of containers that will be used (Singularity, Shifter, Docker). Required name : str A name for the container. Default = '' description : str A description to associate with the container. Default = '' Returns ------- str The id of the container """ payload = { "name": name, "location": location, "description": description, "type": container_type, } r = self.web_client.post("containers", data=payload) return r.data["container_id"] def add_to_whitelist(self, endpoint_id, function_ids): """Adds the function to the endpoint's whitelist Parameters ---------- endpoint_id : str The uuid of the endpoint function_ids : list A list of function id's to be whitelisted Returns ------- json The response of the request """ return self.web_client.whitelist_add(endpoint_id, function_ids) def get_whitelist(self, endpoint_id): """List the endpoint's whitelist Parameters ---------- endpoint_id : str The uuid of the endpoint Returns ------- json The response of the request """ return self.web_client.get_whitelist(endpoint_id) def delete_from_whitelist(self, endpoint_id, function_ids): """List the endpoint's whitelist Parameters ---------- endpoint_id : str The uuid of the endpoint function_ids : list A list of function id's to be whitelisted Returns ------- json The response of the request """ if not isinstance(function_ids, list): function_ids = [function_ids] res = [] for fid in function_ids: res.append(self.web_client.whitelist_remove(endpoint_id, fid)) return res
class CentralScheduler(object): def __init__(self, endpoints, strategy='round-robin', runtime_predictor='rolling-average', last_n=3, train_every=1, log_level='INFO', import_model_file=None, transfer_model_file=None, sync_level='exists', max_backups=0, backup_delay_threshold=2.0, *args, **kwargs): self._fxc = FuncXClient(*args, **kwargs) # Initialize a transfer client self._transfer_manger = TransferManager(endpoints=endpoints, sync_level=sync_level, log_level=log_level) # Info about FuncX endpoints we can execute on self._endpoints = endpoints self._dead_endpoints = set() self.last_result_time = defaultdict(float) self.temperature = defaultdict(lambda: 'WARM') self._imports = defaultdict(list) self._imports_required = defaultdict(list) # Track which endpoints a function can't run on self._blocked = defaultdict(set) # Track pending tasks # We will provide the client our own task ids, since we may submit the # same task multiple times to the FuncX service, and sometimes we may # wait to submit a task to FuncX (e.g., wait for a data transfer). self._task_id_translation = {} self._pending = {} self._pending_by_endpoint = defaultdict(set) self._task_info = {} # List of endpoints a (virtual) task was scheduled to self._endpoints_sent_to = defaultdict(list) self.max_backups = max_backups self.backup_delay_threshold = backup_delay_threshold self._latest_status = {} self._last_task_ETA = defaultdict(float) # Maximum ETA, if any, of a task which we allow to be scheduled on an # endpoint. This is to prevent backfill tasks to be longer than the # estimated time for when a pending data transfer will finish. self._transfer_ETAs = defaultdict(dict) # Estimated error in the pending-task time of an endpoint. # Updated every time a task result is received from an endpoint. self._queue_error = defaultdict(float) # Set logging levels logger.setLevel(log_level) self.execution_log = [] # Intialize serializer self.fx_serializer = FuncXSerializer() self.fx_serializer.use_custom('03\n', 'code') # Initialize runtime predictor self.runtime = init_runtime_predictor(runtime_predictor, endpoints=endpoints, last_n=last_n, train_every=train_every) logger.info(f"Runtime predictor using strategy {self.runtime}") # Initialize transfer-time predictor self.transfer_time = TransferPredictor(endpoints=endpoints, train_every=train_every, state_file=transfer_model_file) # Initialize import-time predictor self.import_predictor = ImportPredictor(endpoints=endpoints, state_file=import_model_file) # Initialize scheduling strategy self.strategy = init_strategy(strategy, endpoints=endpoints, runtime_predictor=self.runtime, queue_predictor=self.queue_delay, cold_start_predictor=self.cold_start, transfer_predictor=self.transfer_time) logger.info(f"Scheduler using strategy {self.strategy}") # Start thread to check on endpoints regularly self._endpoint_watchdog = Thread(target=self._check_endpoints) self._endpoint_watchdog.start() # Start thread to monitor tasks and send tasks to FuncX service self._scheduled_tasks = Queue() self._task_watchdog_sleep = 0.15 self._task_watchdog = Thread(target=self._monitor_tasks) self._task_watchdog.start() def block(self, func, endpoint): if endpoint not in self._endpoints: logger.error('Cannot block unknown endpoint {}'.format(endpoint)) return { 'status': 'Failed', 'reason': 'Unknown endpoint {}'.format(endpoint) } elif len(self._blocked[func]) == len(self._endpoints) - 1: logger.error( 'Cannot block last remaining endpoint {}'.format(endpoint)) return { 'status': 'Failed', 'reason': 'Cannot block all endpoints for {}'.format(func) } else: logger.info('Blocking endpoint {} for function {}'.format( endpoint_name(endpoint), func)) self._blocked[func].add(endpoint) return {'status': 'Success'} def register_imports(self, func, imports): logger.info('Registered function {} with imports {}'.format( func, imports)) self._imports_required[func] = imports def batch_submit(self, tasks, headers): # TODO: smarter scheduling for batch submissions task_ids = [] endpoints = [] for func, payload in tasks: _, ser_kwargs = self.fx_serializer.unpack_buffers(payload) kwargs = self.fx_serializer.deserialize(ser_kwargs) files = kwargs['_globus_files'] task_id, endpoint = self._schedule_task(func=func, payload=payload, headers=headers, files=files) task_ids.append(task_id) endpoints.append(endpoint) return task_ids, endpoints def _schedule_task(self, func, payload, headers, files, task_id=None): # If this is the first time scheduling this task_id # (i.e., non-backup task), record the necessary metadata if task_id is None: # Create (fake) task id to return to client task_id = str(uuid.uuid4()) # Store task information self._task_id_translation[task_id] = set() # Information required to schedule the task, now and in the future info = { 'function_id': func, 'payload': payload, 'headers': headers, 'files': files, 'time_requested': time.time() } self._task_info[task_id] = info # TODO: do not choose a dead endpoint (reliably) # exclude = self._blocked[func] | self._dead_endpoints | set(self._endpoints_sent_to[task_id]) # noqa if len(self._dead_endpoints) > 0: logger.warn('{} endpoints seem dead. Hope they still work!'.format( len(self._dead_endpoints))) exclude = self._blocked[func] | set(self._endpoints_sent_to[task_id]) choice = self.strategy.choose_endpoint( func, payload=payload, files=files, exclude=exclude, transfer_ETAs=self._transfer_ETAs) # noqa endpoint = choice['endpoint'] logger.info('Choosing endpoint {} for func {}, task id {}'.format( endpoint_name(endpoint), func, task_id)) choice['ETA'] = self.strategy.predict_ETA(func, endpoint, payload, files=files) # Start Globus transfer of required files, if any if len(files) > 0: transfer_num = self._transfer_manger.transfer( files, endpoint, task_id) if transfer_num is not None: transfer_ETA = time.time() + self.transfer_time( files, endpoint) self._transfer_ETAs[endpoint][transfer_num] = transfer_ETA else: transfer_num = None # Record endpoint ETA for queue-delay prediction here, # since task will be immediately scheduled self._last_task_ETA[endpoint] = choice['ETA'] # If a cold endpoint is being started, mark it as no longer cold, # so that subsequent launch-time predictions are correct (i.e., 0) if self.temperature[endpoint] == 'COLD': self.temperature[endpoint] = 'WARMING' logger.info( 'A cold endpoint {} was chosen; marked as warming.'.format( endpoint_name(endpoint))) # Schedule task for sending to FuncX self._endpoints_sent_to[task_id].append(endpoint) self._scheduled_tasks.put((task_id, endpoint, transfer_num)) return task_id, endpoint def translate_task_id(self, task_id): return self._task_id_translation[task_id] def log_status(self, real_task_id, data): if real_task_id not in self._pending: logger.warn('Ignoring unknown task id {}'.format(real_task_id)) return task_id = self._pending[real_task_id]['task_id'] func = self._pending[real_task_id]['function_id'] endpoint = self._pending[real_task_id]['endpoint_id'] # Don't overwrite latest status if it is a result/exception if task_id not in self._latest_status or \ self._latest_status[task_id].get('status') == 'PENDING': self._latest_status[task_id] = data if 'result' in data: result = self.fx_serializer.deserialize(data['result']) runtime = result['runtime'] name = endpoint_name(endpoint) logger.info('Got result from {} for task {} with time {}'.format( name, real_task_id, runtime)) self.runtime.update(self._pending[real_task_id], runtime) self._pending[real_task_id]['runtime'] = runtime self._record_completed(real_task_id) self.last_result_time[endpoint] = time.time() self._imports[endpoint] = result['imports'] elif 'exception' in data: exception = self.fx_serializer.deserialize(data['exception']) try: exception.reraise() except Exception as e: logger.error('Got exception on task {}: {}'.format( real_task_id, e)) exc_type, _, _ = sys.exc_info() if exc_type in BLOCK_ERRORS: self.block(func, endpoint) self._record_completed(real_task_id) self.last_result_time[endpoint] = time.time() elif 'status' in data and data['status'] == 'PENDING': pass else: logger.error('Unexpected status message: {}'.format(data)) def get_status(self, task_id): if task_id not in self._task_id_translation: logger.warn('Unknown client task id {}'.format(task_id)) elif len(self._task_id_translation[task_id]) == 0: return {'status': 'PENDING'} # Task has not been scheduled yet elif task_id not in self._latest_status: return {'status': 'PENDING'} # Status has not been queried yet else: return self._latest_status[task_id] def queue_delay(self, endpoint): # Otherwise, queue delay is the ETA of most recent task, # plus the estimated error in the ETA prediction. # Note that if there are no pending tasks on endpoint, no queue delay. # This is implicit since, in this case, both summands will be 0. delay = self._last_task_ETA[endpoint] + self._queue_error[endpoint] return max(delay, time.time()) def _record_completed(self, real_task_id): info = self._pending[real_task_id] endpoint = info['endpoint_id'] # If this is the last pending task on this endpoint, reset ETA offset if len(self._pending_by_endpoint[endpoint]) == 1: self._last_task_ETA[endpoint] = 0.0 self._queue_error[endpoint] = 0.0 else: prediction_error = time.time() - self._pending[real_task_id]['ETA'] self._queue_error[endpoint] = prediction_error # print(colored(f'Prediction error {prediction_error}', 'red')) info['ATA'] = time.time() del info['headers'] self.execution_log.append(info) logger.info( 'Task exec time: expected = {:.3f}, actual = {:.3f}'.format( info['ETA'] - info['time_sent'], time.time() - info['time_sent'])) # logger.info(f'ETA_offset = {self._queue_error[endpoint]:.3f}') # Stop tracking this task del self._pending[real_task_id] self._pending_by_endpoint[endpoint].remove(real_task_id) if info['task_id'] in self._task_info: del self._task_info[info['task_id']] def cold_start(self, endpoint, func): # If endpoint is warm, there is no launch time if self.temperature[endpoint] != 'COLD': launch_time = 0.0 # Otherwise, return the launch time in the endpoint config elif 'launch_time' in self._endpoints[endpoint]: launch_time = self._endpoints[endpoint]['launch_time'] else: logger.warn( 'Endpoint {} should always be warm, but is cold'.format( endpoint_name(endpoint))) launch_time = 0.0 # Time to import dependencies import_time = 0.0 for pkg in self._imports_required[func]: if pkg not in self._imports[endpoint]: logger.debug( 'Cold-start has import time for pkg {} on {}'.format( pkg, endpoint_name(endpoint))) import_time += self.import_predictor(pkg, endpoint) return launch_time + import_time def _monitor_tasks(self): logger.info('Starting task-watchdog thread') scheduled = {} while True: time.sleep(self._task_watchdog_sleep) # Get newly scheduled tasks while True: try: task_id, end, num = self._scheduled_tasks.get_nowait() if task_id not in self._task_info: logger.warn( 'Task id {} scheduled but no info found'.format( task_id)) continue info = self._task_info[task_id] scheduled[task_id] = dict(info) # Create new copy of info scheduled[task_id]['task_id'] = task_id scheduled[task_id]['endpoint_id'] = end scheduled[task_id]['transfer_num'] = num except Empty: break # Filter out all tasks whose data transfer has not been completed ready_to_send = set() for task_id, info in scheduled.items(): transfer_num = info['transfer_num'] if transfer_num is None: ready_to_send.add(task_id) info['transfer_time'] = 0.0 elif self._transfer_manger.is_complete(transfer_num): ready_to_send.add(task_id) del self._transfer_ETAs[info['endpoint_id']][transfer_num] info[ 'transfer_time'] = self._transfer_manger.get_transfer_time( transfer_num) # noqa else: # This task cannot be scheduled yet continue if len(ready_to_send) == 0: logger.debug('No new tasks to send. Task watchdog sleeping...') continue # TODO: different clients send different headers. change eventually headers = list(scheduled.values())[0]['headers'] logger.info('Scheduling a batch of {} tasks'.format( len(ready_to_send))) # Submit all ready tasks to FuncX data = {'tasks': []} for task_id in ready_to_send: info = scheduled[task_id] submit_info = (info['function_id'], info['endpoint_id'], info['payload']) data['tasks'].append(submit_info) res_str = requests.post(f'{FUNCX_API}/submit', headers=headers, data=json.dumps(data)) try: res = res_str.json() except ValueError: logger.error(f'Could not parse JSON from {res_str.text}') continue if res['status'] != 'Success': logger.error( 'Could not send tasks to FuncX. Got response: {}'.format( res)) continue # Update task info with submission info for task_id, real_task_id in zip(ready_to_send, res['task_uuids']): info = scheduled[task_id] # This ETA calculation does not take into account transfer time # since, at this point, the transfer has already completed. info['ETA'] = self.strategy.predict_ETA( info['function_id'], info['endpoint_id'], info['payload']) # Record if this ETA prediction is "reliable". If it is not # (e.g., when we have not learned about this (func, ep) pair), # backup tasks will not be sent for this task if it is delayed. info['is_ETA_reliable'] = self.runtime.has_learned( info['function_id'], info['endpoint_id']) info['time_sent'] = time.time() endpoint = info['endpoint_id'] self._task_id_translation[task_id].add(real_task_id) self._pending[real_task_id] = info self._pending_by_endpoint[endpoint].add(real_task_id) # Record endpoint ETA for queue-delay prediction self._last_task_ETA[endpoint] = info['ETA'] logger.info( 'Sent task id {} to {} with real task id {}'.format( task_id, endpoint_name(endpoint), real_task_id)) # Stop tracking all newly sent tasks for task_id in ready_to_send: del scheduled[task_id] def _check_endpoints(self): logger.info('Starting endpoint-watchdog thread') while True: for end in self._endpoints.keys(): statuses = self._fxc.get_endpoint_status(end) if len(statuses) == 0: logger.warn( 'Endpoint {} does not have any statuses'.format( endpoint_name(end))) else: status = statuses[0] # Most recent endpoint status # Mark endpoint as dead/alive based on heartbeat's age # Heartbeats are delayed when an endpoint is executing # tasks, so take into account last execution too age = time.time() - max(status['timestamp'], self.last_result_time[end]) is_dead = end in self._dead_endpoints if not is_dead and age > HEARTBEAT_THRESHOLD: self._dead_endpoints.add(end) logger.warn( 'Endpoint {} seems to have died! ' 'Last heartbeat was {:.2f} seconds ago.'.format( endpoint_name(end), age)) elif is_dead and age <= HEARTBEAT_THRESHOLD: self._dead_endpoints.remove(end) logger.warn( 'Endpoint {} is back alive! ' 'Last heartbeat was {:.2f} seconds ago.'.format( endpoint_name(end), age)) # Mark endpoint as "cold" or "warm" depending on if it # has active managers (nodes) allocated to it if self.temperature[end] == 'WARM' \ and status['active_managers'] == 0: self.temperature[end] = 'COLD' logger.info('Endpoint {} is cold!'.format( endpoint_name(end))) elif self.temperature[end] != 'WARM' \ and status['active_managers'] > 0: self.temperature[end] = 'WARM' logger.info('Endpoint {} is warm again!'.format( endpoint_name(end))) # Send backup tasks if needed self._send_backups_if_needed() # Sleep before checking statuses again time.sleep(5) def _send_backups_if_needed(self): # Get all tasks which have not been completed yet and still have a # pending (real) task on a dead endpoint task_ids = { self._pending[real_task_id]['task_id'] for endpoint in self._dead_endpoints for real_task_id in self._pending_by_endpoint[endpoint] if self._pending[real_task_id]['task_id'] in self._task_info } # Get all tasks for which we had ETA-predictions but haven't # been completed even past their ETA for real_task_id, info in self._pending.items(): # If the predicted ETA wasn't reliable, don't send backups if not info['is_ETA_reliable']: continue expected = info['ETA'] - info['time_sent'] elapsed = time.time() - info['time_sent'] if elapsed / expected > self.backup_delay_threshold: task_ids.add(info['task_id']) for task_id in task_ids: if len(self._endpoints_sent_to[task_id]) > self.max_backups: logger.debug(f'Skipping sending new backup task for {task_id}') else: logger.info(f'Sending new backup task for {task_id}') info = self._task_info[task_id] self._schedule_task(info['function_id'], info['payload'], info['headers'], info['files'], task_id)
event = None endpoint = '68bade94-bf58-4a7a-bfeb-9c6a61fa5443' items_to_batch = [{"func_id": func_uuid, "event": {}}, {"func_id": func_uuid, "event": {}}] x = remote_extract_batch(items_to_batch, endpoint, headers=headers) fx_ser = FuncXSerializer() import time while True: a = remote_poll_batch(x, headers) print(f"The returned: {a}") for tid in a: if "exception" in a[tid]: exception = a[tid]["exception"] print(f"The serialized exception: {exception}") d_exception = fx_ser.deserialize(exception) print(f"The deserialized exception {d_exception}") print("RERAISING EXCEPTION!") d_exception.reraise() #print(fx_ser.deserialize(a[tid]["exception"])) else: print(a[tid]) time.sleep(5.1)
class ExtractorOrchestrator: def __init__(self, funcx_eid, mdata_store_path, source_eid=None, dest_eid=None, gdrive_token=None, extractor_finder='gdrive', prefetch_remote=False, data_prefetch_path=None, dataset_mdata=None): prefetch_remote = False # TODO -- fix this. # self.crawl_type = 'from_file' self.write_cpe = False self.dataset_mdata = dataset_mdata self.t_crawl_start = time.time() self.t_send_batch = 0 self.t_transfer = 0 self.prefetch_remote = prefetch_remote self.data_prefetch_path = data_prefetch_path self.extractor_finder = extractor_finder self.funcx_eid = funcx_eid self.func_dict = { "image": xtract_images.ImageExtractor(), "images": xtract_images.ImageExtractor(), "tabular": xtract_tabular.TabularExtractor(), "text": xtract_keyword.KeywordExtractor(), "matio": xtract_matio.MatioExtractor() } self.fx_ser = FuncXSerializer() self.send_status = "STARTING" self.poll_status = "STARTING" self.commit_completed = False self.source_endpoint = source_eid self.dest_endpoint = dest_eid self.gdrive_token = gdrive_token self.num_families_fetched = 0 self.get_families_start_time = time.time() self.last_checked = time.time() self.pre_launch_counter = 0 self.success_returns = 0 self.failed_returns = 0 self.to_send_queue = Queue() self.poll_gap_s = 5 self.get_families_status = "STARTING" self.task_dict = { "active": Queue(), "pending": Queue(), "failed": Queue() } # Batch size we use to send tasks to funcx. (and the subbatch size) self.map_size = 8 self.fx_batch_size = 16 self.fx_task_sublist_size = 500 # Want to store attributes about funcX requests/responses. self.tot_fx_send_payload_size = 0 self.tot_fx_poll_payload_size = 0 self.tot_fx_poll_result_size = 0 self.num_send_reqs = 0 self.num_poll_reqs = 0 self.t_first_funcx_invoke = None self.max_result_size = 0 # Number (current and max) of number of tasks sent to funcX for extraction. self.max_extracting_tasks = 5 self.num_extracting_tasks = 0 self.max_pre_prefetch = 15000 # TODO: Integrate this to actually fix timing bug. self.status_things = Queue() # If this is turned on, should mean that we hit our local task maximum and don't want to pull down new work... self.pause_q_consume = False self.file_count = 0 self.current_batch = [] self.extract_end = None self.mdata_store_path = mdata_store_path self.n_fams_transferred = 0 self.prefetcher_tid = None self.prefetch_status = None self.fx_headers = { "Authorization": f"Bearer {self.headers['FuncX']}", 'FuncX': self.headers['FuncX'] } self.family_headers = None if 'Petrel' in self.headers: self.fx_headers['Petrel'] = self.headers['Petrel'] self.family_headers = { 'Authorization': f"Bearer {self.headers['Petrel']}", 'Transfer': self.headers['Transfer'], 'FuncX': self.headers['FuncX'], 'Petrel': self.headers['Petrel'] } self.logger = logging.getLogger(__name__) handler = logging.StreamHandler() formatter = logging.Formatter( '%(asctime)s %(name)-12s %(levelname)-8s %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) self.logger.setLevel( logging.INFO) # TODO: let's make this configurable. self.families_to_process = Queue() self.to_validate_q = Queue() self.sqs_push_threads = {} self.thr_ls = [] self.commit_threads = 1 self.get_family_threads = 20 if self.prefetch_remote: self.logger.info("Launching prefetcher...") self.logger.info("Prefetcher successfully launched!") prefetch_thread = threading.Thread( target=self.prefetcher.main_poller_loop, args=()) prefetch_thread.start() for i in range(0, self.commit_threads): thr = threading.Thread(target=self.validate_enqueue_loop, args=(i, )) self.thr_ls.append(thr) thr.start() self.sqs_push_threads[i] = True self.logger.info( f"Successfully started {len(self.sqs_push_threads)} SQS push threads!" ) if self.crawl_type != 'from_file': for i in range(0, self.get_family_threads): self.logger.info( f"Attempting to start get_next_families() as its own thread [{i}]... " ) consumer_thr = threading.Thread( target=self.get_next_families_loop, args=()) consumer_thr.start() print( f"Successfully started the get_next_families() thread number {i} " ) else: print("ATTEMPTING TO LAUNCH **FILE** CRAWL THREAD. ") file_crawl_thr = threading.Thread( target=self.read_next_families_from_file_loop, args=()) file_crawl_thr.start() print("Successfully started the **FILE** CRAWL thread!") for i in range(0, 15): fx_push_thr = threading.Thread(target=self.send_subbatch_thread, args=()) fx_push_thr.start() print("Successfully spun up {i} send threads!") with open("cpe_times.csv", 'w') as f: f.close() def send_subbatch_thread(self): # TODO: THIS IS NEW. HUNGRY STANDALONE THREADPOOL THAT SENDS ALL TASKS. # TODO: SHOULD TERMINATE WHEN COMPLETED <-- do after paper deadline. while True: sub_batch = [] for i in range(10): if self.to_send_queue.empty(): break # I believe this is a blocking call. part_batch = self.to_send_queue.get() sub_batch.extend(part_batch) if len(sub_batch) == 0: time.sleep(0.5) continue batch_send_t = time.time() task_ids = remote_extract_batch(sub_batch, ep_id=self.funcx_eid, headers=self.fx_headers) batch_recv_t = time.time() print(f"Time to send batch: {batch_recv_t - batch_send_t}") self.num_send_reqs += 1 self.pre_launch_counter -= len(sub_batch) if type(task_ids) is dict: self.logger.exception( f"Caught funcX error: {task_ids['exception_caught']}. \n" f"Putting the tasks back into active queue for retry") for reject_fam_batch in self.current_batch: fam_batch_dict = reject_fam_batch['event']['family_batch'] for reject_fam in fam_batch_dict['families']: self.families_to_process.put(json.dumps(reject_fam)) self.logger.info(f"Pausing for 10 seconds...") for task_id in task_ids: self.task_dict["active"].put(task_id) self.num_extracting_tasks += 1 time.sleep(0.5) def validate_enqueue_loop(self, thr_id): self.logger.debug("[VALIDATE] In validation enqueue loop!") while True: insertables = [] # If empty, then we want to return. if self.to_validate_q.empty(): # If ingest queue empty, we can demote to "idle" if self.poll_status == "COMMITTING": self.sqs_push_threads[thr_id] = "IDLE" print(f"Thread {thr_id} is committing and idle!") time.sleep(0.25) # NOW if all threads idle, then return! if all(value == "IDLE" for value in self.sqs_push_threads.values()): self.commit_completed = True self.poll_status = "COMPLETED" self.extract_end = time.time() self.logger.info(f"Thread {thr_id} is terminating!") return 0 self.logger.info("[Validate]: sleeping for 5 seconds... ") time.sleep(5) continue self.sqs_push_threads[thr_id] = "ACTIVE" # Remove up to n elements from queue, where n is current_batch. current_batch = 1 while not self.to_validate_q.empty( ) and current_batch < 8: # TODO: manual downsize so fits on Q. item_to_add = self.to_validate_q.get() # TODO: THIS IS PULLING THINGS BEFORE THEY GET TO VALIDATE QUEUE. insertables.append(item_to_add) current_batch += 1 # TODO: boot all of this out to file. if self.write_cpe: with open("cpe_times.csv", 'a') as f: csv_writer = csv.writer(f) for item in insertables: fam_batch = json.loads(item['MessageBody']) for family in fam_batch['families']: # print(family) # crawl_timestamp = family['metadata']['crawl_timestamp'] pf_timestamp = family['metadata'][ 'pf_transfer_completed'] fx_timestamp = family['metadata'][ 't_funcx_req_received'] total_file_size = 0 all_files = family['files'] total_files = len(all_files) for file_obj in all_files: total_file_size += file_obj['metadata'][ 'physical']['size'] csv_writer.writerow([ 'x', 0, pf_timestamp, fx_timestamp, total_files, total_file_size ]) # TODO: investigate this. Why is this here? # try: # ta = time.time() # self.client.send_message_batch(QueueUrl=self.validation_queue_url, # Entries=insertables) # tb = time.time() # self.t_send_batch += tb-ta # # except Exception as e: # TODO: too vague # print(f"WAS UNABLE TO PROPERLY CONNECT to SQS QUEUE: {e}") def send_families_loop(self): self.send_status = "RUNNING" last_send = time.time() while True: # print(f"Time since last send: {last_send - time.time()}") last_send = time.time() if self.num_extracting_tasks > self.max_extracting_tasks: self.logger.info( f"[SECOND] Num. active tasks ({self.num_extracting_tasks}) " f"above threshold. Sleeping for 1 second and continuing..." ) time.sleep(1) continue if self.prefetch_remote: while not self.prefetcher.orch_reader_q.empty(): family = self.prefetcher.orch_reader_q.get() family_size = self.prefetcher.get_family_size( json.loads(family)) self.prefetcher.bytes_pf_completed -= family_size self.prefetcher.orch_unextracted_bytes += family_size self.pre_launch_counter += 1 self.families_to_process.put(family) family_list = [] # Now keeping filling our list of families until it is empty. while len( family_list ) < self.fx_batch_size and not self.families_to_process.empty(): family_list.append(self.families_to_process.get()) if len(family_list) == 0: # Here we check if the crawl is complete. If so, then we can start the teardown checks. status_dict = get_crawl_status(self.crawl_id) if status_dict['crawl_status'] in ["SUCCEEDED", "FAILED"]: # Checking second time due to narrow race condition. if self.families_to_process.empty( ) and self.get_families_status == "IDLE": if self.prefetch_remote and self.prefetch_status in [ "SUCCEEDED", "FAILED" ]: self.send_status = "SUCCEEDED" self.logger.info( "[SEND] Queue still empty -- terminating!") # this should terminate thread, because there is nothing to process and queue empty return else: # Something snuck in during the race condition... process it! self.logger.info( "[SEND] Discovered final output despite crawl termination. Processing..." ) time.sleep( 0.5 ) # This is a multi-minute task and only is reached due to starvation. # self.map_size = # Cast list to FamilyBatch family_batch = FamilyBatch() for family in family_list: # Get extractor out of each group if self.extractor_finder == 'matio': d_type = "HTTPS" extr_code = 'matio' xtr_fam_obj = Family(download_type=d_type) xtr_fam_obj.from_dict(json.loads(family)) xtr_fam_obj.headers = self.family_headers # TODO: kick this logic for finding extractor into sdk/crawler. elif self.extractor_finder == 'gdrive': d_type = 'gdrive' xtr_fam_obj = Family(download_type=d_type) xtr_fam_obj.from_dict(json.loads(family)) xtr_fam_obj.headers = self.headers extr_code = xtr_fam_obj.groups[list( xtr_fam_obj.groups.keys())[0]].parser else: raise ValueError( f"Incorrect extractor_finder arg: {self.extractor_finder}" ) # TODO: add the decompression work and the hdf5/netcdf extractors! if extr_code is None or extr_code == 'hierarch' or extr_code == 'compressed': continue extractor = self.func_dict[extr_code] # TODO TYLER: Get the proper function ID here!!! # ex_func_id = extractor.func_id ex_func_id = mapping['xtract-matio::midway2']['func_uuid'] # Putting into family batch -- we use funcX batching now, but no use rewriting... # family_batch = FamilyBatch() family_batch.add_family(xtr_fam_obj) # print(f"Length of family batch: {len(family_batch.families)}") if len(family_batch.families) >= self.map_size: if d_type == "gdrive": self.current_batch.append({ "event": { "family_batch": family_batch, "creds": self.gdrive_token[0] }, "func_id": ex_func_id }) elif d_type == "HTTPS": self.current_batch.append({ "event": { "family_batch": family_batch.to_dict() }, "func_id": ex_func_id }) family_batch = FamilyBatch() # Catch any tasks currently in the map and append them to the batch if len(family_batch.families) > 0: if d_type == "gdrive": self.current_batch.append({ "event": { "family_batch": family_batch, "creds": self.gdrive_token[0] }, "func_id": ex_func_id }) elif d_type == "HTTPS": self.current_batch.append({ "event": { "family_batch": family_batch.to_dict() }, "func_id": ex_func_id }) # Now take that straggling family batch and append it. if len(self.current_batch) > 0: if self.t_first_funcx_invoke is None: self.t_first_funcx_invoke = time.time() req_size = 0 for item in self.current_batch: req_size += sys.getsizeof(item) req_size += 2 * sys.getsizeof( self.funcx_eid) # need size of fx ep and function id. req_size += sys.getsizeof(self.fx_headers) self.tot_fx_send_payload_size += req_size sub_batches = create_list_chunks(self.current_batch, self.fx_task_sublist_size) # print(f"Total sub_batches: {len(sub_batches)}") send_threads = [] # for sub_batch in sub_batches: self.to_send_queue.put(self.current_batch) # i = 0 # # TODO: THESE NEED TO BE STANDALONE THREADPOOL # for subbatch in sub_batches: # send_thr = threading.Thread(target=self.send_subbatch_thread, args=(subbatch,)) # send_thr.start() # send_threads.append(send_thr) # i += 1 # if i > 0: # print(f"Spun up {i} task-send threads!") # # for thr in send_threads: # thr.join() # Empty the batch! Everything in here has been sent :) self.current_batch = [] def launch_poll(self): self.logger.info("Launching poller...") po_thr = threading.Thread(target=self.poll_extractions_and_stats, args=()) po_thr.start() self.logger.info("Poller successfully launched!") def read_next_families_from_file_loop(self): """ This loads saved crawler state (from a json file) and quickly adds all files to our local families_to_process queue. This avoids making calls to SQS to retrieve crawl results. """ with open( '/Users/tylerskluzacek/PycharmProjects/xtracthub-service/experiments/tyler_200k.json', 'r') as f: all_families = json.load(f) num_fams = 0 for family in all_families: self.families_to_process.put(json.dumps(family)) num_fams += 1 if num_fams > self.task_cap_until_termination: break # self.prefetcher.kill_prefetcher = True # self.prefetcher.last_batch = True # TODO: bring this back for prefetcher. print("ENDING LOOP") def launch_extract(self): ex_thr = threading.Thread(target=self.send_families_loop, args=()) ex_thr.start() def unpack_returned_family_batch(self, family_batch): fam_batch_dict = family_batch.to_dict() return fam_batch_dict def update_and_print_stats(self): # STAT CHECK: if we haven't updated stats in 5 seconds, then we update. cur_time = time.time() if cur_time - self.last_checked > 5: # TODO: all the commented-out jazz here should really be 'if prefetch_remote'. # total_bytes = self.prefetcher.orch_unextracted_bytes + \ # self.prefetcher.bytes_pf_completed + \ # self.prefetcher.bytes_pf_in_flight print("Phase 4: polling") print(f"\t Successes: {self.success_returns}") print(f"\t Failures: {self.failed_returns}\n") self.logger.debug( f"[VALIDATE] Length of validation queue: {self.to_validate_q.qsize()}" ) # n_pulled = self.prefetcher.next_prefetch_queue.qsize() # n_pulled_per_sec = self.num_families_fetched / (cur_time - self.get_families_start_time) # n_pf_batch = self.prefetcher.pf_msgs_pulled_since_last_batch # n_families_pf_per_sec = self.prefetcher.num_families_transferred / (cur_time - self.get_families_start_time) # n_pf = self.prefetcher.num_families_mid_transfer # n_awaiting_fx = self.pre_launch_counter + self.prefetcher.orch_reader_q.qsize() n_in_fx = self.num_extracting_tasks n_success = self.success_returns # total_tracked = n_success + n_in_fx + n_awaiting_fx + n_pf + n_pf_batch + n_pulled # self.logger.info(f"\n** TASK LOCATION BREAKDOWN **\n" # f"--Pulled: {n_pulled}\t|\t({n_pulled_per_sec}/s)\n" # f"--In pf batching: {n_pf_batch}\n" # f"--Prefetching: {n_pf}\t|\t({n_families_pf_per_sec})\n" # f"--Awaiting extraction: {n_awaiting_fx}\n" # f"--In extraction: {n_in_fx}\n" # f"--Completed: {n_success}\n" # f"\n-- Fetch-Track Delta: {self.num_families_fetched - total_tracked}\n") if self.prefetch_remote: # print(f"\t Eff. dir size (GB): {total_bytes / 1024 / 1024 / 1024}") print(f"\t N Transfer Tasks: {self.prefetcher.num_transfers}") if self.num_send_reqs > 0 and self.num_poll_reqs > 0 and n_success > 0: self.logger.info( f"\n** funcX Stats **\n" f"Num. Send Requests: {self.num_send_reqs}\t|\t({time.time() - self.t_first_funcx_invoke})\n" f"Num. Poll Requests: {self.num_poll_reqs}\t|\t()\n" f"Avg. Send Request Size: {self.tot_fx_send_payload_size / self.num_send_reqs}\n" f"Avg. Poll Request Size: {self.tot_fx_poll_payload_size / self.num_poll_reqs}\n" f"Avg. Result Size: {self.tot_fx_poll_result_size / n_success}\n" f"Max Result Size: {self.max_result_size}\n") self.last_checked = time.time() print( f"[GET] Elapsed Extract time: {time.time() - self.t_crawl_start}" ) def poll_batch_chunks(self, thr_num, sublist, headers): status_thing = remote_poll_batch(sublist, headers) self.num_poll_reqs += 1 status_tup = (thr_num, status_thing) self.status_things.put(status_tup) time.sleep( 1 ) # TODO: MIGHT WANT TO TAKE THIS OUT. JUST SEEING IF THIS FIXES funcX SCALE-UP. return def poll_extractions_and_stats(self): mod_not_found = 0 type_errors = 0 self.poll_status = "RUNNING" while True: # This will attempt to print stats on every iteration. # Note: internally, this will only execute max every 5 seconds. self.update_and_print_stats() if self.task_dict["active"].empty(): if self.send_status == "SUCCEEDED": print("[POLL] Send status is SUCCEEDED... ") print( f"[POLL] Active tasks: {self.task_dict['active'].empty()}" ) if self.task_dict["active"].empty( ): # check second time b/c of rare r.c. print( "Extraction completed. Upgrading status to COMMITTING." ) self.poll_status = "COMMITTING" return self.logger.debug("No live IDs... sleeping...") time.sleep(10) continue # Here we pull all values from active task queue to create a batch of them! num_elem = self.task_dict["active"].qsize() print(f"[POLL] Size active queue: {num_elem}" ) # TODO: need to shush this. tids_to_poll = [] # the batch for i in range(0, num_elem): ex_id = self.task_dict["active"].get() tids_to_poll.append(ex_id) t_last_poll = time.time() # Send off task_ids to poll, retrieve a bunch of statuses. tid_sublists = create_list_chunks(tids_to_poll, self.fx_task_sublist_size) # TODO: ^^ this'll need to be a dictionary of some kind so we can track. polling_threads = [] thr_to_sublist_map = dict() i = 0 for tid_sublist in tid_sublists: # TODO: put a cap on this (if i>10, break (for example) # If there's an actual thing in the list... if len(tid_sublist) > 0 and i < 15: self.tot_fx_poll_payload_size += sys.getsizeof(tid_sublist) self.tot_fx_poll_payload_size += sys.getsizeof( self.fx_headers) * len(tid_sublist) thr = threading.Thread(target=self.poll_batch_chunks, args=(i, tid_sublist, self.fx_headers)) thr.start() polling_threads.append(thr) # NEW STEP: add the thread ID to the dict thr_to_sublist_map[i] = tid_sublist i += 1 print(f"[POLL] Spun up {i} polling threads!") # Now we should wait for our fan-out threads to fan-in for thread in polling_threads: thread.join() # TODO: here is where we can fix this # TODO: Create tuples in status things that's a tuple with tid_sublist (OR THREAD ID) # that the status thing is about. while not self.status_things.empty(): thr_id, status_thing = self.status_things.get() if "exception_caught" in status_thing: self.logger.exception( f"Caught funcX error: {status_thing['exception_caught']}" ) self.logger.exception( f"Putting the tasks back into active queue for retry") for reject_tid in thr_to_sublist_map[thr_id]: self.task_dict["active"].put(reject_tid) print(f"Pausing for 20 seconds...") self.pause_q_consume = True time.sleep(20) continue else: self.pause_q_consume = False for tid in status_thing: task_obj = status_thing[tid] if "result" in task_obj: res = self.fx_ser.deserialize(task_obj['result']) # Decrement number of tasks being extracted! self.num_extracting_tasks -= 1 if "family_batch" in res: family_batch = res["family_batch"] unpacked_metadata = self.unpack_returned_family_batch( family_batch) # print(unpacked_metadata) # TODO: make this a regular feature for matio (so this code isn't necessary...) if 'event' in unpacked_metadata: family_batch = unpacked_metadata['event'][ 'family_batch'] unpacked_metadata = family_batch.to_dict() unpacked_metadata[ 'dataset'] = self.dataset_mdata if self.prefetch_remote: total_family_size = 0 for family in family_batch.families: # family['metadata']["t_funcx_req_received"] = time.time() total_family_size += self.prefetcher.get_family_size( family.to_dict()) self.prefetcher.orch_unextracted_bytes -= total_family_size for family in unpacked_metadata['families']: family['metadata'][ "t_funcx_req_received"] = time.time() json_mdata = json.dumps(unpacked_metadata, cls=NumpyEncoder) result_size = sys.getsizeof(json_mdata) self.tot_fx_poll_result_size += result_size if result_size > self.max_result_size: self.max_result_size = result_size try: self.to_validate_q.put({ "Id": str(self.file_count), "MessageBody": json_mdata }) self.file_count += 1 except TypeError as e1: self.logger.exception(f"Type error: {e1}") type_errors += 1 self.logger.exception( f"Total type errors: {type_errors}") else: self.logger.error( f"[Poller]: \"family_batch\" not in res!") self.success_returns += 1 # Leave this in. Great for debugging and we have the IO for it. self.logger.debug(f"Received response: {res}") if type(res) is not dict: self.logger.exception(f"Res is not dict: {res}") continue elif 'transfer_time' in res: if res['transfer_time'] > 0: self.t_transfer += res['transfer_time'] self.n_fams_transferred += 1 print( f"Avg. Transfer time: {self.t_transfer/self.n_fams_transferred}" ) # This just means fixing Google Drive extractors... elif 'trans_time' in res: if res['trans_time'] > 0: self.t_transfer += res['trans_time'] self.n_fams_transferred += 1 print( f"Avg. Transfer time: {self.t_transfer/self.n_fams_transferred}" ) elif "exception" in task_obj: exc = self.fx_ser.deserialize(task_obj['exception']) try: exc.reraise() except ModuleNotFoundError: mod_not_found += 1 self.logger.exception( f"Num. ModuleNotFound: {mod_not_found}") except Exception as e: self.logger.exception(f"Caught exception: {e}") self.logger.exception("Continuing!") pass self.failed_returns += 1 self.logger.exception(f"Exception caught: {exc}") else: self.task_dict["active"].put(tid) t_since_last_poll = time.time() - t_last_poll t_poll_gap = self.poll_gap_s - t_since_last_poll if t_poll_gap > 0: time.sleep(t_poll_gap)
# TODO: Add ghetto-retry for pending tasks to catch lost ones. # TODO x10000: move this logic timeout = 120 failed_counter = 0 while True: if task_dict["active"].empty(): print("Active task queue empty... sleeping... ") time.sleep(0.5) break # This should jump out to main_loop cur_tid = task_dict["active"].get() print(cur_tid) status_thing = requests.get(get_url.format(cur_tid), headers=headers).json() if 'result' in status_thing: result = fx_ser.deserialize(status_thing['result']) print(f"Result: {result}") task_dict["results"].append(result) print(len(task_dict["results"])) elif 'exception' in status_thing: print( f"Exception: {fx_ser.deserialize(status_thing['exception'])}") # break else: task_dict["active"].put(cur_tid) # Manager: f2ab81f26f40
class FuncXClient(throttling.ThrottledBaseClient): """Main class for interacting with the funcX service Holds helper operations for performing common tasks with the funcX service. """ TOKEN_DIR = os.path.expanduser("~/.funcx/credentials") TOKEN_FILENAME = 'funcx_sdk_tokens.json' CLIENT_ID = '4cf29807-cf21-49ec-9443-ff9a3fb9f81c' def __init__(self, http_timeout=None, funcx_home=os.path.join('~', '.funcx'), force_login=False, fx_authorizer=None, funcx_service_address='https://funcx.org/api/v1', **kwargs): """ Initialize the client Parameters ---------- http_timeout: int Timeout for any call to service in seconds. Default is no timeout force_login: bool Whether to force a login to get new credentials. fx_authorizer:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`: A custom authorizer instance to communicate with funcX. Default: ``None``, will be created. service_address: str The address of the funcX web service to communicate with. Default: https://dev.funcx.org/api/v1 Keyword arguments are the same as for BaseClient. """ self.func_table = {} self.ep_registration_path = 'register_endpoint_2' self.funcx_home = os.path.expanduser(funcx_home) if not os.path.exists(self.TOKEN_DIR): os.makedirs(self.TOKEN_DIR) tokens_filename = os.path.join(self.TOKEN_DIR, self.TOKEN_FILENAME) self.native_client = NativeClient( client_id=self.CLIENT_ID, app_name="FuncX SDK", token_storage=JSONTokenStorage(tokens_filename)) fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" if not fx_authorizer: self.native_client.login( requested_scopes=[fx_scope], no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True), refresh_tokens=kwargs.get("refresh_tokens", True), force=force_login) all_authorizers = self.native_client.get_authorizers_by_scope( requested_scopes=[fx_scope]) fx_authorizer = all_authorizers[fx_scope] super(FuncXClient, self).__init__("funcX", environment='funcx', authorizer=fx_authorizer, http_timeout=http_timeout, base_url=funcx_service_address, **kwargs) self.fx_serializer = FuncXSerializer() def logout(self): """Remove credentials from your local system """ self.native_client.logout() def update_table(self, return_msg, task_id): """ Parses the return message from the service and updates the internal func_tables Parameters ---------- return_msg : str Return message received from the funcx service task_id : str task id string """ if isinstance(return_msg, str): r_dict = json.loads(return_msg) else: r_dict = return_msg status = {'pending': True} if 'result' in r_dict: try: r_obj = self.fx_serializer.deserialize(r_dict['result']) except Exception: raise Exception( "Failure during deserialization of the result object") else: status.update({'pending': 'False', 'result': r_obj}) self.func_table[task_id] = status elif 'exception' in r_dict: try: r_exception = self.fx_serializer.deserialize( r_dict['exception']) logger.info(f"Exception : {r_exception}") except Exception: raise Exception( "Failure during deserialization of the Task's exception object" ) else: status.update({'pending': 'False', 'exception': r_exception}) self.func_table[task_id] = status return status def get_task_status(self, task_id): """Get the status of a funcX task. Parameters ---------- task_id : str UUID of the task Returns ------- dict Status block containing "status" key. """ if task_id in self.func_table: return self.func_table[task_id] r = self.get("{task_id}/status".format(task_id=task_id)) logger.debug("Response string : {}".format(r)) try: rets = self.update_table(r.text, task_id) except Exception as e: raise e return rets def get_result(self, task_id): """ Get the result of a funcX task Parameters ---------- task_id: str UUID of the task Returns ------- Result obj: If task completed Raises ------ Exception obj: Exception due to which the task failed """ status = self.get_task_status(task_id) if status['pending'] is True: raise Exception("Task pending") else: if 'result' in status: return status['result'] else: logger.warn("We have an exception : {}".format( status['exception'])) status['exception'].reraise() def get_batch_status(self, task_id_list): """ Request status for a batch of task_ids """ assert isinstance(task_id_list, list), "get_batch_status expects a list of task ids" pending_task_ids = [ t for t in task_id_list if t not in self.func_table ] results = {} if pending_task_ids: payload = {'task_ids': pending_task_ids} r = self.post("/batch_status", json_body=payload) logger.debug("Response string : {}".format(r)) pending_task_ids = set(pending_task_ids) for task_id in task_id_list: if task_id in pending_task_ids: try: data = r['results'][task_id] rets = self.update_table(data, task_id) results[task_id] = rets except KeyError: logger.debug( "Task {} info was not available in the batch status") except Exception as e: logger.exception( "Failure while unpacking results fom get_batch_status") else: results[task_id] = self.func_table[task_id] return results def get_batch_result(self, task_id_list): """ Request results for a batch of task_ids """ pass def run(self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task """ servable_path = 'submit' assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" ser_args = self.fx_serializer.serialize(args) ser_kwargs = self.fx_serializer.serialize(kwargs) payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs]) data = { 'endpoint': endpoint_id, 'func': function_id, 'payload': payload, 'is_async': asynchronous } # Send the data to funcX r = self.post(servable_path, json_body=data) if r.http_status is not 200: raise Exception(r) if 'task_uuid' not in r: raise MalformedResponse(r) """ Create a future to deal with the result funcx_future = FuncXFuture(self, task_id, async_poll) if not asynchronous: return funcx_future.result() # Return the result return funcx_future """ return r['task_uuid'] def map_run(self, *args, endpoint_id=None, function_id=None, asynchronous=False, **kwargs): """Initiate an invocation Parameters ---------- *args : Any Args as specified by the function signature endpoint_id : uuid str Endpoint UUID string. Required function_id : uuid str Function UUID string. Required asynchronous : bool Whether or not to run the function asynchronously Returns ------- task_id : str UUID string that identifies the task """ servable_path = 'submit_batch' assert endpoint_id is not None, "endpoint_id key-word argument must be set" assert function_id is not None, "function_id key-word argument must be set" ser_kwargs = self.fx_serializer.serialize(kwargs) batch_payload = [] iterator = args[0] for arg in iterator: ser_args = self.fx_serializer.serialize((arg, )) payload = self.fx_serializer.pack_buffers([ser_args, ser_kwargs]) batch_payload.append(payload) data = { 'endpoints': [endpoint_id], 'func': function_id, 'payload': batch_payload, 'is_async': asynchronous } # Send the data to funcX r = self.post(servable_path, json_body=data) if r.http_status is not 200: raise Exception(r) if 'task_uuids' not in r: raise MalformedResponse(r) return r['task_uuids'] def register_endpoint(self, name, endpoint_uuid, description=None): """Register an endpoint with the funcX service. Parameters ---------- name : str Name of the endpoint endpoint_uuid : str The uuid of the endpoint description : str Description of the endpoint Returns ------- A dict {'endopoint_id' : <>, 'address' : <>, 'client_ports': <>} """ data = { "endpoint_name": name, "endpoint_uuid": endpoint_uuid, "description": description } r = self.post(self.ep_registration_path, json_body=data) if r.http_status is not 200: raise Exception(r) # Return the result return r.data def get_containers(self, name, description=None): """Register a DLHub endpoint with the funcX service and get the containers to launch. Parameters ---------- name : str Name of the endpoint description : str Description of the endpoint Returns ------- int The port to connect to and a list of containers """ registration_path = 'get_containers' data = {"endpoint_name": name, "description": description} r = self.post(registration_path, json_body=data) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['endpoint_uuid'], r.data['endpoint_containers'] def get_container(self, container_uuid, container_type): """Get the details of a container for staging it locally. Parameters ---------- container_uuid : str UUID of the container in question container_type : str The type of containers that will be used (Singularity, Shifter, Docker) Returns ------- dict The details of the containers to deploy """ container_path = f'containers/{container_uuid}/{container_type}' r = self.get(container_path) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['container'] def register_function(self, function, function_name=None, container_uuid=None, description=None): """Register a function code with the funcX service. Parameters ---------- function : Python Function The function to be registered for remote execution function_name : str The entry point (function name) of the function. Default: None container_uuid : str Container UUID from registration with funcX description : str Description of the file Returns ------- function uuid : str UUID identifier for the registered function """ registration_path = 'register_function' serialized_fn = self.fx_serializer.serialize(function) packed_code = self.fx_serializer.pack_buffers([serialized_fn]) data = { "function_name": function.__name__, "function_code": packed_code, "container_uuid": container_uuid, "entry_point": function_name if function_name else function.__name__, "description": description } logger.info("Registering function : {}".format(data)) r = self.post(registration_path, json_body=data) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['function_uuid'] def register_container(self, location, container_type, name='', description=''): """Register a container with the funcX service. Parameters ---------- location : str The location of the container (e.g., its docker url). Required container_type : str The type of containers that will be used (Singularity, Shifter, Docker). Required name : str A name for the container. Default = '' description : str A description to associate with the container. Default = '' Returns ------- str The id of the container """ container_path = f'containers' payload = { 'name': name, 'location': location, 'description': description, 'type': container_type } r = self.post(container_path, json_body=payload) if r.http_status is not 200: raise Exception(r) # Return the result return r.data['container_id']