def get_result(task_id): """Get the result of the function. """ fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1') res = fxc.get_result(task_id) print(res) return res
def __init__( self, fx_auth, search_auth, openid_auth, endpoint_id, func, expected, args=None, timeout=15, concurrency=1, tol=1e-5, ): self.endpoint_id = endpoint_id self.func = func self.expected = expected self.args = args self.timeout = timeout self.concurrency = concurrency self.tol = tol self.fxc = FuncXClient( fx_authorizer=fx_auth, search_authorizer=search_auth, openid_authorizer=openid_auth, ) self.func_uuid = self.fxc.register_function(self.func) self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.DEBUG) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) formatter = logging.Formatter( "%(asctime)s %(name)s:%(lineno)d [%(levelname)s] %(message)s") handler.setFormatter(formatter) self.logger.addHandler(handler)
def get_deser_result(task_id): """Get the result of the function. """ fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1') res = fxc.get(f"/tasks/{task_id}?deserialize=True") print(res) return res
def run_function_wait_result( py_fn, py_fn_args, py_fn_kwargs={}, endpoint_id="3c3f0b4f-4ae4-4241-8497-d7339972ff4a", print_status=True): """ Register and run a function with FuncX, wait for execution, and return results when they are available :param py_fn: Handle of Python function :param py_fn_args: List of positional args for py function :param py_fn_kwargs: Dict of keyword args for py function, :param endpoint_id: ID of endpoint to run command on - must be configured in config.py """ fxc = FuncXClient() func_uuid = fxc.register_function(py_fn) res = fxc.run(*py_fn_args, **py_fn_kwargs, endpoint_id=endpoint_id, function_id=func_uuid) while True: try: if print_status: print("Waiting for results...") time.sleep(FUNCX_SLEEP_TIME) return str(fxc.get_result(res), encoding="utf-8") break except Exception as e: if "waiting-for-" in str(e): continue else: raise e
def register_container(): from funcx.sdk.client import FuncXClient fxc = FuncXClient() from gladier_xpcs.tools.corr import eigen_corr cont_dir = '/eagle/APSDataAnalysis/XPCS_test/containers/' container_name = 'eigen_v2.simg' eigen_cont_id = fxc.register_container(location=cont_dir+container_name,container_type='singularity') corr_cont_fxid = fxc.register_function(eigen_corr, container_uuid=eigen_cont_id) return corr_cont_fxid
def __init__(self, dlh_authorizer=None, search_client=None, http_timeout=None, force_login=False, fx_authorizer=None, **kwargs): """Initialize the client Args: dlh_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with DLHub. If ``None``, will be created. search_client (:class:`SearchClient <globus_sdk.SearchClient>`): An authenticated SearchClient to communicate with Globus Search. If ``None``, will be created. http_timeout (int): Timeout for any call to service in seconds. (default is no timeout) force_login (bool): Whether to force a login to get new credentials. A login will always occur if ``dlh_authorizer`` or ``search_client`` are not provided. no_local_server (bool): Disable spinning up a local server to automatically copy-paste the auth code. THIS IS REQUIRED if you are on a remote server. When used locally with no_local_server=False, the domain is localhost with a randomly chosen open port number. **Default**: ``True``. fx_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with funcX. If ``None``, will be created. no_browser (bool): Do not automatically open the browser for the Globus Auth URL. Display the URL instead and let the user navigate to that location manually. **Default**: ``True``. Keyword arguments are the same as for BaseClient. """ if force_login or not dlh_authorizer or not search_client or not fx_authorizer: fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" auth_res = login(services=["search", "dlhub", fx_scope], app_name="DLHub_Client", client_id=CLIENT_ID, clear_old_tokens=force_login, token_dir=_token_dir, no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True)) dlh_authorizer = auth_res["dlhub"] fx_authorizer = auth_res[fx_scope] self._search_client = auth_res["search"] self._fx_client = FuncXClient(force_login=True,fx_authorizer=fx_authorizer, funcx_service_address='https://funcx.org/api/v1') # funcX endpoint to use self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000' # self.fx_endpoint = '2c92a06a-015d-4bfa-924c-b3d0c36bdad7' self.fx_serializer = FuncXSerializer() self.fx_cache = {} super(DLHubClient, self).__init__("DLHub", environment='dlhub', authorizer=dlh_authorizer, http_timeout=http_timeout, base_url=DLHUB_SERVICE_ADDRESS, **kwargs)
def handle(self, *args, **options): if options['register']: fxc = FuncXClient() ep = fxc.register_function(process_hdfs, description="Process an hdf") self.stderr.write(f'FuncX function endpoint has been ' f'registered: {ep}') self.stderr.write(f'You need to add this somewhere manually!') elif options['test']: name = options['test'] if not name: self.stderr.write('test needs the name of a search collection') bag = Bag.objects.filter( search_collection__name=options['test']).first() if bag: action = ReprocessingTask.new_action(bag, user=bag.user) action.save() rt = ReprocessingTask(bag=bag, action=action) rt.save() rt.action.start_flow() self.stdout.write(f'Started {action}') else: bags = [b.search_collection.name for b in Bag.objects.all()] self.stderr.write(f'No bag named {options["test"]}, please ' f'use one of the following instead {bags}') elif options['check']: rts = ReprocessingTask.objects.filter(action__status='ACTIVE') if not rts: self.stderr.write('No Tasks to update.') for rt in rts: old = rt.action.status rt.action.update_flow() self.stdout.write(f'Updated {rt.bag.search_collection.name} ' f'from "{old}" to "{rt.action.status}".') elif options.get('payload') is not None: raise NotImplementedError('This does not work yet...') pl = self.get_task_or_list_all(options['payload']).action.payload plain_pl = deserialize_payload(pl['ProcessDataInput']['payload']) pprint(plain_pl) elif options.get('output') is not None: task = self.get_task_or_list_all(options['output']) if not task: return automate_output = task.action.cache['details']['output'] outputs = [ data['details'] for name, data in automate_output.items() if 'details' in data ] for output in outputs: if 'result' in output: pprint(deserialize_payload(output['result'])) elif 'exception' in output: deserialize_payload(output['exception']).reraise()
def run_real(payload): """Run a function with some raw json input. """ fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1') # register a function func_id = fxc.register_function(test_func) ep_id = '60ad46e1-c912-468b-8674-4d582e9dc9ee' res = fxc.run({'name': 'real'}, function_id=func_id, endpoint_id=ep_id) print(res) return res
def init_endpoint(): """Setup funcx dirs and default endpoint config files TODO : Every mechanism that will update the config file, must be using a locking mechanism, ideally something like fcntl https://docs.python.org/3/library/fcntl.html to ensure that multiple endpoint invocations do not mangle the funcx config files or the lockfile module. """ _ = FuncXClient() if os.path.exists(State.FUNCX_CONFIG_FILE): typer.confirm( "Are you sure you want to initialize this directory? " f"This will erase everything in {State.FUNCX_DIR}", abort=True ) logger.info("Wiping all current configs in {}".format(State.FUNCX_DIR)) backup_dir = State.FUNCX_DIR + ".bak" try: logger.debug(f"Removing old backups in {backup_dir}") shutil.rmtree(backup_dir) except OSError: pass os.renames(State.FUNCX_DIR, backup_dir) if os.path.exists(State.FUNCX_CONFIG_FILE): logger.debug("Config file exists at {}".format(State.FUNCX_CONFIG_FILE)) return try: os.makedirs(State.FUNCX_DIR, exist_ok=True) except Exception as e: print("[FuncX] Caught exception during registration {}".format(e)) shutil.copyfile(State.FUNCX_DEFAULT_CONFIG_TEMPLATE, State.FUNCX_CONFIG_FILE) init_endpoint_dir("default")
def run_ser(payload): """Run a function with some raw json input. """ fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1') # register a function func_id = fxc.register_function(test_func) ep_id = '60ad46e1-c912-468b-8674-4d582e9dc9ee' payload = { 'serialize': True, 'payload': payload, 'endpoint': ep_id, 'func': func_id } res = fxc.post('submit', json_body=payload) res = res['task_uuid'] print(res) return res
def run_function_async(py_fn, py_fn_args, py_fn_kwargs={}, endpoint_id="3c3f0b4f-4ae4-4241-8497-d7339972ff4a"): """ Asynchronously register and run a Python function on a FuncX endpoint :param py_fn: Handle of Python function :param py_fn_args: List of positional args for py function :param py_fn_kwargs: Dict of keyword args for py function, :param endpoint_id: ID of endpoint to run command on - must be configured in config.py """ # Use return value for Funcx polling fxc = FuncXClient() func_uuid = fxc.register_function(py_fn) res = fxc.run(*py_fn_args, **kwargs, endpoint_id=endpoint_id, function_id=func_uuid) return res
class FuncXFuture(Future): client = FuncXClient() serializer = FuncXSerializer() def __init__(self, task_id, poll_period=1): super().__init__() self.task_id = task_id self.poll_period = poll_period self.__result = None self.submitted = time.time() def done(self): if self.__result is not None: return True try: data = FuncXFuture.client.get_task_status(self.task_id) except Exception: return False if 'status' in data and data['status'] == 'PENDING': time.sleep( self.poll_period) # needed to not overwhelm the FuncX server return False elif 'result' in data: self.__result = FuncXFuture.serializer.deserialize(data['result']) self.returned = time.time() # FIXME AW benchmarking self.connected_managers = os.environ.get('connected_managers', -1) return True elif 'exception' in data: e = FuncXFuture.serializer.deserialize(data['exception']) e.reraise() else: raise NotImplementedError( 'task {} is neither pending or finished: {}'.format( self.task_id, str(data))) def result(self, timeout=None): if self.__result is not None: return self.__result while True: if self.done(): break else: time.sleep(self.poll_period) if timeout is not None: timeout -= self.poll_period if timeout < 0: raise TimeoutError return self.__result
def register_funcx(task): """Register the function and the container with funcX. Parameters ---------- task : dict A dict of the task to publish Returns ------- str The funcX function id """ # Get the funcX dependent token fx_token = task['dlhub']['funcx_token'] # Create a client using this token fx_auth = globus_sdk.AccessTokenAuthorizer(fx_token) fxc = FuncXClient(fx_authorizer=fx_auth, use_offprocess_checker=False) description = f"A container for the DLHub model {task['dlhub']['shorthand_name']}" try: description = task['datacite']['descriptions'][0]['description'] except: # It doesn't have a simple description pass # Register the container with funcX container_id = fxc.register_container(task['dlhub']['ecr_uri'], 'docker', name=task['dlhub']['shorthand_name'], description=description) # Register a function funcx_id = fxc.register_function(dlhub_run, function_name=task['dlhub']['name'], container_uuid=container_id, description=description, public=True) # Whitelist the function on DLHub's endpoint # First create a new fxc client on DLHub's behalf fxc = FuncXClient(use_offprocess_checker=False) endpoint_uuid = '86a47061-f3d9-44f0-90dc-56ddc642c000' res = fxc.add_to_whitelist(endpoint_uuid, [funcx_id]) print(res) return funcx_id
def init_endpoint(self): """Setup funcx dirs and default endpoint config files TODO : Every mechanism that will update the config file, must be using a locking mechanism, ideally something like fcntl [1] to ensure that multiple endpoint invocations do not mangle the funcx config files or the lockfile module. [1] https://docs.python.org/3/library/fcntl.html """ _ = FuncXClient() if os.path.exists(self.funcx_config_file): typer.confirm( "Are you sure you want to initialize this directory? " f"This will erase everything in {self.funcx_dir}", abort=True, ) log.info(f"Wiping all current configs in {self.funcx_dir}") backup_dir = self.funcx_dir + ".bak" try: log.debug(f"Removing old backups in {backup_dir}") shutil.rmtree(backup_dir) except OSError: pass os.renames(self.funcx_dir, backup_dir) if os.path.exists(self.funcx_config_file): log.debug(f"Config file exists at {self.funcx_config_file}") return try: os.makedirs(self.funcx_dir, exist_ok=True) except Exception as e: print(f"[FuncX] Caught exception during registration {e}") shutil.copyfile(self.funcx_default_config_template, self.funcx_config_file)
"""Run toxicity inference with FuncX""" from funcx.sdk.client import FuncXClient import json # Get FuncX ready fxc = FuncXClient() theta_ep = 'd3a23590-3282-429a-8bce-e0ca0f4177f3' with open('func_uuid.json') as fp: func_id = json.load(fp) # Run the infernece smiles = ['C', 'CC', 'CCC'] task_id = fxc.run(smiles, endpoint_id=theta_ep, function_id=func_id) print(task_id)
with open(result_path, 'w') as f: print('problem loading processor instance for {}'.format( str(item)), file=f) print(e, file=f) print('environment:', file=f) for key, value in os.environ.items(): print('{}: {}'.format(key, value), file=f) print('hostname: {}'.format(subprocess.check_output('hostname', shell=True), file=f)) if stageout_url.startswith('root://'): command = 'xrdcp {} {}'.format(result_path, os.path.join(stageout_url, subdir)) subprocess.call(command, shell=True) os.unlink(result_path) return os.path.join(subdir, os.path.basename(result_path)) client = FuncXClient() uuids = {} for func in [process]: f = timeout(func) uuids[func.__name__] = client.register_function(f) with open('data/function_uuids.json', 'w') as f: f.write(json.dumps(uuids, indent=4, sort_keys=True))
import mdml_client as mdml # Connect to the MDML to query for data exp = mdml.experiment(params['experiment_id'], params['user'], params['pass'], params['host']) # Grabbing the latest temperature value query = [{"device": "DATA1", "variables": ["temperature"], "last": 1}] res = exp.query(query, verify_cert=False) # Running the query tempF = res['DATA1'][0]['temperature'] tempC = (tempF - 32) * (5 / 9) return {'time': mdml.unix_time(True), 'tempC': tempC} # Registering the function if args.register: from funcx.sdk.client import FuncXClient fxc = FuncXClient() funcx_func_uuid = fxc.register_function( basic_analysis, description="Temperature conversion") print(f'funcX UUID: {funcx_func_uuid}') else: # Use the most recent function funcx ID (manually put here after running --register once) funcx_func_uuid = "1712a2fc-cc40-4b2c-ae44-405d58f78c5d" # Sept 16th 2020 # Now that the function is ready for use, we need to start an experiment to use it with import sys import time sys.path.insert(1, '../') # using local mdml_client import mdml_client as mdml exp = mdml.experiment("TEST", args.username, args.password, args.host) exp.add_config(auto=True) exp.send_config()
def __init__(self, config, client_address="127.0.0.1", interchange_address="127.0.0.1", client_ports=(50055, 50056, 50057), worker_ports=None, worker_port_range=(54000, 55000), cores_per_worker=1.0, worker_debug=False, launch_cmd=None, heartbeat_threshold=60, logdir=".", logging_level=logging.INFO, poll_period=10, endpoint_id=None, suppress_failure=False, max_heartbeats_missed=2 ): """ Parameters ---------- config : funcx.Config object Funcx config object that describes how compute should be provisioned client_address : str The ip address at which the parsl client can be reached. Default: "127.0.0.1" interchange_address : str The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" client_ports : triple(int, int, int) The ports at which the client can be reached launch_cmd : str TODO : update worker_ports : tuple(int, int) The specific two ports at which workers will connect to the Interchange. Default: None worker_port_range : tuple(int, int) The interchange picks ports at random from the range which will be used by workers. This is overridden when the worker_ports option is set. Defauls: (54000, 55000) cores_per_worker : float cores to be assigned to each worker. Oversubscription is possible by setting cores_per_worker < 1.0. Default=1 worker_debug : Bool Enables worker debug logging. heartbeat_threshold : int Number of seconds since the last heartbeat after which worker is considered lost. logdir : str Parsl log directory paths. Logs and temp files go here. Default: '.' logging_level : int Logging level as defined in the logging module. Default: logging.INFO (20) endpoint_id : str Identity string that identifies the endpoint to the broker poll_period : int The main thread polling period, in milliseconds. Default: 10ms suppress_failure : Bool When set to True, the interchange will attempt to suppress failures. Default: False max_heartbeats_missed : int Number of heartbeats missed before setting kill_event """ self.logdir = logdir try: os.makedirs(self.logdir) except FileExistsError: pass start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level) logger.info("logger location {}".format(logger.handlers)) logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id)) self.config = config logger.info("Got config : {}".format(config)) self.strategy = self.config.strategy self.client_address = client_address self.interchange_address = interchange_address self.suppress_failure = suppress_failure self.poll_period = poll_period self.serializer = FuncXSerializer() logger.info("Attempting connection to client at {} on ports: {},{},{}".format( client_address, client_ports[0], client_ports[1], client_ports[2])) self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.RCVTIMEO = 10 # in milliseconds logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0])) self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) self.results_outgoing = self.context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1])) self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) self.command_channel = self.context.socket(zmq.DEALER) self.command_channel.RCVTIMEO = 1000 # in milliseconds # self.command_channel.set_hwm(0) logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2])) self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") self.pending_task_queue = {} self.containers = {} self.total_pending_task_count = 0 self.fxs = FuncXClient() logger.info("Interchange address is {}".format(self.interchange_address)) self.worker_ports = worker_ports self.worker_port_range = worker_port_range self.task_outgoing = self.context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) self.results_incoming = self.context.socket(zmq.ROUTER) self.results_incoming.set_hwm(0) # initalize the last heartbeat time to start the loop self.last_heartbeat = time.time() self.max_heartbeats_missed = max_heartbeats_missed self.endpoint_id = endpoint_id if self.worker_ports: self.worker_task_port = self.worker_ports[0] self.worker_result_port = self.worker_ports[1] self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port)) self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port)) else: self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*', min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*', min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) logger.info("Bound to ports {},{} for incoming worker connections".format( self.worker_task_port, self.worker_result_port)) self._ready_manager_queue = {} self.heartbeat_threshold = heartbeat_threshold self.blocks = {} # type: Dict[str, str] self.block_id_map = {} self.launch_cmd = launch_cmd self.last_core_hr_counter = 0 if not launch_cmd: self.launch_cmd = ("funcx-manager {debug} {max_workers} " "-c {cores_per_worker} " "--poll {poll_period} " "--task_url={task_url} " "--result_url={result_url} " "--logdir={logdir} " "--block_id={{block_id}} " "--hb_period={heartbeat_period} " "--hb_threshold={heartbeat_threshold} " "--worker_mode={worker_mode} " "--scheduler_mode={scheduler_mode} " "--worker_type={{worker_type}} ") self.current_platform = {'parsl_v': PARSL_VERSION, 'python_v': "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor, sys.version_info.micro), 'os': platform.system(), 'hname': platform.node(), 'dir': os.getcwd()} logger.info("Platform info: {}".format(self.current_platform)) self._block_counter = 0 try: self.load_config() except Exception as e: logger.exception("Caught exception") raise
class Interchange(object): """ Interchange is a task orchestrator for distributed systems. 1. Asynchronously queue large volume of tasks (>100K) 2. Allow for workers to join and leave the union 3. Detect workers that have failed using heartbeats 4. Service single and batch requests from workers 5. Be aware of requests worker resource capacity, eg. schedule only jobs that fit into walltime. TODO: We most likely need a PUB channel to send out global commands, like shutdown """ def __init__(self, config, client_address="127.0.0.1", interchange_address="127.0.0.1", client_ports=(50055, 50056, 50057), worker_ports=None, worker_port_range=(54000, 55000), cores_per_worker=1.0, worker_debug=False, launch_cmd=None, heartbeat_threshold=60, logdir=".", logging_level=logging.INFO, poll_period=10, endpoint_id=None, suppress_failure=False, max_heartbeats_missed=2 ): """ Parameters ---------- config : funcx.Config object Funcx config object that describes how compute should be provisioned client_address : str The ip address at which the parsl client can be reached. Default: "127.0.0.1" interchange_address : str The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" client_ports : triple(int, int, int) The ports at which the client can be reached launch_cmd : str TODO : update worker_ports : tuple(int, int) The specific two ports at which workers will connect to the Interchange. Default: None worker_port_range : tuple(int, int) The interchange picks ports at random from the range which will be used by workers. This is overridden when the worker_ports option is set. Defauls: (54000, 55000) cores_per_worker : float cores to be assigned to each worker. Oversubscription is possible by setting cores_per_worker < 1.0. Default=1 worker_debug : Bool Enables worker debug logging. heartbeat_threshold : int Number of seconds since the last heartbeat after which worker is considered lost. logdir : str Parsl log directory paths. Logs and temp files go here. Default: '.' logging_level : int Logging level as defined in the logging module. Default: logging.INFO (20) endpoint_id : str Identity string that identifies the endpoint to the broker poll_period : int The main thread polling period, in milliseconds. Default: 10ms suppress_failure : Bool When set to True, the interchange will attempt to suppress failures. Default: False max_heartbeats_missed : int Number of heartbeats missed before setting kill_event """ self.logdir = logdir try: os.makedirs(self.logdir) except FileExistsError: pass start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level) logger.info("logger location {}".format(logger.handlers)) logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id)) self.config = config logger.info("Got config : {}".format(config)) self.strategy = self.config.strategy self.client_address = client_address self.interchange_address = interchange_address self.suppress_failure = suppress_failure self.poll_period = poll_period self.serializer = FuncXSerializer() logger.info("Attempting connection to client at {} on ports: {},{},{}".format( client_address, client_ports[0], client_ports[1], client_ports[2])) self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.RCVTIMEO = 10 # in milliseconds logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0])) self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) self.results_outgoing = self.context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1])) self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) self.command_channel = self.context.socket(zmq.DEALER) self.command_channel.RCVTIMEO = 1000 # in milliseconds # self.command_channel.set_hwm(0) logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2])) self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") self.pending_task_queue = {} self.containers = {} self.total_pending_task_count = 0 self.fxs = FuncXClient() logger.info("Interchange address is {}".format(self.interchange_address)) self.worker_ports = worker_ports self.worker_port_range = worker_port_range self.task_outgoing = self.context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) self.results_incoming = self.context.socket(zmq.ROUTER) self.results_incoming.set_hwm(0) # initalize the last heartbeat time to start the loop self.last_heartbeat = time.time() self.max_heartbeats_missed = max_heartbeats_missed self.endpoint_id = endpoint_id if self.worker_ports: self.worker_task_port = self.worker_ports[0] self.worker_result_port = self.worker_ports[1] self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port)) self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port)) else: self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*', min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*', min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) logger.info("Bound to ports {},{} for incoming worker connections".format( self.worker_task_port, self.worker_result_port)) self._ready_manager_queue = {} self.heartbeat_threshold = heartbeat_threshold self.blocks = {} # type: Dict[str, str] self.block_id_map = {} self.launch_cmd = launch_cmd self.last_core_hr_counter = 0 if not launch_cmd: self.launch_cmd = ("funcx-manager {debug} {max_workers} " "-c {cores_per_worker} " "--poll {poll_period} " "--task_url={task_url} " "--result_url={result_url} " "--logdir={logdir} " "--block_id={{block_id}} " "--hb_period={heartbeat_period} " "--hb_threshold={heartbeat_threshold} " "--worker_mode={worker_mode} " "--scheduler_mode={scheduler_mode} " "--worker_type={{worker_type}} ") self.current_platform = {'parsl_v': PARSL_VERSION, 'python_v': "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor, sys.version_info.micro), 'os': platform.system(), 'hname': platform.node(), 'dir': os.getcwd()} logger.info("Platform info: {}".format(self.current_platform)) self._block_counter = 0 try: self.load_config() except Exception as e: logger.exception("Caught exception") raise def load_config(self): """ Load the config """ logger.info("Loading endpoint local config") working_dir = self.config.working_dir if self.config.working_dir is None: working_dir = "{}/{}".format(self.logdir, "worker_logs") logger.info("Setting working_dir: {}".format(working_dir)) self.config.provider.script_dir = working_dir if hasattr(self.config.provider, 'channel'): self.config.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts') self.config.provider.channel.makedirs(self.config.provider.channel.script_dir, exist_ok=True) os.makedirs(self.config.provider.script_dir, exist_ok=True) debug_opts = "--debug" if self.config.worker_debug else "" max_workers = "" if self.config.max_workers_per_node == float('inf') \ else "--max_workers={}".format(self.config.max_workers_per_node) worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}" worker_result_url = f"tcp://{self.interchange_address}:{self.worker_result_port}" l_cmd = self.launch_cmd.format(debug=debug_opts, max_workers=max_workers, cores_per_worker=self.config.cores_per_worker, #mem_per_worker=self.config.mem_per_worker, prefetch_capacity=self.config.prefetch_capacity, task_url=worker_task_url, result_url=worker_result_url, nodes_per_block=self.config.provider.nodes_per_block, heartbeat_period=self.config.heartbeat_period, heartbeat_threshold=self.config.heartbeat_threshold, poll_period=self.config.poll_period, worker_mode=self.config.worker_mode, scheduler_mode=self.config.scheduler_mode, logdir=working_dir) self.launch_cmd = l_cmd logger.info("Launch command: {}".format(self.launch_cmd)) if self.config.scaling_enabled: logger.info("Scaling ...") self.scale_out(self.config.provider.init_blocks) def get_tasks(self, count): """ Obtains a batch of tasks from the internal pending_task_queue Parameters ---------- count: int Count of tasks to get from the queue Returns ------- List of upto count tasks. May return fewer than count down to an empty list eg. [{'task_id':<x>, 'buffer':<buf>} ... ] """ tasks = [] for i in range(0, count): try: x = self.pending_task_queue.get(block=False) except queue.Empty: break else: tasks.append(x) return tasks def migrate_tasks_to_internal(self, kill_event, status_request): """Pull tasks from the incoming tasks 0mq pipe onto the internal pending task queue Parameters: ----------- kill_event : threading.Event Event to let the thread know when it is time to die. """ logger.info("[TASK_PULL_THREAD] Starting") task_counter = 0 poller = zmq.Poller() poller.register(self.task_incoming, zmq.POLLIN) while not kill_event.is_set(): # Check when the last heartbeat was. # logger.debug(f"[TASK_PULL_THREAD] Last heartbeat: {self.last_heartbeat}") if int(time.time() - self.last_heartbeat) > (self.heartbeat_threshold * self.max_heartbeats_missed): logger.critical("[TASK_PULL_THREAD] Missed too many heartbeats. Setting kill event.") kill_event.set() break try: msg = self.task_incoming.recv_pyobj() self.last_heartbeat = time.time() except zmq.Again: # We just timed out while attempting to receive logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count)) continue if msg == 'STOP': kill_event.set() break elif msg == 'STATUS_REQUEST': logger.info("Got STATUS_REQUEST") status_request.set() else: logger.info("[TASK_PULL_THREAD] Received task:{}".format(msg)) task_type = self.get_container(msg['task_id'].split(";")[1]) msg['container'] = task_type if task_type not in self.pending_task_queue: self.pending_task_queue[task_type] = queue.Queue(maxsize=10 ** 6) self.pending_task_queue[task_type].put(msg) self.total_pending_task_count += 1 logger.debug("[TASK_PULL_THREAD] pending task count: {}".format(self.total_pending_task_count)) task_counter += 1 logger.debug("[TASK_PULL_THREAD] Fetched task:{}".format(task_counter)) def get_container(self, container_uuid): """ Get the container image location if it is not known to the interchange""" if container_uuid not in self.containers: if container_uuid == 'RAW' or not container_uuid: self.containers[container_uuid] = 'RAW' else: try: container = self.fxs.get_container(container_uuid, self.config.container_type) except Exception: logger.exception("[FETCH_CONTAINER] Unable to resolve container location") self.containers[container_uuid] = 'RAW' else: logger.info("[FETCH_CONTAINER] Got container info: {}".format(container)) self.containers[container_uuid] = container.get('location', 'RAW') return self.containers[container_uuid] def get_total_tasks_outstanding(self): """ Get the outstanding tasks in total """ outstanding = {} for task_type in self.pending_task_queue: outstanding[task_type] = outstanding.get(task_type, 0) + self.pending_task_queue[task_type].qsize() for manager in self._ready_manager_queue: for task_type in self._ready_manager_queue[manager]['tasks']: outstanding[task_type] = outstanding.get(task_type, 0) + len(self._ready_manager_queue[manager]['tasks'][task_type]) return outstanding def get_total_live_workers(self): """ Get the total active workers """ active = 0 for manager in self._ready_manager_queue: if self._ready_manager_queue[manager]['active']: active += self._ready_manager_queue[manager]['max_worker_count'] return active def get_outstanding_breakdown(self): """ Get outstanding breakdown per manager and in the interchange queues Returns ------- List of status for online elements [ (element, tasks_pending, status) ... ] """ pending_on_interchange = self.total_pending_task_count # Reporting pending on interchange is a deviation from Parsl reply = [('interchange', pending_on_interchange, True)] for manager in self._ready_manager_queue: resp = (manager.decode('utf-8'), sum([len(tids) for tids in self._ready_manager_queue[manager]['tasks'].values()]), self._ready_manager_queue[manager]['active']) reply.append(resp) return reply def _hold_block(self, block_id): """ Sends hold command to all managers which are in a specific block Parameters ---------- block_id : str Block identifier of the block to be put on hold """ for manager in self._ready_manager_queue: if self._ready_manager_queue[manager]['active'] and \ self._ready_manager_queue[manager]['block_id'] == block_id: logger.debug("[HOLD_BLOCK]: Sending hold to manager: {}".format(manager)) self.hold_manager(manager) def hold_manager(self, manager): """ Put manager on hold Parameters ---------- manager : str Manager id to be put on hold while being killed """ if manager in self._ready_manager_queue: self._ready_manager_queue[manager]['active'] = False reply = True else: reply = False def _command_server(self, kill_event): """ Command server to run async command to the interchange """ logger.debug("[COMMAND] Command Server Starting") while not kill_event.is_set(): try: command_req = self.command_channel.recv_pyobj() logger.debug("[COMMAND] Received command request: {}".format(command_req)) if command_req == "OUTSTANDING_C": reply = self.get_total_outstanding() elif command_req == "MANAGERS": reply = self.get_outstanding_breakdown() elif command_req.startswith("HOLD_WORKER"): cmd, s_manager = command_req.split(';') manager = s_manager.encode('utf-8') logger.info("[CMD] Received HOLD_WORKER for {}".format(manager)) if manager in self._ready_manager_queue: self._ready_manager_queue[manager]['active'] = False reply = True else: reply = False elif command_req == "HEARTBEAT": logger.info("[CMD] Received heartbeat message from hub") reply = "HBT,{}".format(self.endpoint_id) elif command_req == "SHUTDOWN": logger.info("[CMD] Received SHUTDOWN command") kill_event.set() reply = True else: reply = None logger.debug("[COMMAND] Reply: {}".format(reply)) self.command_channel.send_pyobj(reply) except zmq.Again: logger.debug("[COMMAND] is alive") continue def stop(self): """Prepare the interchange for shutdown""" self._kill_event.set() self._task_puller_thread.join() self._command_thread.join() def start(self, poll_period=None): """ Start the Interchange Parameters: ---------- poll_period : int poll_period in milliseconds """ logger.info("Incoming ports bound") if poll_period is None: poll_period = self.poll_period start = time.time() count = 0 self._kill_event = threading.Event() self._status_request = threading.Event() self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal, args=(self._kill_event, self._status_request, )) self._task_puller_thread.start() self._command_thread = threading.Thread(target=self._command_server, args=(self._kill_event, )) self._command_thread.start() try: logger.debug("Starting strategy.") self.strategy.start(self) except RuntimeError as e: # This is raised when re-registering an endpoint as strategy already exists logger.debug("Failed to start strategy.") logger.info(e) poller = zmq.Poller() # poller.register(self.task_incoming, zmq.POLLIN) poller.register(self.task_outgoing, zmq.POLLIN) poller.register(self.results_incoming, zmq.POLLIN) # These are managers which we should examine in an iteration # for scheduling a job (or maybe any other attention?). # Anything altering the state of the manager should add it # onto this list. interesting_managers = set() while not self._kill_event.is_set(): self.socks = dict(poller.poll(timeout=poll_period)) # Listen for requests for work if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN: logger.debug("[MAIN] starting task_outgoing section") message = self.task_outgoing.recv_multipart() manager = message[0] if manager not in self._ready_manager_queue: reg_flag = False try: msg = json.loads(message[1].decode('utf-8')) reg_flag = True except Exception: logger.warning("[MAIN] Got a non-json registration message from manager:{}".format( manager)) logger.debug("[MAIN] Message :\n{}\n".format(message)) # By default we set up to ignore bad nodes/registration messages. self._ready_manager_queue[manager] = {'last': time.time(), 'reg_time': time.time(), 'free_capacity': {'total_workers': 0}, 'max_worker_count': 0, 'active': True, 'tasks': collections.defaultdict(set), 'total_tasks': 0} if reg_flag is True: interesting_managers.add(manager) logger.info("[MAIN] Adding manager: {} to ready queue".format(manager)) self._ready_manager_queue[manager].update(msg) logger.info("[MAIN] Registration info for manager {}: {}".format(manager, msg)) if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or msg['parsl_v'] != self.current_platform['parsl_v']): logger.warn("[MAIN] Manager {} has incompatible version info with the interchange".format(manager)) if self.suppress_failure is False: logger.debug("Setting kill event") self._kill_event.set() e = ManagerLost(manager) result_package = {'task_id': -1, 'exception': self.serializer.serialize(e)} pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) logger.warning("[MAIN] Sent failure reports, unregistering manager") else: logger.debug("[MAIN] Suppressing shutdown due to version incompatibility") else: # Registration has failed. if self.suppress_failure is False: logger.debug("Setting kill event for bad manager") self._kill_event.set() e = BadRegistration(manager, critical=True) result_package = {'task_id': -1, 'exception': self.serializer.serialize(e)} pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) else: logger.debug("[MAIN] Suppressing bad registration from manager:{}".format( manager)) else: self._ready_manager_queue[manager]['last'] = time.time() if message[1] == b'HEARTBEAT': logger.debug("[MAIN] Manager {} sends heartbeat".format(manager)) self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE]) else: manager_adv = pickle.loads(message[1]) logger.debug("[MAIN] Manager {} requested {}".format(manager, manager_adv)) self._ready_manager_queue[manager]['free_capacity'].update(manager_adv) self._ready_manager_queue[manager]['free_capacity']['total_workers'] = sum(manager_adv.values()) interesting_managers.add(manager) # If we had received any requests, check if there are tasks that could be passed logger.debug("[MAIN] Managers count (total/interesting): {}/{}".format( len(self._ready_manager_queue), len(interesting_managers))) task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers, self.pending_task_queue, self._ready_manager_queue, scheduler_mode=self.config.scheduler_mode) self.total_pending_task_count -= dispatched_task for manager in task_dispatch: tasks = task_dispatch[manager] if tasks: logger.info("[MAIN] Sending task message {} to manager {}".format(tasks, manager)) self.task_outgoing.send_multipart([manager, b'', pickle.dumps(tasks)]) # Receive any results and forward to client if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN: logger.debug("[MAIN] entering results_incoming section") manager, *b_messages = self.results_incoming.recv_multipart() if manager not in self._ready_manager_queue: logger.warning("[MAIN] Received a result from a un-registered manager: {}".format(manager)) else: logger.info("[MAIN] Got {} result items in batch".format(len(b_messages))) for b_message in b_messages: r = pickle.loads(b_message) # logger.debug("[MAIN] Received result for task {} from {}".format(r['task_id'], manager)) task_type = self.containers[r['task_id'].split(';')[1]] self._ready_manager_queue[manager]['tasks'][task_type].remove(r['task_id']) self._ready_manager_queue[manager]['total_tasks'] -= len(b_messages) self.results_outgoing.send_multipart(b_messages) logger.debug("[MAIN] Current tasks: {}".format(self._ready_manager_queue[manager]['tasks'])) logger.debug("[MAIN] leaving results_incoming section") # logger.debug("[MAIN] entering bad_managers section") bad_managers = [manager for manager in self._ready_manager_queue if time.time() - self._ready_manager_queue[manager]['last'] > self.heartbeat_threshold] for manager in bad_managers: logger.debug("[MAIN] Last: {} Current: {}".format(self._ready_manager_queue[manager]['last'], time.time())) logger.warning("[MAIN] Too many heartbeats missed for manager {}".format(manager)) e = ManagerLost(manager) for task_type in self._ready_manager_queue[manager]['tasks']: for tid in self._ready_manager_queue[manager]['tasks'][task_type]: result_package = {'task_id': tid, 'exception': self.serializer.serialize(e)} pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) logger.warning("[MAIN] Sent failure reports, unregistering manager") self._ready_manager_queue.pop(manager, 'None') if manager in interesting_managers: interesting_managers.remove(manager) logger.debug("[MAIN] ending one main loop iteration") if self._status_request.is_set(): logger.info("status request response") result_package = self.get_status_report() pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) logger.info("[MAIN] Sent info response") self._status_request.clear() delta = time.time() - start logger.info("Processed {} tasks in {} seconds".format(count, delta)) logger.warning("Exiting") def get_status_report(self): """ Get utilization numbers """ total_cores = 0 total_mem = 0 core_hrs = 0 active_managers = 0 free_capacity = 0 outstanding_tasks = self.get_total_tasks_outstanding() pending_tasks = self.total_pending_task_count num_managers = len(self._ready_manager_queue) live_workers = self.get_total_live_workers() for manager in self._ready_manager_queue: total_cores += self._ready_manager_queue[manager]['cores'] total_mem += self._ready_manager_queue[manager]['mem'] active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time']) core_hrs += (active_dur * total_cores) / 3600 if self._ready_manager_queue[manager]['active']: active_managers += 1 free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers'] result_package = {'task_id': -2, 'info': {'total_cores': total_cores, 'total_mem' : total_mem, 'new_core_hrs': core_hrs - self.last_core_hr_counter, 'total_core_hrs': round(core_hrs, 2), 'managers': num_managers, 'active_managers': active_managers, 'total_workers': live_workers, 'idle_workers': free_capacity, 'pending_tasks': pending_tasks, 'outstanding_tasks': outstanding_tasks, 'worker_mode': self.config.worker_mode, 'scheduler_mode': self.config.scheduler_mode, 'scaling_enabled': self.config.scaling_enabled, 'mem_per_worker': self.config.mem_per_worker, 'cores_per_worker': self.config.cores_per_worker, 'prefetch_capacity': self.config.prefetch_capacity, 'max_blocks': self.config.provider.max_blocks, 'min_blocks': self.config.provider.min_blocks, 'max_workers_per_node': self.config.max_workers_per_node, 'nodes_per_block': self.config.provider.nodes_per_block }} self.last_core_hr_counter = core_hrs return result_package def scale_out(self, blocks=1, task_type=None): """Scales out the number of blocks by "blocks" Raises: NotImplementedError """ r = [] for i in range(blocks): if self.config.provider: self._block_counter += 1 external_block_id = str(self._block_counter) if not task_type and self.config.scheduler_mode == 'hard': launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW') else: launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type) if not task_type: internal_block = self.config.provider.submit(launch_cmd, 1) else: internal_block = self.config.provider.submit(launch_cmd, 1, task_type) logger.debug("Launched block {}->{}".format(external_block_id, internal_block)) if not internal_block: raise(ScalingFailed(self.provider.label, "Attempts to provision nodes via provider has failed")) self.blocks[external_block_id] = internal_block self.block_id_map[internal_block] = external_block_id else: logger.error("No execution provider available") r = None return r def scale_in(self, blocks=None, block_ids=[], task_type=None): """Scale in the number of active blocks by specified amount. Parameters ---------- blocks : int # of blocks to terminate block_ids : [str.. ] List of external block ids to terminate """ if task_type: logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type)) if self.config.scaling_enabled and self.config.provider: to_kill, r = self.config.provider.cancel(blocks, task_type) logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r)) for job in to_kill: logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job)) block_id = self.block_id_map[job] logger.info("[scale_in] Holding block {}".format(block_id)) self._hold_block(block_id) self.blocks.pop(block_id) return r if block_ids: block_ids_to_kill = block_ids else: block_ids_to_kill = list(self.blocks.keys())[:blocks] # Try a polite terminate # TODO : Missing logic to hold blocks for block_id in block_ids_to_kill: self._hold_block(block_id) # Now kill via provider to_kill = [self.blocks.pop(bid) for bid in block_ids_to_kill] if self.config.scaling_enabled and self.config.provider: r = self.config.provider.cancel(to_kill) return r def provider_status(self): """ Get status of all blocks from the provider """ status = [] if self.config.provider: logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values()))) status = self.config.provider.status(list(self.blocks.values())) logger.debug("[MAIN] The status is {}".format(status)) return status
def start_endpoint(self, name, endpoint_uuid, endpoint_config): self.name = name endpoint_dir = os.path.join(self.funcx_dir, self.name) endpoint_json = os.path.join(endpoint_dir, "endpoint.json") # These certs need to be recreated for every registration keys_dir = os.path.join(endpoint_dir, "certificates") os.makedirs(keys_dir, exist_ok=True) client_public_file, client_secret_file = zmq.auth.create_certificates( keys_dir, "endpoint") client_public_key, _ = zmq.auth.load_certificate(client_public_file) client_public_key = client_public_key.decode("utf-8") # This is to ensure that at least 1 executor is defined if not endpoint_config.config.executors: raise Exception( f"Endpoint config file at {endpoint_dir} is missing " "executor definitions") funcx_client_options = { "funcx_service_address": endpoint_config.config.funcx_service_address, "check_endpoint_version": True, } funcx_client = FuncXClient(**funcx_client_options) endpoint_uuid = self.check_endpoint_json(endpoint_json, endpoint_uuid) log.info(f"Starting endpoint with uuid: {endpoint_uuid}") pid_file = os.path.join(endpoint_dir, "daemon.pid") pid_check = self.check_pidfile(pid_file) # if the pidfile exists, we should return early because we don't # want to attempt to create a new daemon when one is already # potentially running with the existing pidfile if pid_check["exists"]: if pid_check["active"]: log.info("Endpoint is already active") sys.exit(-1) else: log.info( "A prior Endpoint instance appears to have been terminated without " "proper cleanup. Cleaning up now.") self.pidfile_cleanup(pid_file) results_ack_handler = ResultsAckHandler(endpoint_dir=endpoint_dir) try: results_ack_handler.load() results_ack_handler.persist() except Exception: log.exception( "Caught exception while attempting load and persist of outstanding " "results") sys.exit(-1) # Create a daemon context # If we are running a full detached daemon then we will send the output to # log files, otherwise we can piggy back on our stdout if endpoint_config.config.detach_endpoint: stdout = open( os.path.join(endpoint_dir, endpoint_config.config.stdout), "a+") stderr = open( os.path.join(endpoint_dir, endpoint_config.config.stderr), "a+") else: stdout = sys.stdout stderr = sys.stderr try: context = daemon.DaemonContext( working_directory=endpoint_dir, umask=0o002, pidfile=daemon.pidfile.PIDLockFile(pid_file), stdout=stdout, stderr=stderr, detach_process=endpoint_config.config.detach_endpoint, ) except Exception: log.exception( "Caught exception while trying to setup endpoint context dirs") sys.exit(-1) # place registration after everything else so that the endpoint will # only be registered if everything else has been set up successfully reg_info = None try: reg_info = register_endpoint(funcx_client, endpoint_uuid, endpoint_dir, self.name) # if the service sends back an error response, it will be a FuncxResponseError except FuncxResponseError as e: # an example of an error that could conceivably occur here would be # if the service could not register this endpoint with the forwarder # because the forwarder was unreachable if e.http_status_code >= 500: log.exception( "Caught exception while attempting endpoint registration") log.critical( "Endpoint registration will be retried in the new endpoint daemon " "process. The endpoint will not work until it is successfully " "registered.") else: raise e # if the service has an unexpected internal error and is unable to send # back a FuncxResponseError except GlobusAPIError as e: if e.http_status >= 500: log.exception( "Caught exception while attempting endpoint registration") log.critical( "Endpoint registration will be retried in the new endpoint daemon " "process. The endpoint will not work until it is successfully " "registered.") else: raise e # if the service is unreachable due to a timeout or connection error except NetworkError as e: # the output of a NetworkError exception is huge and unhelpful, so # it seems better to just stringify it here and get a concise error log.exception( f"Caught exception while attempting endpoint registration: {e}" ) log.critical( "funcx-endpoint is unable to reach the funcX service due to a " "NetworkError \n" "Please make sure that the funcX service address you provided is " "reachable \n" "and then attempt restarting the endpoint") exit(-1) except Exception: raise if reg_info: log.info("Launching endpoint daemon process") else: log.critical( "Launching endpoint daemon process with errors noted above") # NOTE # It's important that this log is emitted before we enter the daemon context # because daemonization closes down everything, a log message inside the # context won't write the currently configured loggers logfile = os.path.join(endpoint_dir, "endpoint.log") log.info( "Logging will be reconfigured for the daemon. logfile=%s , debug=%s", logfile, self.debug, ) with context: setup_logging(logfile=logfile, debug=self.debug, console_enabled=False) self.daemon_launch( endpoint_uuid, endpoint_dir, keys_dir, endpoint_config, reg_info, funcx_client_options, results_ack_handler, )
from funcx.sdk.client import FuncXClient fxc = FuncXClient() def funcx_test(): while True: print("Viana") func_uuid = fxc.register_function(funcx_test) tutorial_endpoint = '70d29c21-66c3-4ba8-98fc-91490b522699' # Public tutorial endpoint res = fxc.run(endpoint_id=tutorial_endpoint, function_id=func_uuid) funcx_test()
return sum(event) """ @funcx.register(description="...") def sum_yadu_new01(event): return sum(event) """ def test(fxc, ep_id): fn_uuid = fxc.register_function( sum_yadu_new01, ep_id, # TODO: We do not need ep id here description="New sum function defined without string spec", ) print("FN_UUID : ", fn_uuid) res = fxc.run([1, 2, 3, 99], endpoint_id=ep_id, function_id=fn_uuid) print(res) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) args = parser.parse_args() fxc = FuncXClient() test(fxc, args.endpoint)
from funcx.sdk.client import FuncXClient if __name__ == "__main__": fxc = FuncXClient() print(fxc) fxc.register_endpoint("foobar", None)
def __init__( self, config, client_address="127.0.0.1", interchange_address="127.0.0.1", client_ports: Tuple[int, int, int] = (50055, 50056, 50057), launch_cmd=None, logdir=".", endpoint_id=None, keys_dir=".curve", suppress_failure=True, endpoint_dir=".", endpoint_name="default", reg_info=None, funcx_client_options=None, results_ack_handler=None, ): """ Parameters ---------- config : funcx.Config object Funcx config object that describes how compute should be provisioned client_address : str The ip address at which the parsl client can be reached. Default: "127.0.0.1" interchange_address : str The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" client_ports : Tuple[int, int, int] The ports at which the client can be reached launch_cmd : str TODO : update logdir : str Parsl log directory paths. Logs and temp files go here. Default: '.' keys_dir : str Directory from where keys used for communicating with the funcX service (forwarders) are stored endpoint_id : str Identity string that identifies the endpoint to the broker suppress_failure : Bool When set to True, the interchange will attempt to suppress failures. Default: False endpoint_dir : str Endpoint directory path to store registration info in endpoint_name : str Name of endpoint reg_info : Dict Registration info from initial registration on endpoint start, if it succeeded funcx_client_options : Dict FuncXClient initialization options """ self.logdir = logdir log.info( "Initializing EndpointInterchange process with Endpoint ID: {}". format(endpoint_id)) self.config = config log.info(f"Got config: {config}") self.client_address = client_address self.interchange_address = interchange_address self.client_ports = client_ports self.suppress_failure = suppress_failure self.endpoint_dir = endpoint_dir self.endpoint_name = endpoint_name if funcx_client_options is None: funcx_client_options = {} self.funcx_client = FuncXClient(**funcx_client_options) self.initial_registration_complete = False if reg_info: self.initial_registration_complete = True self.apply_reg_info(reg_info) self.heartbeat_period = self.config.heartbeat_period self.heartbeat_threshold = self.config.heartbeat_threshold # initalize the last heartbeat time to start the loop self.last_heartbeat = time.time() self.keys_dir = keys_dir self.serializer = FuncXSerializer() self.pending_task_queue = Queue() self.containers = {} self.total_pending_task_count = 0 self._quiesce_event = threading.Event() self._kill_event = threading.Event() self.results_ack_handler = results_ack_handler log.info(f"Interchange address is {self.interchange_address}") self.endpoint_id = endpoint_id self.current_platform = { "parsl_v": PARSL_VERSION, "python_v": "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor, sys.version_info.micro), "libzmq_v": zmq.zmq_version(), "pyzmq_v": zmq.pyzmq_version(), "os": platform.system(), "hname": platform.node(), "funcx_sdk_version": funcx_sdk_version, "funcx_endpoint_version": funcx_endpoint_version, "registration": self.endpoint_id, "dir": os.getcwd(), } log.info(f"Platform info: {self.current_platform}") try: self.load_config() except Exception: log.exception("Caught exception") raise self.tasks = set() self.task_status_deltas = {} self._test_start = False
from funcx.sdk.client import FuncXClient import time fx = FuncXClient() def test_batch1(a, b, c=2, d=2): return a + b + c + d def test_batch2(a, b, c=2, d=2): return a * b * c * d def test_batch3(a, b, c=2, d=2): return a + 2 * b + 3 * c + 4 * d funcs = [test_batch1, test_batch2, test_batch3] func_ids = [] for func in funcs: func_ids.append(fx.register_function(func, description='test')) ep_id = '4b116d3c-1703-4f8f-9f6f-39921e5864df' print("FN_UUID : ", func_ids) start = time.time() task_count = 5 batch = fx.create_batch() for func_id in func_ids: for i in range(task_count):
def __init__(self, dlh_authorizer: Optional[GlobusAuthorizer] = None, search_authorizer: Optional[GlobusAuthorizer] = None, fx_authorizer: Optional[GlobusAuthorizer] = None, openid_authorizer: Optional[GlobusAuthorizer] = None, http_timeout: Optional[int] = None, force_login: bool = False, **kwargs): """Initialize the client Args: http_timeout (int): Timeout for any call to service in seconds. (default is no timeout) force_login (bool): Whether to force a login to get new credentials. A login will always occur if ``dlh_authorizer`` or ``search_client`` are not provided. no_local_server (bool): Disable spinning up a local server to automatically copy-paste the auth code. THIS IS REQUIRED if you are on a remote server. When used locally with no_local_server=False, the domain is localhost with a randomly chosen open port number. **Default**: ``True``. no_browser (bool): Do not automatically open the browser for the Globus Auth URL. Display the URL instead and let the user navigate to that location manually. **Default**: ``True``. dlh_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with DLHub. If ``None``, will be created from your account's credentials. search_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authenticated SearchClient to communicate with Globus Search. If ``None``, will be created from your account's credentials. fx_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with funcX. If ``None``, will be created from your account's credentials. openid_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with OpenID. If ``None``, will be created from your account's credentials. Keyword arguments are the same as for :class:`BaseClient <globus_sdk.base.BaseClient>`. """ authorizers = [dlh_authorizer, search_authorizer, openid_authorizer, fx_authorizer] # Get authorizers through Globus login if any are not provided if not all(a is not None for a in authorizers): # If some but not all were provided, warn the user they could be making a mistake if any(a is not None for a in authorizers): logger.warning('You have defined some of the authorizers but not all. DLHub is falling back to login. ' 'You must provide authorizers for DLHub, Search, OpenID, FuncX.') fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" auth_res = login(services=["search", "dlhub", fx_scope, "openid"], app_name="DLHub_Client", make_clients=False, client_id=CLIENT_ID, clear_old_tokens=force_login, token_dir=_token_dir, no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True)) # Unpack the authorizers dlh_authorizer = auth_res["dlhub"] fx_authorizer = auth_res[fx_scope] openid_authorizer = auth_res['openid'] search_authorizer = auth_res['search'] # Define the subclients needed by the service self._fx_client = FuncXClient(fx_authorizer=fx_authorizer, search_authorizer=search_authorizer, openid_authorizer=openid_authorizer, no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True)) self._search_client = globus_sdk.SearchClient(authorizer=search_authorizer, http_timeout=5 * 60) # funcX endpoint to use self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000' self.fx_cache = {} super(DLHubClient, self).__init__("DLHub", environment='dlhub', authorizer=dlh_authorizer, http_timeout=http_timeout, base_url=DLHUB_SERVICE_ADDRESS, **kwargs)
class DLHubClient(BaseClient): """Main class for interacting with the DLHub service Holds helper operations for performing common tasks with the DLHub service. For example, `get_servables` produces a list of all servables registered with DLHub. For most cases, we recommend creating a new DLHubClient by calling ``DLHubClient.login``. This operation will check if you have saved any credentials to disk before using the CLI or SDK and, if not, get new credentials and save them for later use. For cases where disk access is unacceptable, you can create the client by creating an authorizer following the `tutorial for the Globus SDK <https://globus-sdk-python.readthedocs.io/en/stable/tutorial/>`_ and providing that authorizer to the initializer (e.g., ``DLHubClient(auth)``)""" def __init__(self, dlh_authorizer=None, search_client=None, http_timeout=None, force_login=False, fx_authorizer=None, **kwargs): """Initialize the client Args: dlh_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with DLHub. If ``None``, will be created. search_client (:class:`SearchClient <globus_sdk.SearchClient>`): An authenticated SearchClient to communicate with Globus Search. If ``None``, will be created. http_timeout (int): Timeout for any call to service in seconds. (default is no timeout) force_login (bool): Whether to force a login to get new credentials. A login will always occur if ``dlh_authorizer`` or ``search_client`` are not provided. no_local_server (bool): Disable spinning up a local server to automatically copy-paste the auth code. THIS IS REQUIRED if you are on a remote server. When used locally with no_local_server=False, the domain is localhost with a randomly chosen open port number. **Default**: ``True``. fx_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`): An authorizer instance used to communicate with funcX. If ``None``, will be created. no_browser (bool): Do not automatically open the browser for the Globus Auth URL. Display the URL instead and let the user navigate to that location manually. **Default**: ``True``. Keyword arguments are the same as for BaseClient. """ if force_login or not dlh_authorizer or not search_client or not fx_authorizer: fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" auth_res = login(services=["search", "dlhub", fx_scope, "openid"], app_name="DLHub_Client", client_id=CLIENT_ID, clear_old_tokens=force_login, token_dir=_token_dir, no_local_server=kwargs.get( "no_local_server", True), no_browser=kwargs.get("no_browser", True)) # openid_authorizer = auth_res["openid"] dlh_authorizer = auth_res["dlhub"] fx_authorizer = auth_res[fx_scope] self._search_client = auth_res["search"] self._fx_client = FuncXClient( force_login=force_login, no_local_server=kwargs.get("no_local_server", True), no_browser=kwargs.get("no_browser", True), funcx_service_address='https://api.funcx.org/v1', ) # funcX endpoint to use self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000' # self.fx_endpoint = '2c92a06a-015d-4bfa-924c-b3d0c36bdad7' self.fx_serializer = FuncXSerializer() self.fx_cache = {} super(DLHubClient, self).__init__("DLHub", environment='dlhub', authorizer=dlh_authorizer, http_timeout=http_timeout, base_url=DLHUB_SERVICE_ADDRESS, **kwargs) def logout(self): """Remove credentials from your local system""" logout() @property def query(self): """Access a query of the DLHub Search repository""" return DLHubSearchHelper(search_client=self._search_client) def get_username(self): """Get the username associated with the current credentials""" res = self.get('/namespaces') return res.data['namespace'] def get_servables(self, only_latest_version=True): """Get all of the servables available in the service Args: only_latest_version (bool): Whether to only return the latest version of each servable Returns: ([list]) Complete metadata for all servables found in DLHub """ # Get all of the servables results, info = self.query.match_field('dlhub.type', 'servable')\ .add_sort('dlhub.owner', ascending=True).add_sort('dlhub.name', ascending=False)\ .add_sort('dlhub.publication_date', ascending=False).search(info=True) if info['total_query_matches'] > SEARCH_LIMIT: raise RuntimeError( 'DLHub contains more servables than we can return in one entry. ' 'DLHub SDK needs to be updated.') if only_latest_version: # Sort out only the most recent versions (they come first in the sorted list names = set() output = [] for r in results: name = r['dlhub']['shorthand_name'] if name not in names: names.add(name) output.append(r) results = output # Add these to the cache for r in results: self.fx_cache[r['dlhub'] ['shorthand_name']] = r['dlhub']['funcx_id'] return results def list_servables(self): """Get a list of the servables available in the service Returns: [string]: List of all servable names in username/servable_name format """ servables = self.get_servables(only_latest_version=True) return [x['dlhub']['shorthand_name'] for x in servables] def get_task_status(self, task_id): """Get the status of a DLHub task. Args: task_id (string): UUID of the task Returns: dict: status block containing "status" key. """ r = self._fx_client.get_task(task_id) return r def describe_servable(self, name): """Get the description for a certain servable Args: name (string): DLHub name of the servable of the form <user>/<servable_name> Returns: dict: Summary of the servable """ split_name = name.split('/') if len(split_name) < 2: raise AttributeError( 'Please enter name in the form <user>/<servable_name>') # Create a query for a single servable query = self.query.match_servable('/'.join(split_name[1:]))\ .match_owner(split_name[0]).add_sort("dlhub.publication_date", False)\ .search(limit=1) # Raise error if servable is not found if len(query) == 0: raise AttributeError('No such servable: {}'.format(name)) return query[0] def describe_methods(self, name, method=None): """Get the description for the method(s) of a certain servable Args: name (string): DLHub name of the servable of the form <user>/<servable_name> method (string): Optional: Name of the method Returns: dict: Description of a certain method if ``method`` provided, all methods if the method name was not provided. """ metadata = self.describe_servable(name) return get_method_details(metadata, method) def run(self, name, inputs, asynchronous=False, async_wait=5, timeout: Optional[float] = None) -> Union[Any, DLHubFuture]: """Invoke a DLHub servable Args: name (string): DLHub name of the servable of the form <user>/<servable_name> inputs: Data to be used as input to the function. Can be a string of file paths or URLs asynchronous (bool): Whether to return from the function immediately or wait for the execution to finish. async_wait (float): How many seconds to wait between checking async status timeout (float): How long to wait for a result to return. Only used for synchronous calls Returns: Results of running the servable. If asynchronous, then a DLHubFuture holding the result """ if name not in self.fx_cache: # Look it up and add it to the cache, this will raise an exception if not found. serv = self.describe_servable(name) self.fx_cache.update({name: serv['dlhub']['funcx_id']}) funcx_id = self.fx_cache[name] payload = {'data': inputs} task_id = self._fx_client.run(payload, endpoint_id=self.fx_endpoint, function_id=funcx_id) # Return the result future = DLHubFuture(self, task_id, async_wait) return future.result(timeout=timeout) if not asynchronous else future def run_serial(self, servables, inputs, async_wait=5): """Invoke each servable in a serial pipeline. This function accepts a list of servables and will run each one, passing the output of one as the input to the next. Args: servables (list): A list of servable strings inputs: Data to pass to the first servable asycn_wait (float): Seconds to wait between status checks Returns: Results of running the servable """ if not isinstance(servables, list): print("run_serial requires a list of servables to invoke.") serv_data = inputs for serv in servables: serv_data = self.run(serv, serv_data, async_wait=async_wait) return serv_data def get_result(self, task_id, verbose=False): """Get the result of a task_id Args: task_id str: The task's uuid verbose bool: whether or not to return the full dlhub response Returns: Reult of running the servable """ result = self._fx_client.get_result(task_id) if isinstance(result, tuple) and not verbose: result = result[0] return result def publish_servable(self, model): """Submit a servable to DLHub If this servable has not been published before, it will be assigned a unique identifier. If it has been published before (DLHub detects if it has an identifier), then DLHub will update the servable to the new version. Args: model (BaseMetadataModel): Servable to be submitted Returns: (string): Task ID of this submission, used for checking for success """ # Get the metadata metadata = model.to_dict(simplify_paths=True) # Mark the method used to submit the model metadata['dlhub']['transfer_method'] = {'POST': 'file'} # Validate against the servable schema validate_against_dlhub_schema(metadata, 'servable') # Wipe the fx cache so we don't keep reusing an old servable self.clear_funcx_cache() # Get the data to be submitted as a ZIP file fp, zip_filename = mkstemp('.zip') os.close(fp) os.unlink(zip_filename) try: model.get_zip_file(zip_filename) # Get the authorization headers headers = {} self.authorizer.set_authorization_header(headers) # Submit data to DLHub service with open(zip_filename, 'rb') as zf: reply = requests.post(slash_join(self.base_url, 'publish'), headers=headers, files={ 'json': ('dlhub.json', json.dumps(metadata), 'application/json'), 'file': ('servable.zip', zf, 'application/octet-stream') }) # Return the task id if reply.status_code != 200: raise Exception(reply.text) return reply.json()['task_id'] finally: os.unlink(zip_filename) def publish_repository(self, repository): """Submit a repository to DLHub for publication Args: repository (string): Repository to publish Returns: (string): Task ID of this submission, used for checking for success """ # Publish to DLHub metadata = {"repository": repository} # Wipe the fx cache so we don't keep reusing an old servable self.clear_funcx_cache() response = self.post('publish_repo', json_body=metadata) task_id = response.data['task_id'] return task_id def search(self, query, advanced=False, limit=None, only_latest=True): """Query the DLHub servable library By default, the query is used as a simple plaintext search of all model metadata. Optionally, you can provided an advanced query on any of the indexed fields in the DLHub model metadata by setting :code:`advanced=True` and following the guide for constructing advanced queries found in the `Globus Search documentation <https://docs.globus.org/api/search/search/#query_syntax>`_. Args: query (string): Query to be performed advanced (bool): Whether to perform an advanced query limit (int): Maximum number of entries to return only_latest (bool): Whether to return only the latest version of the model Returns: ([dict]): All records matching the search query """ results = self.query.search(query, advanced=advanced, limit=limit) return filter_latest(results) if only_latest else results def search_by_servable(self, servable_name=None, owner=None, version=None, only_latest=True, limit=None, get_info=False): """Search by the ownership, name, or version of a servable Args: servable_name (str): The name of the servable. **Default**: None, to match all servable names. owner (str): The name of the owner of the servable. **Default**: ``None``, to match all owners. version (int): Model version, which corresponds to the date when the servable was published. **Default**: ``None``, to match all versions. only_latest (bool): When ``True``, will only return the latest version of each servable. When ``False``, will return all matching versions. **Default**: ``True``. limit (int): The maximum number of results to return. **Default:** ``None``, for no limit. get_info (bool): If ``False``, search will return a list of the results. If ``True``, search will return a tuple containing the results list and other information about the query. **Default:** ``False``. Returns: If ``info`` is ``False``, *list*: The search results. If ``info`` is ``True``, *tuple*: The search results, and a dictionary of query information. """ if not servable_name and not owner and not version: raise ValueError( "One of 'servable_name', 'owner', or 'publication_date' is required." ) # Perform the query results, info = (self.query.match_servable( servable_name=servable_name, owner=owner, publication_date=version).search(limit=limit, info=True)) # Filter out the latest models if only_latest: results = filter_latest(results) if get_info: return results, info return results def search_by_authors(self, authors, match_all=True, limit=None, only_latest=True): """Execute a search for servables from certain authors. Authors in DLHub may be different than the owners of the servable and generally are the people who developed functionality of a certain servable (e.g., the creators of the machine learning model used in a servable). If you want to search by ownership, see :meth:`search_by_servable` Args: authors (str or list of str): The authors to match. Names must be in "Family Name, Given Name" format match_all (bool): If ``True``, will require all authors be on any results. If ``False``, will only require one author to be in results. **Default**: ``True``. limit (int): The maximum number of results to return. **Default:** ``None``, for no limit. only_latest (bool): When ``True``, will only return the latest version of each servable. When ``False``, will return all matching versions. **Default**: ``True``. Returns: [dict]: List of servables from the desired authors """ results = self.query.match_authors( authors, match_all=match_all).search(limit=limit) return filter_latest(results) if only_latest else results def search_by_related_doi(self, doi, limit=None, only_latest=True): """Get all of the servables associated with a certain publication Args: doi (string): DOI of related paper limit (int): Maximum number of results to return only_latest (bool): Whether to return only the most recent version of the model Returns: [dict]: List of servables from the requested paper """ results = self.query.match_doi(doi).search(limit=limit) return filter_latest(results) if only_latest else results def clear_funcx_cache(self, servable=None): """Remove functions from the cache. Either remove a specific servable or wipe the whole cache. Args: Servable: str The name of the servable to remove. Default None """ if servable: del (self.fx_cache[servable]) else: self.fx_cache = {} return self.fx_cache
with open("/home/zzli/.funcx/{}/config.py".format(args.endpoint_name), 'w') as f: f.write(config) # Start the endpoint endpoint_name = args.endpoint_name cmd = "funcx-endpoint start {}".format(endpoint_name) try: subprocess.call(cmd, shell=True) except Exception as e: print(e) print("Started the endpoint {}".format(args.endpoint_id)) print("Wating 60 seconds for the endpoint to start") time.sleep(60) fxc = FuncXClient() func_uuid = fxc.register_function(dlhub_test, description="A sum function") print("The functoin uuid is {}".format(func_uuid)) fxs = FuncXSerializer() def test(tasks=1, data=[1], timeout=float('inf'), endpoint_id=None, function_id=None, poll=0.1): start = time.time() res = fxc.run(data, endpoint_id=endpoint_id, function_id=function_id)
import sys from time import sleep import json NUM_RUNS = 70 import requests from funcx.sdk.client import FuncXClient pyhf_endpoint = 'a727e996-7836-4bec-9fa2-44ebf7ca5302' fxc = FuncXClient() fxc.max_requests = 200 def prepare_workspace(data): import pyhf w = pyhf.Workspace(data) return w prepare_func = fxc.register_function(prepare_workspace) def infer_hypotest(w, metadata, doc): import pyhf import time tick = time.time() m = w.model(patches=[doc], modifier_settings={ "normsys": { "interpcode": "code4"
task_ids = fxc.map_run(list(range(task_count)), endpoint_id=ep_id, function_id=fn_uuid) delta = time.time() - start print(f"Time to launch {task_count} tasks: {delta:8.3f} s") print(f"Got {len(task_ids)} tasks_ids ") for _i in range(3): x = fxc.get_batch_result(task_ids) complete_count = sum(1 for t in task_ids if t in x and x[t].get("pending", False)) print(f"Batch status : {complete_count}/{len(task_ids)} complete") if complete_count == len(task_ids): break time.sleep(2) delta = time.time() - start print(f"Time to complete {task_count} tasks: {delta:8.3f} s") print(f"Throughput : {task_count / delta:8.3f} Tasks/s") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-e", "--endpoint", required=True) parser.add_argument("-c", "--count", default="10") args = parser.parse_args() print("FuncX version : ", funcx.__version__) fxc = FuncXClient(funcx_service_address="https://dev.funcx.org/api/v1") test(fxc, args.endpoint, task_count=int(args.count))