Beispiel #1
0
def get_result(task_id):
    """Get the result of the function.
    """
    fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1')
    res = fxc.get_result(task_id)
    print(res)
    return res
Beispiel #2
0
    def __init__(
        self,
        fx_auth,
        search_auth,
        openid_auth,
        endpoint_id,
        func,
        expected,
        args=None,
        timeout=15,
        concurrency=1,
        tol=1e-5,
    ):
        self.endpoint_id = endpoint_id
        self.func = func
        self.expected = expected
        self.args = args
        self.timeout = timeout
        self.concurrency = concurrency
        self.tol = tol
        self.fxc = FuncXClient(
            fx_authorizer=fx_auth,
            search_authorizer=search_auth,
            openid_authorizer=openid_auth,
        )
        self.func_uuid = self.fxc.register_function(self.func)

        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.DEBUG)
        handler = logging.StreamHandler(sys.stdout)
        handler.setLevel(logging.DEBUG)
        formatter = logging.Formatter(
            "%(asctime)s %(name)s:%(lineno)d [%(levelname)s]  %(message)s")
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
Beispiel #3
0
def get_deser_result(task_id):
    """Get the result of the function.
    """
    fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1')
    res = fxc.get(f"/tasks/{task_id}?deserialize=True")
    print(res)
    return res
Beispiel #4
0
def run_function_wait_result(
        py_fn,
        py_fn_args,
        py_fn_kwargs={},
        endpoint_id="3c3f0b4f-4ae4-4241-8497-d7339972ff4a",
        print_status=True):
    """
    Register and run a function with FuncX, wait for execution,
        and return results when they are available

    :param py_fn: Handle of Python function
    :param py_fn_args: List of positional args for py function
    :param py_fn_kwargs: Dict of keyword args for py function,
    :param endpoint_id: ID of endpoint to run command on
        - must be configured in config.py
    """
    fxc = FuncXClient()
    func_uuid = fxc.register_function(py_fn)
    res = fxc.run(*py_fn_args,
                  **py_fn_kwargs,
                  endpoint_id=endpoint_id,
                  function_id=func_uuid)
    while True:
        try:
            if print_status:
                print("Waiting for results...")
            time.sleep(FUNCX_SLEEP_TIME)
            return str(fxc.get_result(res), encoding="utf-8")
            break
        except Exception as e:
            if "waiting-for-" in str(e):
                continue
            else:
                raise e
def register_container():
    from funcx.sdk.client import FuncXClient
    fxc = FuncXClient()
    from gladier_xpcs.tools.corr import eigen_corr
    cont_dir = '/eagle/APSDataAnalysis/XPCS_test/containers/'
    container_name = 'eigen_v2.simg'
    eigen_cont_id = fxc.register_container(location=cont_dir+container_name,container_type='singularity')
    corr_cont_fxid = fxc.register_function(eigen_corr, container_uuid=eigen_cont_id)
    return corr_cont_fxid
Beispiel #6
0
    def __init__(self, dlh_authorizer=None, search_client=None, http_timeout=None,
                 force_login=False, fx_authorizer=None, **kwargs):
        """Initialize the client

        Args:
            dlh_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with DLHub.
                If ``None``, will be created.
            search_client (:class:`SearchClient <globus_sdk.SearchClient>`):
                An authenticated SearchClient to communicate with Globus Search.
                If ``None``, will be created.
            http_timeout (int): Timeout for any call to service in seconds. (default is no timeout)
            force_login (bool): Whether to force a login to get new credentials.
                A login will always occur if ``dlh_authorizer`` or ``search_client``
                are not provided.
            no_local_server (bool): Disable spinning up a local server to automatically
                copy-paste the auth code. THIS IS REQUIRED if you are on a remote server.
                When used locally with no_local_server=False, the domain is localhost with
                a randomly chosen open port number.
                **Default**: ``True``.
            fx_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with funcX.
                If ``None``, will be created.
            no_browser (bool): Do not automatically open the browser for the Globus Auth URL.
                Display the URL instead and let the user navigate to that location manually.
                **Default**: ``True``.
        Keyword arguments are the same as for BaseClient.
        """
        if force_login or not dlh_authorizer or not search_client or not fx_authorizer:

            fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
            auth_res = login(services=["search", "dlhub",
                                       fx_scope],
                             app_name="DLHub_Client",
                             client_id=CLIENT_ID, clear_old_tokens=force_login,
                             token_dir=_token_dir, no_local_server=kwargs.get("no_local_server", True),
                             no_browser=kwargs.get("no_browser", True))
            dlh_authorizer = auth_res["dlhub"]
            fx_authorizer = auth_res[fx_scope]
            self._search_client = auth_res["search"]
            self._fx_client = FuncXClient(force_login=True,fx_authorizer=fx_authorizer,
                                          funcx_service_address='https://funcx.org/api/v1')

        # funcX endpoint to use
        self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000'
        # self.fx_endpoint = '2c92a06a-015d-4bfa-924c-b3d0c36bdad7'
        self.fx_serializer = FuncXSerializer()
        self.fx_cache = {}
        super(DLHubClient, self).__init__("DLHub", environment='dlhub', authorizer=dlh_authorizer,
                                          http_timeout=http_timeout, base_url=DLHUB_SERVICE_ADDRESS,
                                          **kwargs)
Beispiel #7
0
 def handle(self, *args, **options):
     if options['register']:
         fxc = FuncXClient()
         ep = fxc.register_function(process_hdfs,
                                    description="Process an hdf")
         self.stderr.write(f'FuncX function endpoint has been '
                           f'registered: {ep}')
         self.stderr.write(f'You need to add this somewhere manually!')
     elif options['test']:
         name = options['test']
         if not name:
             self.stderr.write('test needs the name of a search collection')
         bag = Bag.objects.filter(
             search_collection__name=options['test']).first()
         if bag:
             action = ReprocessingTask.new_action(bag, user=bag.user)
             action.save()
             rt = ReprocessingTask(bag=bag, action=action)
             rt.save()
             rt.action.start_flow()
             self.stdout.write(f'Started {action}')
         else:
             bags = [b.search_collection.name for b in Bag.objects.all()]
             self.stderr.write(f'No bag named {options["test"]}, please '
                               f'use one of the following instead {bags}')
     elif options['check']:
         rts = ReprocessingTask.objects.filter(action__status='ACTIVE')
         if not rts:
             self.stderr.write('No Tasks to update.')
         for rt in rts:
             old = rt.action.status
             rt.action.update_flow()
             self.stdout.write(f'Updated {rt.bag.search_collection.name} '
                               f'from "{old}" to "{rt.action.status}".')
     elif options.get('payload') is not None:
         raise NotImplementedError('This does not work yet...')
         pl = self.get_task_or_list_all(options['payload']).action.payload
         plain_pl = deserialize_payload(pl['ProcessDataInput']['payload'])
         pprint(plain_pl)
     elif options.get('output') is not None:
         task = self.get_task_or_list_all(options['output'])
         if not task:
             return
         automate_output = task.action.cache['details']['output']
         outputs = [
             data['details'] for name, data in automate_output.items()
             if 'details' in data
         ]
         for output in outputs:
             if 'result' in output:
                 pprint(deserialize_payload(output['result']))
             elif 'exception' in output:
                 deserialize_payload(output['exception']).reraise()
Beispiel #8
0
def run_real(payload):
    """Run a function with some raw json input.
    """
    fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1')

    # register a function
    func_id = fxc.register_function(test_func)
    ep_id = '60ad46e1-c912-468b-8674-4d582e9dc9ee'

    res = fxc.run({'name': 'real'}, function_id=func_id, endpoint_id=ep_id)

    print(res)

    return res
Beispiel #9
0
def init_endpoint():
    """Setup funcx dirs and default endpoint config files

    TODO : Every mechanism that will update the config file, must be using a
    locking mechanism, ideally something like fcntl https://docs.python.org/3/library/fcntl.html
    to ensure that multiple endpoint invocations do not mangle the funcx config files
    or the lockfile module.
    """
    _ = FuncXClient()

    if os.path.exists(State.FUNCX_CONFIG_FILE):
        typer.confirm(
            "Are you sure you want to initialize this directory? "
            f"This will erase everything in {State.FUNCX_DIR}", abort=True
        )
        logger.info("Wiping all current configs in {}".format(State.FUNCX_DIR))
        backup_dir = State.FUNCX_DIR + ".bak"
        try:
            logger.debug(f"Removing old backups in {backup_dir}")
            shutil.rmtree(backup_dir)
        except OSError:
            pass
        os.renames(State.FUNCX_DIR, backup_dir)

    if os.path.exists(State.FUNCX_CONFIG_FILE):
        logger.debug("Config file exists at {}".format(State.FUNCX_CONFIG_FILE))
        return

    try:
        os.makedirs(State.FUNCX_DIR, exist_ok=True)
    except Exception as e:
        print("[FuncX] Caught exception during registration {}".format(e))

    shutil.copyfile(State.FUNCX_DEFAULT_CONFIG_TEMPLATE, State.FUNCX_CONFIG_FILE)
    init_endpoint_dir("default")
Beispiel #10
0
def run_ser(payload):
    """Run a function with some raw json input.
    """
    fxc = FuncXClient(funcx_service_address='https://dev.funcx.org/api/v1')

    # register a function
    func_id = fxc.register_function(test_func)
    ep_id = '60ad46e1-c912-468b-8674-4d582e9dc9ee'
    payload = {
        'serialize': True,
        'payload': payload,
        'endpoint': ep_id,
        'func': func_id
    }

    res = fxc.post('submit', json_body=payload)
    res = res['task_uuid']
    print(res)

    return res
Beispiel #11
0
def run_function_async(py_fn,
                       py_fn_args,
                       py_fn_kwargs={},
                       endpoint_id="3c3f0b4f-4ae4-4241-8497-d7339972ff4a"):
    """
    Asynchronously register and run a Python function on a FuncX endpoint

    :param py_fn: Handle of Python function
    :param py_fn_args: List of positional args for py function
    :param py_fn_kwargs: Dict of keyword args for py function,
    :param endpoint_id: ID of endpoint to run command on
        - must be configured in config.py
    """
    # Use return value for Funcx polling
    fxc = FuncXClient()
    func_uuid = fxc.register_function(py_fn)
    res = fxc.run(*py_fn_args,
                  **kwargs,
                  endpoint_id=endpoint_id,
                  function_id=func_uuid)
    return res
Beispiel #12
0
class FuncXFuture(Future):
    client = FuncXClient()
    serializer = FuncXSerializer()

    def __init__(self, task_id, poll_period=1):
        super().__init__()
        self.task_id = task_id
        self.poll_period = poll_period
        self.__result = None
        self.submitted = time.time()

    def done(self):
        if self.__result is not None:
            return True
        try:
            data = FuncXFuture.client.get_task_status(self.task_id)
        except Exception:
            return False
        if 'status' in data and data['status'] == 'PENDING':
            time.sleep(
                self.poll_period)  # needed to not overwhelm the FuncX server
            return False
        elif 'result' in data:
            self.__result = FuncXFuture.serializer.deserialize(data['result'])
            self.returned = time.time()
            # FIXME AW benchmarking
            self.connected_managers = os.environ.get('connected_managers', -1)

            return True
        elif 'exception' in data:
            e = FuncXFuture.serializer.deserialize(data['exception'])
            e.reraise()
        else:
            raise NotImplementedError(
                'task {} is neither pending or finished: {}'.format(
                    self.task_id, str(data)))

    def result(self, timeout=None):
        if self.__result is not None:
            return self.__result
        while True:
            if self.done():
                break
            else:
                time.sleep(self.poll_period)
                if timeout is not None:
                    timeout -= self.poll_period
                    if timeout < 0:
                        raise TimeoutError

        return self.__result
Beispiel #13
0
def register_funcx(task):
    """Register the function and the container with funcX.

    Parameters
    ----------
    task : dict
        A dict of the task to publish

    Returns
    -------
    str
        The funcX function id
    """

    # Get the funcX dependent token
    fx_token = task['dlhub']['funcx_token']
    # Create a client using this token
    fx_auth = globus_sdk.AccessTokenAuthorizer(fx_token)
    fxc = FuncXClient(fx_authorizer=fx_auth, use_offprocess_checker=False)
    description = f"A container for the DLHub model {task['dlhub']['shorthand_name']}"
    try:
        description = task['datacite']['descriptions'][0]['description']
    except:
        # It doesn't have a simple description
        pass
    # Register the container with funcX
    container_id = fxc.register_container(task['dlhub']['ecr_uri'],
                                          'docker',
                                          name=task['dlhub']['shorthand_name'],
                                          description=description)

    # Register a function
    funcx_id = fxc.register_function(dlhub_run,
                                     function_name=task['dlhub']['name'],
                                     container_uuid=container_id,
                                     description=description,
                                     public=True)

    # Whitelist the function on DLHub's endpoint
    # First create a new fxc client on DLHub's behalf
    fxc = FuncXClient(use_offprocess_checker=False)
    endpoint_uuid = '86a47061-f3d9-44f0-90dc-56ddc642c000'
    res = fxc.add_to_whitelist(endpoint_uuid, [funcx_id])
    print(res)
    return funcx_id
Beispiel #14
0
    def init_endpoint(self):
        """Setup funcx dirs and default endpoint config files

        TODO : Every mechanism that will update the config file, must be using a
        locking mechanism, ideally something like fcntl [1]
        to ensure that multiple endpoint invocations do not mangle the funcx config
        files or the lockfile module.

        [1] https://docs.python.org/3/library/fcntl.html
        """
        _ = FuncXClient()

        if os.path.exists(self.funcx_config_file):
            typer.confirm(
                "Are you sure you want to initialize this directory? "
                f"This will erase everything in {self.funcx_dir}",
                abort=True,
            )
            log.info(f"Wiping all current configs in {self.funcx_dir}")
            backup_dir = self.funcx_dir + ".bak"
            try:
                log.debug(f"Removing old backups in {backup_dir}")
                shutil.rmtree(backup_dir)
            except OSError:
                pass
            os.renames(self.funcx_dir, backup_dir)

        if os.path.exists(self.funcx_config_file):
            log.debug(f"Config file exists at {self.funcx_config_file}")
            return

        try:
            os.makedirs(self.funcx_dir, exist_ok=True)
        except Exception as e:
            print(f"[FuncX] Caught exception during registration {e}")

        shutil.copyfile(self.funcx_default_config_template,
                        self.funcx_config_file)
"""Run toxicity inference with FuncX"""
from funcx.sdk.client import FuncXClient
import json

# Get FuncX ready
fxc = FuncXClient()
theta_ep = 'd3a23590-3282-429a-8bce-e0ca0f4177f3'
with open('func_uuid.json') as fp:
    func_id = json.load(fp)

# Run the infernece
smiles = ['C', 'CC', 'CCC']
task_id = fxc.run(smiles, endpoint_id=theta_ep, function_id=func_id)
print(task_id)
Beispiel #16
0
        with open(result_path, 'w') as f:
            print('problem loading processor instance for {}'.format(
                str(item)),
                  file=f)
            print(e, file=f)
            print('environment:', file=f)
            for key, value in os.environ.items():
                print('{}: {}'.format(key, value), file=f)
            print('hostname: {}'.format(subprocess.check_output('hostname',
                                                                shell=True),
                                        file=f))

    if stageout_url.startswith('root://'):
        command = 'xrdcp {} {}'.format(result_path,
                                       os.path.join(stageout_url, subdir))
        subprocess.call(command, shell=True)
        os.unlink(result_path)

    return os.path.join(subdir, os.path.basename(result_path))


client = FuncXClient()
uuids = {}
for func in [process]:
    f = timeout(func)
    uuids[func.__name__] = client.register_function(f)

with open('data/function_uuids.json', 'w') as f:
    f.write(json.dumps(uuids, indent=4, sort_keys=True))
    import mdml_client as mdml
    # Connect to the MDML to query for data
    exp = mdml.experiment(params['experiment_id'], params['user'],
                          params['pass'], params['host'])
    # Grabbing the latest temperature value
    query = [{"device": "DATA1", "variables": ["temperature"], "last": 1}]
    res = exp.query(query, verify_cert=False)  # Running the query
    tempF = res['DATA1'][0]['temperature']
    tempC = (tempF - 32) * (5 / 9)
    return {'time': mdml.unix_time(True), 'tempC': tempC}


# Registering the function
if args.register:
    from funcx.sdk.client import FuncXClient
    fxc = FuncXClient()
    funcx_func_uuid = fxc.register_function(
        basic_analysis, description="Temperature conversion")
    print(f'funcX UUID: {funcx_func_uuid}')
else:  # Use the most recent function funcx ID (manually put here after running --register once)
    funcx_func_uuid = "1712a2fc-cc40-4b2c-ae44-405d58f78c5d"  # Sept 16th 2020

# Now that the function is ready for use, we need to start an experiment to use it with
import sys
import time
sys.path.insert(1, '../')  # using local mdml_client
import mdml_client as mdml

exp = mdml.experiment("TEST", args.username, args.password, args.host)
exp.add_config(auto=True)
exp.send_config()
Beispiel #18
0
    def __init__(self,
                 config,
                 client_address="127.0.0.1",
                 interchange_address="127.0.0.1",
                 client_ports=(50055, 50056, 50057),
                 worker_ports=None,
                 worker_port_range=(54000, 55000),
                 cores_per_worker=1.0,
                 worker_debug=False,
                 launch_cmd=None,
                 heartbeat_threshold=60,
                 logdir=".",
                 logging_level=logging.INFO,
                 poll_period=10,
                 endpoint_id=None,
                 suppress_failure=False,
                 max_heartbeats_missed=2
                 ):
        """
        Parameters
        ----------
        config : funcx.Config object
             Funcx config object that describes how compute should be provisioned

        client_address : str
             The ip address at which the parsl client can be reached. Default: "127.0.0.1"

        interchange_address : str
             The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1"

        client_ports : triple(int, int, int)
             The ports at which the client can be reached

        launch_cmd : str
             TODO : update

        worker_ports : tuple(int, int)
             The specific two ports at which workers will connect to the Interchange. Default: None

        worker_port_range : tuple(int, int)
             The interchange picks ports at random from the range which will be used by workers.
             This is overridden when the worker_ports option is set. Defauls: (54000, 55000)

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        worker_debug : Bool
             Enables worker debug logging.

        heartbeat_threshold : int
             Number of seconds since the last heartbeat after which worker is considered lost.

        logdir : str
             Parsl log directory paths. Logs and temp files go here. Default: '.'

        logging_level : int
             Logging level as defined in the logging module. Default: logging.INFO (20)

        endpoint_id : str
             Identity string that identifies the endpoint to the broker

        poll_period : int
             The main thread polling period, in milliseconds. Default: 10ms

        suppress_failure : Bool
             When set to True, the interchange will attempt to suppress failures. Default: False

        max_heartbeats_missed : int
             Number of heartbeats missed before setting kill_event

        """
        self.logdir = logdir
        try:
            os.makedirs(self.logdir)
        except FileExistsError:
            pass

        start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level)
        logger.info("logger location {}".format(logger.handlers))
        logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id))
        self.config = config
        logger.info("Got config : {}".format(config))

        self.strategy = self.config.strategy
        self.client_address = client_address
        self.interchange_address = interchange_address
        self.suppress_failure = suppress_failure
        self.poll_period = poll_period

        self.serializer = FuncXSerializer()
        logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
            client_address, client_ports[0], client_ports[1], client_ports[2]))
        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.set_hwm(0)
        self.task_incoming.RCVTIMEO = 10  # in milliseconds
        logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0]))
        self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))

        self.results_outgoing = self.context.socket(zmq.DEALER)
        self.results_outgoing.set_hwm(0)
        logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1]))
        self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))

        self.command_channel = self.context.socket(zmq.DEALER)
        self.command_channel.RCVTIMEO = 1000  # in milliseconds
        # self.command_channel.set_hwm(0)
        logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2]))
        self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
        logger.info("Connected to client")

        self.pending_task_queue = {}
        self.containers = {}
        self.total_pending_task_count = 0
        self.fxs = FuncXClient()

        logger.info("Interchange address is {}".format(self.interchange_address))
        self.worker_ports = worker_ports
        self.worker_port_range = worker_port_range

        self.task_outgoing = self.context.socket(zmq.ROUTER)
        self.task_outgoing.set_hwm(0)
        self.results_incoming = self.context.socket(zmq.ROUTER)
        self.results_incoming.set_hwm(0)

        # initalize the last heartbeat time to start the loop
        self.last_heartbeat = time.time()
        self.max_heartbeats_missed = max_heartbeats_missed

        self.endpoint_id = endpoint_id
        if self.worker_ports:
            self.worker_task_port = self.worker_ports[0]
            self.worker_result_port = self.worker_ports[1]

            self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port))
            self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port))

        else:
            self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*',
                                                                           min_port=worker_port_range[0],
                                                                           max_port=worker_port_range[1], max_tries=100)
            self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*',
                                                                                min_port=worker_port_range[0],
                                                                                max_port=worker_port_range[1], max_tries=100)

        logger.info("Bound to ports {},{} for incoming worker connections".format(
            self.worker_task_port, self.worker_result_port))

        self._ready_manager_queue = {}

        self.heartbeat_threshold = heartbeat_threshold
        self.blocks = {}  # type: Dict[str, str]
        self.block_id_map = {}
        self.launch_cmd = launch_cmd
        self.last_core_hr_counter = 0
        if not launch_cmd:
            self.launch_cmd = ("funcx-manager {debug} {max_workers} "
                               "-c {cores_per_worker} "
                               "--poll {poll_period} "
                               "--task_url={task_url} "
                               "--result_url={result_url} "
                               "--logdir={logdir} "
                               "--block_id={{block_id}} "
                               "--hb_period={heartbeat_period} "
                               "--hb_threshold={heartbeat_threshold} "
                               "--worker_mode={worker_mode} "
                               "--scheduler_mode={scheduler_mode} "
                               "--worker_type={{worker_type}} ")

        self.current_platform = {'parsl_v': PARSL_VERSION,
                                 'python_v': "{}.{}.{}".format(sys.version_info.major,
                                                               sys.version_info.minor,
                                                               sys.version_info.micro),
                                 'os': platform.system(),
                                 'hname': platform.node(),
                                 'dir': os.getcwd()}

        logger.info("Platform info: {}".format(self.current_platform))
        self._block_counter = 0
        try:
            self.load_config()
        except Exception as e:
            logger.exception("Caught exception")
            raise
Beispiel #19
0
class Interchange(object):
    """ Interchange is a task orchestrator for distributed systems.

    1. Asynchronously queue large volume of tasks (>100K)
    2. Allow for workers to join and leave the union
    3. Detect workers that have failed using heartbeats
    4. Service single and batch requests from workers
    5. Be aware of requests worker resource capacity,
       eg. schedule only jobs that fit into walltime.

    TODO: We most likely need a PUB channel to send out global commands, like shutdown
    """
    def __init__(self,
                 config,
                 client_address="127.0.0.1",
                 interchange_address="127.0.0.1",
                 client_ports=(50055, 50056, 50057),
                 worker_ports=None,
                 worker_port_range=(54000, 55000),
                 cores_per_worker=1.0,
                 worker_debug=False,
                 launch_cmd=None,
                 heartbeat_threshold=60,
                 logdir=".",
                 logging_level=logging.INFO,
                 poll_period=10,
                 endpoint_id=None,
                 suppress_failure=False,
                 max_heartbeats_missed=2
                 ):
        """
        Parameters
        ----------
        config : funcx.Config object
             Funcx config object that describes how compute should be provisioned

        client_address : str
             The ip address at which the parsl client can be reached. Default: "127.0.0.1"

        interchange_address : str
             The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1"

        client_ports : triple(int, int, int)
             The ports at which the client can be reached

        launch_cmd : str
             TODO : update

        worker_ports : tuple(int, int)
             The specific two ports at which workers will connect to the Interchange. Default: None

        worker_port_range : tuple(int, int)
             The interchange picks ports at random from the range which will be used by workers.
             This is overridden when the worker_ports option is set. Defauls: (54000, 55000)

        cores_per_worker : float
             cores to be assigned to each worker. Oversubscription is possible
             by setting cores_per_worker < 1.0. Default=1

        worker_debug : Bool
             Enables worker debug logging.

        heartbeat_threshold : int
             Number of seconds since the last heartbeat after which worker is considered lost.

        logdir : str
             Parsl log directory paths. Logs and temp files go here. Default: '.'

        logging_level : int
             Logging level as defined in the logging module. Default: logging.INFO (20)

        endpoint_id : str
             Identity string that identifies the endpoint to the broker

        poll_period : int
             The main thread polling period, in milliseconds. Default: 10ms

        suppress_failure : Bool
             When set to True, the interchange will attempt to suppress failures. Default: False

        max_heartbeats_missed : int
             Number of heartbeats missed before setting kill_event

        """
        self.logdir = logdir
        try:
            os.makedirs(self.logdir)
        except FileExistsError:
            pass

        start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level)
        logger.info("logger location {}".format(logger.handlers))
        logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id))
        self.config = config
        logger.info("Got config : {}".format(config))

        self.strategy = self.config.strategy
        self.client_address = client_address
        self.interchange_address = interchange_address
        self.suppress_failure = suppress_failure
        self.poll_period = poll_period

        self.serializer = FuncXSerializer()
        logger.info("Attempting connection to client at {} on ports: {},{},{}".format(
            client_address, client_ports[0], client_ports[1], client_ports[2]))
        self.context = zmq.Context()
        self.task_incoming = self.context.socket(zmq.DEALER)
        self.task_incoming.set_hwm(0)
        self.task_incoming.RCVTIMEO = 10  # in milliseconds
        logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0]))
        self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0]))

        self.results_outgoing = self.context.socket(zmq.DEALER)
        self.results_outgoing.set_hwm(0)
        logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1]))
        self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1]))

        self.command_channel = self.context.socket(zmq.DEALER)
        self.command_channel.RCVTIMEO = 1000  # in milliseconds
        # self.command_channel.set_hwm(0)
        logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2]))
        self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2]))
        logger.info("Connected to client")

        self.pending_task_queue = {}
        self.containers = {}
        self.total_pending_task_count = 0
        self.fxs = FuncXClient()

        logger.info("Interchange address is {}".format(self.interchange_address))
        self.worker_ports = worker_ports
        self.worker_port_range = worker_port_range

        self.task_outgoing = self.context.socket(zmq.ROUTER)
        self.task_outgoing.set_hwm(0)
        self.results_incoming = self.context.socket(zmq.ROUTER)
        self.results_incoming.set_hwm(0)

        # initalize the last heartbeat time to start the loop
        self.last_heartbeat = time.time()
        self.max_heartbeats_missed = max_heartbeats_missed

        self.endpoint_id = endpoint_id
        if self.worker_ports:
            self.worker_task_port = self.worker_ports[0]
            self.worker_result_port = self.worker_ports[1]

            self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port))
            self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port))

        else:
            self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*',
                                                                           min_port=worker_port_range[0],
                                                                           max_port=worker_port_range[1], max_tries=100)
            self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*',
                                                                                min_port=worker_port_range[0],
                                                                                max_port=worker_port_range[1], max_tries=100)

        logger.info("Bound to ports {},{} for incoming worker connections".format(
            self.worker_task_port, self.worker_result_port))

        self._ready_manager_queue = {}

        self.heartbeat_threshold = heartbeat_threshold
        self.blocks = {}  # type: Dict[str, str]
        self.block_id_map = {}
        self.launch_cmd = launch_cmd
        self.last_core_hr_counter = 0
        if not launch_cmd:
            self.launch_cmd = ("funcx-manager {debug} {max_workers} "
                               "-c {cores_per_worker} "
                               "--poll {poll_period} "
                               "--task_url={task_url} "
                               "--result_url={result_url} "
                               "--logdir={logdir} "
                               "--block_id={{block_id}} "
                               "--hb_period={heartbeat_period} "
                               "--hb_threshold={heartbeat_threshold} "
                               "--worker_mode={worker_mode} "
                               "--scheduler_mode={scheduler_mode} "
                               "--worker_type={{worker_type}} ")

        self.current_platform = {'parsl_v': PARSL_VERSION,
                                 'python_v': "{}.{}.{}".format(sys.version_info.major,
                                                               sys.version_info.minor,
                                                               sys.version_info.micro),
                                 'os': platform.system(),
                                 'hname': platform.node(),
                                 'dir': os.getcwd()}

        logger.info("Platform info: {}".format(self.current_platform))
        self._block_counter = 0
        try:
            self.load_config()
        except Exception as e:
            logger.exception("Caught exception")
            raise


    def load_config(self):
        """ Load the config
        """
        logger.info("Loading endpoint local config")
        working_dir = self.config.working_dir
        if self.config.working_dir is None:
            working_dir = "{}/{}".format(self.logdir, "worker_logs")
        logger.info("Setting working_dir: {}".format(working_dir))

        self.config.provider.script_dir = working_dir
        if hasattr(self.config.provider, 'channel'):
            self.config.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts')
            self.config.provider.channel.makedirs(self.config.provider.channel.script_dir, exist_ok=True)
            os.makedirs(self.config.provider.script_dir, exist_ok=True)

        debug_opts = "--debug" if self.config.worker_debug else ""
        max_workers = "" if self.config.max_workers_per_node == float('inf') \
                      else "--max_workers={}".format(self.config.max_workers_per_node)

        worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}"
        worker_result_url = f"tcp://{self.interchange_address}:{self.worker_result_port}"

        l_cmd = self.launch_cmd.format(debug=debug_opts,
                                       max_workers=max_workers,
                                       cores_per_worker=self.config.cores_per_worker,
                                       #mem_per_worker=self.config.mem_per_worker,
                                       prefetch_capacity=self.config.prefetch_capacity,
                                       task_url=worker_task_url,
                                       result_url=worker_result_url,
                                       nodes_per_block=self.config.provider.nodes_per_block,
                                       heartbeat_period=self.config.heartbeat_period,
                                       heartbeat_threshold=self.config.heartbeat_threshold,
                                       poll_period=self.config.poll_period,
                                       worker_mode=self.config.worker_mode,
                                       scheduler_mode=self.config.scheduler_mode,
                                       logdir=working_dir)
        self.launch_cmd = l_cmd
        logger.info("Launch command: {}".format(self.launch_cmd))

        if self.config.scaling_enabled:
            logger.info("Scaling ...")
            self.scale_out(self.config.provider.init_blocks)


    def get_tasks(self, count):
        """ Obtains a batch of tasks from the internal pending_task_queue

        Parameters
        ----------
        count: int
            Count of tasks to get from the queue

        Returns
        -------
        List of upto count tasks. May return fewer than count down to an empty list
            eg. [{'task_id':<x>, 'buffer':<buf>} ... ]
        """
        tasks = []
        for i in range(0, count):
            try:
                x = self.pending_task_queue.get(block=False)
            except queue.Empty:
                break
            else:
                tasks.append(x)

        return tasks

    def migrate_tasks_to_internal(self, kill_event, status_request):
        """Pull tasks from the incoming tasks 0mq pipe onto the internal
        pending task queue

        Parameters:
        -----------
        kill_event : threading.Event
              Event to let the thread know when it is time to die.
        """
        logger.info("[TASK_PULL_THREAD] Starting")
        task_counter = 0
        poller = zmq.Poller()
        poller.register(self.task_incoming, zmq.POLLIN)

        while not kill_event.is_set():
            # Check when the last heartbeat was.
            # logger.debug(f"[TASK_PULL_THREAD] Last heartbeat: {self.last_heartbeat}")
            if int(time.time() - self.last_heartbeat) > (self.heartbeat_threshold * self.max_heartbeats_missed):
                logger.critical("[TASK_PULL_THREAD] Missed too many heartbeats. Setting kill event.")
                kill_event.set()
                break

            try:
                msg = self.task_incoming.recv_pyobj()
                self.last_heartbeat = time.time()
            except zmq.Again:
                # We just timed out while attempting to receive
                logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count))
                continue

            if msg == 'STOP':
                kill_event.set()
                break
            elif msg == 'STATUS_REQUEST':
                logger.info("Got STATUS_REQUEST")
                status_request.set()
            else:
                logger.info("[TASK_PULL_THREAD] Received task:{}".format(msg))
                task_type = self.get_container(msg['task_id'].split(";")[1])
                msg['container'] = task_type
                if task_type not in self.pending_task_queue:
                    self.pending_task_queue[task_type] = queue.Queue(maxsize=10 ** 6)
                self.pending_task_queue[task_type].put(msg)
                self.total_pending_task_count += 1
                logger.debug("[TASK_PULL_THREAD] pending task count: {}".format(self.total_pending_task_count))
                task_counter += 1
                logger.debug("[TASK_PULL_THREAD] Fetched task:{}".format(task_counter))

    def get_container(self, container_uuid):
        """ Get the container image location if it is not known to the interchange"""
        if container_uuid not in self.containers:
            if container_uuid == 'RAW' or not container_uuid:
                self.containers[container_uuid] = 'RAW'
            else:
                try:
                    container = self.fxs.get_container(container_uuid, self.config.container_type)
                except Exception:
                    logger.exception("[FETCH_CONTAINER] Unable to resolve container location")
                    self.containers[container_uuid] = 'RAW'
                else:
                    logger.info("[FETCH_CONTAINER] Got container info: {}".format(container))
                    self.containers[container_uuid] = container.get('location', 'RAW')
        return self.containers[container_uuid]

    def get_total_tasks_outstanding(self):
        """ Get the outstanding tasks in total
        """
        outstanding = {}
        for task_type in self.pending_task_queue:
            outstanding[task_type] = outstanding.get(task_type, 0) + self.pending_task_queue[task_type].qsize()
        for manager in self._ready_manager_queue:
            for task_type in self._ready_manager_queue[manager]['tasks']:
                outstanding[task_type] = outstanding.get(task_type, 0) + len(self._ready_manager_queue[manager]['tasks'][task_type])
        return outstanding

    def get_total_live_workers(self):
        """ Get the total active workers
        """
        active = 0
        for manager in self._ready_manager_queue:
            if self._ready_manager_queue[manager]['active']:
                active += self._ready_manager_queue[manager]['max_worker_count']
        return active

    def get_outstanding_breakdown(self):
        """ Get outstanding breakdown per manager and in the interchange queues

        Returns
        -------
        List of status for online elements
        [ (element, tasks_pending, status) ... ]
        """

        pending_on_interchange = self.total_pending_task_count
        # Reporting pending on interchange is a deviation from Parsl
        reply = [('interchange', pending_on_interchange, True)]
        for manager in self._ready_manager_queue:
            resp = (manager.decode('utf-8'),
                    sum([len(tids) for tids in self._ready_manager_queue[manager]['tasks'].values()]),
                    self._ready_manager_queue[manager]['active'])
            reply.append(resp)
        return reply

    def _hold_block(self, block_id):
        """ Sends hold command to all managers which are in a specific block

        Parameters
        ----------
        block_id : str
             Block identifier of the block to be put on hold
        """
        for manager in self._ready_manager_queue:
            if self._ready_manager_queue[manager]['active'] and \
               self._ready_manager_queue[manager]['block_id'] == block_id:
                logger.debug("[HOLD_BLOCK]: Sending hold to manager: {}".format(manager))
                self.hold_manager(manager)

    def hold_manager(self, manager):
        """ Put manager on hold
        Parameters
        ----------

        manager : str
          Manager id to be put on hold while being killed
        """
        if manager in self._ready_manager_queue:
            self._ready_manager_queue[manager]['active'] = False
            reply = True
        else:
            reply = False

    def _command_server(self, kill_event):
        """ Command server to run async command to the interchange
        """
        logger.debug("[COMMAND] Command Server Starting")

        while not kill_event.is_set():
            try:
                command_req = self.command_channel.recv_pyobj()
                logger.debug("[COMMAND] Received command request: {}".format(command_req))
                if command_req == "OUTSTANDING_C":
                    reply = self.get_total_outstanding()

                elif command_req == "MANAGERS":
                    reply = self.get_outstanding_breakdown()

                elif command_req.startswith("HOLD_WORKER"):
                    cmd, s_manager = command_req.split(';')
                    manager = s_manager.encode('utf-8')
                    logger.info("[CMD] Received HOLD_WORKER for {}".format(manager))
                    if manager in self._ready_manager_queue:
                        self._ready_manager_queue[manager]['active'] = False
                        reply = True
                    else:
                        reply = False

                elif command_req == "HEARTBEAT":
                    logger.info("[CMD] Received heartbeat message from hub")
                    reply = "HBT,{}".format(self.endpoint_id)

                elif command_req == "SHUTDOWN":
                    logger.info("[CMD] Received SHUTDOWN command")
                    kill_event.set()
                    reply = True

                else:
                    reply = None

                logger.debug("[COMMAND] Reply: {}".format(reply))
                self.command_channel.send_pyobj(reply)

            except zmq.Again:
                logger.debug("[COMMAND] is alive")
                continue

    def stop(self):
        """Prepare the interchange for shutdown"""
        self._kill_event.set()

        self._task_puller_thread.join()
        self._command_thread.join()

    def start(self, poll_period=None):
        """ Start the Interchange

        Parameters:
        ----------
        poll_period : int
           poll_period in milliseconds
        """
        logger.info("Incoming ports bound")

        if poll_period is None:
            poll_period = self.poll_period

        start = time.time()
        count = 0

        self._kill_event = threading.Event()
        self._status_request = threading.Event()
        self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal,
                                                    args=(self._kill_event, self._status_request, ))
        self._task_puller_thread.start()

        self._command_thread = threading.Thread(target=self._command_server,
                                                args=(self._kill_event, ))
        self._command_thread.start()

        try:
            logger.debug("Starting strategy.")
            self.strategy.start(self)
        except RuntimeError as e:
            # This is raised when re-registering an endpoint as strategy already exists
            logger.debug("Failed to start strategy.")
            logger.info(e)

        poller = zmq.Poller()
        # poller.register(self.task_incoming, zmq.POLLIN)
        poller.register(self.task_outgoing, zmq.POLLIN)
        poller.register(self.results_incoming, zmq.POLLIN)

        # These are managers which we should examine in an iteration
        # for scheduling a job (or maybe any other attention?).
        # Anything altering the state of the manager should add it
        # onto this list.
        interesting_managers = set()

        while not self._kill_event.is_set():
            self.socks = dict(poller.poll(timeout=poll_period))

            # Listen for requests for work
            if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN:
                logger.debug("[MAIN] starting task_outgoing section")
                message = self.task_outgoing.recv_multipart()
                manager = message[0]

                if manager not in self._ready_manager_queue:
                    reg_flag = False

                    try:
                        msg = json.loads(message[1].decode('utf-8'))
                        reg_flag = True
                    except Exception:
                        logger.warning("[MAIN] Got a non-json registration message from manager:{}".format(
                            manager))
                        logger.debug("[MAIN] Message :\n{}\n".format(message))

                    # By default we set up to ignore bad nodes/registration messages.
                    self._ready_manager_queue[manager] = {'last': time.time(),
                                                          'reg_time': time.time(),
                                                          'free_capacity': {'total_workers': 0},
                                                          'max_worker_count': 0,
                                                          'active': True,
                                                          'tasks': collections.defaultdict(set),
                                                          'total_tasks': 0}
                    if reg_flag is True:
                        interesting_managers.add(manager)
                        logger.info("[MAIN] Adding manager: {} to ready queue".format(manager))
                        self._ready_manager_queue[manager].update(msg)
                        logger.info("[MAIN] Registration info for manager {}: {}".format(manager, msg))

                        if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or
                            msg['parsl_v'] != self.current_platform['parsl_v']):
                            logger.warn("[MAIN] Manager {} has incompatible version info with the interchange".format(manager))

                            if self.suppress_failure is False:
                                logger.debug("Setting kill event")
                                self._kill_event.set()
                                e = ManagerLost(manager)
                                result_package = {'task_id': -1,
                                                  'exception': self.serializer.serialize(e)}
                                pkl_package = pickle.dumps(result_package)
                                self.results_outgoing.send(pkl_package)
                                logger.warning("[MAIN] Sent failure reports, unregistering manager")
                            else:
                                logger.debug("[MAIN] Suppressing shutdown due to version incompatibility")

                    else:
                        # Registration has failed.
                        if self.suppress_failure is False:
                            logger.debug("Setting kill event for bad manager")
                            self._kill_event.set()
                            e = BadRegistration(manager, critical=True)
                            result_package = {'task_id': -1,
                                              'exception': self.serializer.serialize(e)}
                            pkl_package = pickle.dumps(result_package)
                            self.results_outgoing.send(pkl_package)
                        else:
                            logger.debug("[MAIN] Suppressing bad registration from manager:{}".format(
                                manager))

                else:
                    self._ready_manager_queue[manager]['last'] = time.time()
                    if message[1] == b'HEARTBEAT':
                        logger.debug("[MAIN] Manager {} sends heartbeat".format(manager))
                        self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE])
                    else:
                        manager_adv = pickle.loads(message[1])
                        logger.debug("[MAIN] Manager {} requested {}".format(manager, manager_adv))
                        self._ready_manager_queue[manager]['free_capacity'].update(manager_adv)
                        self._ready_manager_queue[manager]['free_capacity']['total_workers'] = sum(manager_adv.values())
                        interesting_managers.add(manager)

            # If we had received any requests, check if there are tasks that could be passed

            logger.debug("[MAIN] Managers count (total/interesting): {}/{}".format(
                len(self._ready_manager_queue),
                len(interesting_managers)))

            task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers,
                                                                             self.pending_task_queue,
                                                                             self._ready_manager_queue,
                                                                             scheduler_mode=self.config.scheduler_mode)
            self.total_pending_task_count -= dispatched_task

            for manager in task_dispatch:
                tasks = task_dispatch[manager]
                if tasks:
                    logger.info("[MAIN] Sending task message {} to manager {}".format(tasks, manager))
                    self.task_outgoing.send_multipart([manager, b'', pickle.dumps(tasks)])

            # Receive any results and forward to client
            if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN:
                logger.debug("[MAIN] entering results_incoming section")
                manager, *b_messages = self.results_incoming.recv_multipart()
                if manager not in self._ready_manager_queue:
                    logger.warning("[MAIN] Received a result from a un-registered manager: {}".format(manager))
                else:
                    logger.info("[MAIN] Got {} result items in batch".format(len(b_messages)))
                    for b_message in b_messages:
                        r = pickle.loads(b_message)
                        # logger.debug("[MAIN] Received result for task {} from {}".format(r['task_id'], manager))
                        task_type = self.containers[r['task_id'].split(';')[1]]
                        self._ready_manager_queue[manager]['tasks'][task_type].remove(r['task_id'])
                    self._ready_manager_queue[manager]['total_tasks'] -= len(b_messages)
                    self.results_outgoing.send_multipart(b_messages)
                    logger.debug("[MAIN] Current tasks: {}".format(self._ready_manager_queue[manager]['tasks']))
                logger.debug("[MAIN] leaving results_incoming section")

            # logger.debug("[MAIN] entering bad_managers section")
            bad_managers = [manager for manager in self._ready_manager_queue if
                            time.time() - self._ready_manager_queue[manager]['last'] > self.heartbeat_threshold]
            for manager in bad_managers:
                logger.debug("[MAIN] Last: {} Current: {}".format(self._ready_manager_queue[manager]['last'], time.time()))
                logger.warning("[MAIN] Too many heartbeats missed for manager {}".format(manager))
                e = ManagerLost(manager)
                for task_type in self._ready_manager_queue[manager]['tasks']:
                    for tid in self._ready_manager_queue[manager]['tasks'][task_type]:
                        result_package = {'task_id': tid, 'exception': self.serializer.serialize(e)}
                        pkl_package = pickle.dumps(result_package)
                        self.results_outgoing.send(pkl_package)
                logger.warning("[MAIN] Sent failure reports, unregistering manager")
                self._ready_manager_queue.pop(manager, 'None')
                if manager in interesting_managers:
                    interesting_managers.remove(manager)
            logger.debug("[MAIN] ending one main loop iteration")

            if self._status_request.is_set():
                logger.info("status request response")
                result_package = self.get_status_report()
                pkl_package = pickle.dumps(result_package)
                self.results_outgoing.send(pkl_package)
                logger.info("[MAIN] Sent info response")
                self._status_request.clear()

        delta = time.time() - start
        logger.info("Processed {} tasks in {} seconds".format(count, delta))
        logger.warning("Exiting")

    def get_status_report(self):
        """ Get utilization numbers
        """
        total_cores = 0
        total_mem = 0
        core_hrs = 0
        active_managers = 0
        free_capacity = 0
        outstanding_tasks = self.get_total_tasks_outstanding()
        pending_tasks = self.total_pending_task_count
        num_managers = len(self._ready_manager_queue)
        live_workers = self.get_total_live_workers()
        
        for manager in self._ready_manager_queue:
            total_cores += self._ready_manager_queue[manager]['cores']
            total_mem += self._ready_manager_queue[manager]['mem']
            active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time'])
            core_hrs += (active_dur * total_cores) / 3600
            if self._ready_manager_queue[manager]['active']:
                active_managers += 1
            free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers']

        result_package = {'task_id': -2,
                          'info': {'total_cores': total_cores,
                                   'total_mem' : total_mem,
                                   'new_core_hrs': core_hrs - self.last_core_hr_counter,
                                   'total_core_hrs': round(core_hrs, 2),
                                   'managers': num_managers,
                                   'active_managers': active_managers,
                                   'total_workers': live_workers,
                                   'idle_workers': free_capacity,
                                   'pending_tasks': pending_tasks,
                                   'outstanding_tasks': outstanding_tasks,
                                   'worker_mode': self.config.worker_mode,
                                   'scheduler_mode': self.config.scheduler_mode,
                                   'scaling_enabled': self.config.scaling_enabled,
                                   'mem_per_worker': self.config.mem_per_worker,
                                   'cores_per_worker': self.config.cores_per_worker,
                                   'prefetch_capacity': self.config.prefetch_capacity,
                                   'max_blocks': self.config.provider.max_blocks,
                                   'min_blocks': self.config.provider.min_blocks,
                                   'max_workers_per_node': self.config.max_workers_per_node,
                                   'nodes_per_block': self.config.provider.nodes_per_block
        }}

        self.last_core_hr_counter = core_hrs
        return result_package

    def scale_out(self, blocks=1, task_type=None):
        """Scales out the number of blocks by "blocks"

        Raises:
             NotImplementedError
        """
        r = []
        for i in range(blocks):
            if self.config.provider:
                self._block_counter += 1
                external_block_id = str(self._block_counter)
                if not task_type and self.config.scheduler_mode == 'hard':
                    launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW')
                else:
                    launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type)
                if not task_type:
                    internal_block = self.config.provider.submit(launch_cmd, 1)
                else:
                    internal_block = self.config.provider.submit(launch_cmd, 1, task_type)
                logger.debug("Launched block {}->{}".format(external_block_id, internal_block))
                if not internal_block:
                    raise(ScalingFailed(self.provider.label,
                                        "Attempts to provision nodes via provider has failed"))
                self.blocks[external_block_id] = internal_block
                self.block_id_map[internal_block] = external_block_id
            else:
                logger.error("No execution provider available")
                r = None
        return r

    def scale_in(self, blocks=None, block_ids=[], task_type=None):
        """Scale in the number of active blocks by specified amount.

        Parameters
        ----------
        blocks : int
            # of blocks to terminate

        block_ids : [str.. ]
            List of external block ids to terminate
        """
        if task_type:
            logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type))
            if self.config.scaling_enabled and self.config.provider:
                to_kill, r = self.config.provider.cancel(blocks, task_type)
                logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r))
                for job in to_kill:
                    logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job))
                    block_id = self.block_id_map[job]
                    logger.info("[scale_in] Holding block {}".format(block_id))
                    self._hold_block(block_id)
                    self.blocks.pop(block_id)
                return r

        if block_ids:
            block_ids_to_kill = block_ids
        else:
            block_ids_to_kill = list(self.blocks.keys())[:blocks]

        # Try a polite terminate
        # TODO : Missing logic to hold blocks
        for block_id in block_ids_to_kill:
            self._hold_block(block_id)

        # Now kill via provider
        to_kill = [self.blocks.pop(bid) for bid in block_ids_to_kill]

        if self.config.scaling_enabled and self.config.provider:
            r = self.config.provider.cancel(to_kill)

        return r

    def provider_status(self):
        """ Get status of all blocks from the provider
        """
        status = []
        if self.config.provider:
            logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values())))
            status = self.config.provider.status(list(self.blocks.values()))
            logger.debug("[MAIN] The status is {}".format(status))

        return status
Beispiel #20
0
    def start_endpoint(self, name, endpoint_uuid, endpoint_config):
        self.name = name

        endpoint_dir = os.path.join(self.funcx_dir, self.name)
        endpoint_json = os.path.join(endpoint_dir, "endpoint.json")

        # These certs need to be recreated for every registration
        keys_dir = os.path.join(endpoint_dir, "certificates")
        os.makedirs(keys_dir, exist_ok=True)
        client_public_file, client_secret_file = zmq.auth.create_certificates(
            keys_dir, "endpoint")
        client_public_key, _ = zmq.auth.load_certificate(client_public_file)
        client_public_key = client_public_key.decode("utf-8")

        # This is to ensure that at least 1 executor is defined
        if not endpoint_config.config.executors:
            raise Exception(
                f"Endpoint config file at {endpoint_dir} is missing "
                "executor definitions")

        funcx_client_options = {
            "funcx_service_address":
            endpoint_config.config.funcx_service_address,
            "check_endpoint_version": True,
        }
        funcx_client = FuncXClient(**funcx_client_options)

        endpoint_uuid = self.check_endpoint_json(endpoint_json, endpoint_uuid)

        log.info(f"Starting endpoint with uuid: {endpoint_uuid}")

        pid_file = os.path.join(endpoint_dir, "daemon.pid")
        pid_check = self.check_pidfile(pid_file)
        # if the pidfile exists, we should return early because we don't
        # want to attempt to create a new daemon when one is already
        # potentially running with the existing pidfile
        if pid_check["exists"]:
            if pid_check["active"]:
                log.info("Endpoint is already active")
                sys.exit(-1)
            else:
                log.info(
                    "A prior Endpoint instance appears to have been terminated without "
                    "proper cleanup. Cleaning up now.")
                self.pidfile_cleanup(pid_file)

        results_ack_handler = ResultsAckHandler(endpoint_dir=endpoint_dir)

        try:
            results_ack_handler.load()
            results_ack_handler.persist()
        except Exception:
            log.exception(
                "Caught exception while attempting load and persist of outstanding "
                "results")
            sys.exit(-1)

        # Create a daemon context
        # If we are running a full detached daemon then we will send the output to
        # log files, otherwise we can piggy back on our stdout
        if endpoint_config.config.detach_endpoint:
            stdout = open(
                os.path.join(endpoint_dir, endpoint_config.config.stdout),
                "a+")
            stderr = open(
                os.path.join(endpoint_dir, endpoint_config.config.stderr),
                "a+")
        else:
            stdout = sys.stdout
            stderr = sys.stderr

        try:
            context = daemon.DaemonContext(
                working_directory=endpoint_dir,
                umask=0o002,
                pidfile=daemon.pidfile.PIDLockFile(pid_file),
                stdout=stdout,
                stderr=stderr,
                detach_process=endpoint_config.config.detach_endpoint,
            )

        except Exception:
            log.exception(
                "Caught exception while trying to setup endpoint context dirs")
            sys.exit(-1)

        # place registration after everything else so that the endpoint will
        # only be registered if everything else has been set up successfully
        reg_info = None
        try:
            reg_info = register_endpoint(funcx_client, endpoint_uuid,
                                         endpoint_dir, self.name)
        # if the service sends back an error response, it will be a FuncxResponseError
        except FuncxResponseError as e:
            # an example of an error that could conceivably occur here would be
            # if the service could not register this endpoint with the forwarder
            # because the forwarder was unreachable
            if e.http_status_code >= 500:
                log.exception(
                    "Caught exception while attempting endpoint registration")
                log.critical(
                    "Endpoint registration will be retried in the new endpoint daemon "
                    "process. The endpoint will not work until it is successfully "
                    "registered.")
            else:
                raise e
        # if the service has an unexpected internal error and is unable to send
        # back a FuncxResponseError
        except GlobusAPIError as e:
            if e.http_status >= 500:
                log.exception(
                    "Caught exception while attempting endpoint registration")
                log.critical(
                    "Endpoint registration will be retried in the new endpoint daemon "
                    "process. The endpoint will not work until it is successfully "
                    "registered.")
            else:
                raise e
        # if the service is unreachable due to a timeout or connection error
        except NetworkError as e:
            # the output of a NetworkError exception is huge and unhelpful, so
            # it seems better to just stringify it here and get a concise error
            log.exception(
                f"Caught exception while attempting endpoint registration: {e}"
            )
            log.critical(
                "funcx-endpoint is unable to reach the funcX service due to a "
                "NetworkError \n"
                "Please make sure that the funcX service address you provided is "
                "reachable \n"
                "and then attempt restarting the endpoint")
            exit(-1)
        except Exception:
            raise

        if reg_info:
            log.info("Launching endpoint daemon process")
        else:
            log.critical(
                "Launching endpoint daemon process with errors noted above")

        # NOTE
        # It's important that this log is emitted before we enter the daemon context
        # because daemonization closes down everything, a log message inside the
        # context won't write the currently configured loggers
        logfile = os.path.join(endpoint_dir, "endpoint.log")
        log.info(
            "Logging will be reconfigured for the daemon. logfile=%s , debug=%s",
            logfile,
            self.debug,
        )

        with context:
            setup_logging(logfile=logfile,
                          debug=self.debug,
                          console_enabled=False)
            self.daemon_launch(
                endpoint_uuid,
                endpoint_dir,
                keys_dir,
                endpoint_config,
                reg_info,
                funcx_client_options,
                results_ack_handler,
            )
from funcx.sdk.client import FuncXClient
fxc = FuncXClient()


def funcx_test():
    while True:
        print("Viana")


func_uuid = fxc.register_function(funcx_test)
tutorial_endpoint = '70d29c21-66c3-4ba8-98fc-91490b522699'  # Public tutorial endpoint
res = fxc.run(endpoint_id=tutorial_endpoint, function_id=func_uuid)
funcx_test()
Beispiel #22
0
    return sum(event)


"""
@funcx.register(description="...")
def sum_yadu_new01(event):
    return sum(event)
"""


def test(fxc, ep_id):

    fn_uuid = fxc.register_function(
        sum_yadu_new01,
        ep_id,  # TODO: We do not need ep id here
        description="New sum function defined without string spec",
    )
    print("FN_UUID : ", fn_uuid)

    res = fxc.run([1, 2, 3, 99], endpoint_id=ep_id, function_id=fn_uuid)
    print(res)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--endpoint", required=True)
    args = parser.parse_args()

    fxc = FuncXClient()
    test(fxc, args.endpoint)
Beispiel #23
0
from funcx.sdk.client import FuncXClient

if __name__ == "__main__":

    fxc = FuncXClient()
    print(fxc)

    fxc.register_endpoint("foobar", None)
Beispiel #24
0
    def __init__(
        self,
        config,
        client_address="127.0.0.1",
        interchange_address="127.0.0.1",
        client_ports: Tuple[int, int, int] = (50055, 50056, 50057),
        launch_cmd=None,
        logdir=".",
        endpoint_id=None,
        keys_dir=".curve",
        suppress_failure=True,
        endpoint_dir=".",
        endpoint_name="default",
        reg_info=None,
        funcx_client_options=None,
        results_ack_handler=None,
    ):
        """
        Parameters
        ----------
        config : funcx.Config object
             Funcx config object that describes how compute should be provisioned

        client_address : str
             The ip address at which the parsl client can be reached.
             Default: "127.0.0.1"

        interchange_address : str
             The ip address at which the workers will be able to reach the Interchange.
             Default: "127.0.0.1"

        client_ports : Tuple[int, int, int]
             The ports at which the client can be reached

        launch_cmd : str
             TODO : update

        logdir : str
             Parsl log directory paths. Logs and temp files go here. Default: '.'

        keys_dir : str
             Directory from where keys used for communicating with the funcX
             service (forwarders) are stored

        endpoint_id : str
             Identity string that identifies the endpoint to the broker

        suppress_failure : Bool
             When set to True, the interchange will attempt to suppress failures.
             Default: False

        endpoint_dir : str
             Endpoint directory path to store registration info in

        endpoint_name : str
             Name of endpoint

        reg_info : Dict
             Registration info from initial registration on endpoint start, if it
             succeeded

        funcx_client_options : Dict
             FuncXClient initialization options
        """
        self.logdir = logdir
        log.info(
            "Initializing EndpointInterchange process with Endpoint ID: {}".
            format(endpoint_id))
        self.config = config
        log.info(f"Got config: {config}")

        self.client_address = client_address
        self.interchange_address = interchange_address
        self.client_ports = client_ports
        self.suppress_failure = suppress_failure

        self.endpoint_dir = endpoint_dir
        self.endpoint_name = endpoint_name

        if funcx_client_options is None:
            funcx_client_options = {}
        self.funcx_client = FuncXClient(**funcx_client_options)

        self.initial_registration_complete = False
        if reg_info:
            self.initial_registration_complete = True
            self.apply_reg_info(reg_info)

        self.heartbeat_period = self.config.heartbeat_period
        self.heartbeat_threshold = self.config.heartbeat_threshold
        # initalize the last heartbeat time to start the loop
        self.last_heartbeat = time.time()
        self.keys_dir = keys_dir
        self.serializer = FuncXSerializer()

        self.pending_task_queue = Queue()
        self.containers = {}
        self.total_pending_task_count = 0

        self._quiesce_event = threading.Event()
        self._kill_event = threading.Event()

        self.results_ack_handler = results_ack_handler

        log.info(f"Interchange address is {self.interchange_address}")

        self.endpoint_id = endpoint_id

        self.current_platform = {
            "parsl_v":
            PARSL_VERSION,
            "python_v":
            "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor,
                              sys.version_info.micro),
            "libzmq_v":
            zmq.zmq_version(),
            "pyzmq_v":
            zmq.pyzmq_version(),
            "os":
            platform.system(),
            "hname":
            platform.node(),
            "funcx_sdk_version":
            funcx_sdk_version,
            "funcx_endpoint_version":
            funcx_endpoint_version,
            "registration":
            self.endpoint_id,
            "dir":
            os.getcwd(),
        }

        log.info(f"Platform info: {self.current_platform}")
        try:
            self.load_config()
        except Exception:
            log.exception("Caught exception")
            raise

        self.tasks = set()
        self.task_status_deltas = {}

        self._test_start = False
Beispiel #25
0
from funcx.sdk.client import FuncXClient
import time

fx = FuncXClient()


def test_batch1(a, b, c=2, d=2):
    return a + b + c + d


def test_batch2(a, b, c=2, d=2):
    return a * b * c * d


def test_batch3(a, b, c=2, d=2):
    return a + 2 * b + 3 * c + 4 * d


funcs = [test_batch1, test_batch2, test_batch3]
func_ids = []
for func in funcs:
    func_ids.append(fx.register_function(func, description='test'))

ep_id = '4b116d3c-1703-4f8f-9f6f-39921e5864df'
print("FN_UUID : ", func_ids)

start = time.time()
task_count = 5
batch = fx.create_batch()
for func_id in func_ids:
    for i in range(task_count):
Beispiel #26
0
    def __init__(self, dlh_authorizer: Optional[GlobusAuthorizer] = None,
                 search_authorizer: Optional[GlobusAuthorizer] = None,
                 fx_authorizer: Optional[GlobusAuthorizer] = None,
                 openid_authorizer: Optional[GlobusAuthorizer] = None,
                 http_timeout: Optional[int] = None,
                 force_login: bool = False, **kwargs):
        """Initialize the client

        Args:
            http_timeout (int): Timeout for any call to service in seconds. (default is no timeout)
            force_login (bool): Whether to force a login to get new credentials.
                A login will always occur if ``dlh_authorizer`` or ``search_client``
                are not provided.
            no_local_server (bool): Disable spinning up a local server to automatically
                copy-paste the auth code. THIS IS REQUIRED if you are on a remote server.
                When used locally with no_local_server=False, the domain is localhost with
                a randomly chosen open port number.
                **Default**: ``True``.
            no_browser (bool): Do not automatically open the browser for the Globus Auth URL.
                Display the URL instead and let the user navigate to that location manually.
                **Default**: ``True``.
            dlh_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with DLHub.
                If ``None``, will be created from your account's credentials.
            search_authorizer (:class:`GlobusAuthorizer <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authenticated SearchClient to communicate with Globus Search.
                If ``None``, will be created from your account's credentials.
            fx_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with funcX.
                If ``None``, will be created from your account's credentials.
            openid_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with OpenID.
                If ``None``, will be created from your account's credentials.
        Keyword arguments are the same as for :class:`BaseClient <globus_sdk.base.BaseClient>`.
        """

        authorizers = [dlh_authorizer, search_authorizer, openid_authorizer, fx_authorizer]
        # Get authorizers through Globus login if any are not provided
        if not all(a is not None for a in authorizers):
            # If some but not all were provided, warn the user they could be making a mistake
            if any(a is not None for a in authorizers):
                logger.warning('You have defined some of the authorizers but not all. DLHub is falling back to login. '
                               'You must provide authorizers for DLHub, Search, OpenID, FuncX.')

            fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
            auth_res = login(services=["search", "dlhub",
                                       fx_scope, "openid"],
                             app_name="DLHub_Client",
                             make_clients=False,
                             client_id=CLIENT_ID,
                             clear_old_tokens=force_login,
                             token_dir=_token_dir,
                             no_local_server=kwargs.get("no_local_server", True),
                             no_browser=kwargs.get("no_browser", True))

            # Unpack the authorizers
            dlh_authorizer = auth_res["dlhub"]
            fx_authorizer = auth_res[fx_scope]
            openid_authorizer = auth_res['openid']
            search_authorizer = auth_res['search']

        # Define the subclients needed by the service
        self._fx_client = FuncXClient(fx_authorizer=fx_authorizer,
                                      search_authorizer=search_authorizer,
                                      openid_authorizer=openid_authorizer,
                                      no_local_server=kwargs.get("no_local_server", True),
                                      no_browser=kwargs.get("no_browser", True))
        self._search_client = globus_sdk.SearchClient(authorizer=search_authorizer,
                                                      http_timeout=5 * 60)

        # funcX endpoint to use
        self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000'
        self.fx_cache = {}
        super(DLHubClient, self).__init__("DLHub", environment='dlhub',
                                          authorizer=dlh_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=DLHUB_SERVICE_ADDRESS,
                                          **kwargs)
Beispiel #27
0
class DLHubClient(BaseClient):
    """Main class for interacting with the DLHub service

    Holds helper operations for performing common tasks with the DLHub service. For example,
    `get_servables` produces a list of all servables registered with DLHub.

    For most cases, we recommend creating a new DLHubClient by calling ``DLHubClient.login``.
    This operation will check if you have saved any credentials to disk before using the CLI or SDK
    and, if not, get new credentials and save them for later use.
    For cases where disk access is unacceptable, you can create the client by creating an authorizer
    following the
    `tutorial for the Globus SDK <https://globus-sdk-python.readthedocs.io/en/stable/tutorial/>`_
    and providing that authorizer to the initializer (e.g., ``DLHubClient(auth)``)"""
    def __init__(self,
                 dlh_authorizer=None,
                 search_client=None,
                 http_timeout=None,
                 force_login=False,
                 fx_authorizer=None,
                 **kwargs):
        """Initialize the client

        Args:
            dlh_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with DLHub.
                If ``None``, will be created.
            search_client (:class:`SearchClient <globus_sdk.SearchClient>`):
                An authenticated SearchClient to communicate with Globus Search.
                If ``None``, will be created.
            http_timeout (int): Timeout for any call to service in seconds. (default is no timeout)
            force_login (bool): Whether to force a login to get new credentials.
                A login will always occur if ``dlh_authorizer`` or ``search_client``
                are not provided.
            no_local_server (bool): Disable spinning up a local server to automatically
                copy-paste the auth code. THIS IS REQUIRED if you are on a remote server.
                When used locally with no_local_server=False, the domain is localhost with
                a randomly chosen open port number.
                **Default**: ``True``.
            fx_authorizer (:class:`GlobusAuthorizer
                            <globus_sdk.authorizers.base.GlobusAuthorizer>`):
                An authorizer instance used to communicate with funcX.
                If ``None``, will be created.
            no_browser (bool): Do not automatically open the browser for the Globus Auth URL.
                Display the URL instead and let the user navigate to that location manually.
                **Default**: ``True``.
        Keyword arguments are the same as for BaseClient.
        """
        if force_login or not dlh_authorizer or not search_client or not fx_authorizer:

            fx_scope = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all"
            auth_res = login(services=["search", "dlhub", fx_scope, "openid"],
                             app_name="DLHub_Client",
                             client_id=CLIENT_ID,
                             clear_old_tokens=force_login,
                             token_dir=_token_dir,
                             no_local_server=kwargs.get(
                                 "no_local_server", True),
                             no_browser=kwargs.get("no_browser", True))
            # openid_authorizer = auth_res["openid"]
            dlh_authorizer = auth_res["dlhub"]
            fx_authorizer = auth_res[fx_scope]
            self._search_client = auth_res["search"]

        self._fx_client = FuncXClient(
            force_login=force_login,
            no_local_server=kwargs.get("no_local_server", True),
            no_browser=kwargs.get("no_browser", True),
            funcx_service_address='https://api.funcx.org/v1',
        )

        # funcX endpoint to use
        self.fx_endpoint = '86a47061-f3d9-44f0-90dc-56ddc642c000'
        # self.fx_endpoint = '2c92a06a-015d-4bfa-924c-b3d0c36bdad7'
        self.fx_serializer = FuncXSerializer()
        self.fx_cache = {}
        super(DLHubClient, self).__init__("DLHub",
                                          environment='dlhub',
                                          authorizer=dlh_authorizer,
                                          http_timeout=http_timeout,
                                          base_url=DLHUB_SERVICE_ADDRESS,
                                          **kwargs)

    def logout(self):
        """Remove credentials from your local system"""
        logout()

    @property
    def query(self):
        """Access a query of the DLHub Search repository"""
        return DLHubSearchHelper(search_client=self._search_client)

    def get_username(self):
        """Get the username associated with the current credentials"""

        res = self.get('/namespaces')
        return res.data['namespace']

    def get_servables(self, only_latest_version=True):
        """Get all of the servables available in the service

        Args:
            only_latest_version (bool): Whether to only return the latest version of each servable
        Returns:
            ([list]) Complete metadata for all servables found in DLHub
        """

        # Get all of the servables
        results, info = self.query.match_field('dlhub.type', 'servable')\
            .add_sort('dlhub.owner', ascending=True).add_sort('dlhub.name', ascending=False)\
            .add_sort('dlhub.publication_date', ascending=False).search(info=True)
        if info['total_query_matches'] > SEARCH_LIMIT:
            raise RuntimeError(
                'DLHub contains more servables than we can return in one entry. '
                'DLHub SDK needs to be updated.')

        if only_latest_version:
            # Sort out only the most recent versions (they come first in the sorted list
            names = set()
            output = []
            for r in results:
                name = r['dlhub']['shorthand_name']
                if name not in names:
                    names.add(name)
                    output.append(r)
            results = output

        # Add these to the cache
        for r in results:
            self.fx_cache[r['dlhub']
                          ['shorthand_name']] = r['dlhub']['funcx_id']

        return results

    def list_servables(self):
        """Get a list of the servables available in the service

        Returns:
            [string]: List of all servable names in username/servable_name format
        """

        servables = self.get_servables(only_latest_version=True)
        return [x['dlhub']['shorthand_name'] for x in servables]

    def get_task_status(self, task_id):
        """Get the status of a DLHub task.

        Args:
            task_id (string): UUID of the task
        Returns:
            dict: status block containing "status" key.
        """

        r = self._fx_client.get_task(task_id)
        return r

    def describe_servable(self, name):
        """Get the description for a certain servable

        Args:
            name (string): DLHub name of the servable of the form <user>/<servable_name>
        Returns:
            dict: Summary of the servable
        """
        split_name = name.split('/')
        if len(split_name) < 2:
            raise AttributeError(
                'Please enter name in the form <user>/<servable_name>')

        # Create a query for a single servable
        query = self.query.match_servable('/'.join(split_name[1:]))\
            .match_owner(split_name[0]).add_sort("dlhub.publication_date", False)\
            .search(limit=1)

        # Raise error if servable is not found
        if len(query) == 0:
            raise AttributeError('No such servable: {}'.format(name))
        return query[0]

    def describe_methods(self, name, method=None):
        """Get the description for the method(s) of a certain servable

        Args:
            name (string): DLHub name of the servable of the form <user>/<servable_name>
            method (string): Optional: Name of the method
        Returns:
             dict: Description of a certain method if ``method`` provided, all methods
                if the method name was not provided.
        """

        metadata = self.describe_servable(name)
        return get_method_details(metadata, method)

    def run(self,
            name,
            inputs,
            asynchronous=False,
            async_wait=5,
            timeout: Optional[float] = None) -> Union[Any, DLHubFuture]:
        """Invoke a DLHub servable

        Args:
            name (string): DLHub name of the servable of the form <user>/<servable_name>
            inputs: Data to be used as input to the function. Can be a string of file paths or URLs
            asynchronous (bool): Whether to return from the function immediately or
                wait for the execution to finish.
            async_wait (float): How many seconds to wait between checking async status
            timeout (float): How long to wait for a result to return.
                Only used for synchronous calls
        Returns:
            Results of running the servable. If asynchronous, then a DLHubFuture holding the result
        """

        if name not in self.fx_cache:
            # Look it up and add it to the cache, this will raise an exception if not found.
            serv = self.describe_servable(name)
            self.fx_cache.update({name: serv['dlhub']['funcx_id']})

        funcx_id = self.fx_cache[name]
        payload = {'data': inputs}
        task_id = self._fx_client.run(payload,
                                      endpoint_id=self.fx_endpoint,
                                      function_id=funcx_id)

        # Return the result
        future = DLHubFuture(self, task_id, async_wait)
        return future.result(timeout=timeout) if not asynchronous else future

    def run_serial(self, servables, inputs, async_wait=5):
        """Invoke each servable in a serial pipeline.
        This function accepts a list of servables and will run each one,
        passing the output of one as the input to the next.

        Args:
             servables (list): A list of servable strings
             inputs: Data to pass to the first servable
             asycn_wait (float): Seconds to wait between status checks
        Returns:
            Results of running the servable
        """
        if not isinstance(servables, list):
            print("run_serial requires a list of servables to invoke.")

        serv_data = inputs
        for serv in servables:
            serv_data = self.run(serv, serv_data, async_wait=async_wait)
        return serv_data

    def get_result(self, task_id, verbose=False):
        """Get the result of a task_id

        Args:
            task_id str: The task's uuid
            verbose bool: whether or not to return the full dlhub response
        Returns:
            Reult of running the servable
        """

        result = self._fx_client.get_result(task_id)
        if isinstance(result, tuple) and not verbose:
            result = result[0]
        return result

    def publish_servable(self, model):
        """Submit a servable to DLHub

        If this servable has not been published before, it will be assigned a unique identifier.

        If it has been published before (DLHub detects if it has an identifier), then DLHub
        will update the servable to the new version.

        Args:
            model (BaseMetadataModel): Servable to be submitted
        Returns:
            (string): Task ID of this submission, used for checking for success
        """

        # Get the metadata
        metadata = model.to_dict(simplify_paths=True)

        # Mark the method used to submit the model
        metadata['dlhub']['transfer_method'] = {'POST': 'file'}

        # Validate against the servable schema
        validate_against_dlhub_schema(metadata, 'servable')

        # Wipe the fx cache so we don't keep reusing an old servable
        self.clear_funcx_cache()

        # Get the data to be submitted as a ZIP file
        fp, zip_filename = mkstemp('.zip')
        os.close(fp)
        os.unlink(zip_filename)
        try:
            model.get_zip_file(zip_filename)

            # Get the authorization headers
            headers = {}
            self.authorizer.set_authorization_header(headers)

            # Submit data to DLHub service
            with open(zip_filename, 'rb') as zf:
                reply = requests.post(slash_join(self.base_url, 'publish'),
                                      headers=headers,
                                      files={
                                          'json':
                                          ('dlhub.json', json.dumps(metadata),
                                           'application/json'),
                                          'file': ('servable.zip', zf,
                                                   'application/octet-stream')
                                      })

            # Return the task id
            if reply.status_code != 200:
                raise Exception(reply.text)
            return reply.json()['task_id']
        finally:
            os.unlink(zip_filename)

    def publish_repository(self, repository):
        """Submit a repository to DLHub for publication

        Args:
            repository (string): Repository to publish
        Returns:
            (string): Task ID of this submission, used for checking for success
        """

        # Publish to DLHub
        metadata = {"repository": repository}

        # Wipe the fx cache so we don't keep reusing an old servable
        self.clear_funcx_cache()

        response = self.post('publish_repo', json_body=metadata)

        task_id = response.data['task_id']
        return task_id

    def search(self, query, advanced=False, limit=None, only_latest=True):
        """Query the DLHub servable library

        By default, the query is used as a simple plaintext search of all model metadata.
        Optionally, you can provided an advanced query on any of the indexed fields in
        the DLHub model metadata by setting :code:`advanced=True` and following the guide for
        constructing advanced queries found in the
        `Globus Search documentation <https://docs.globus.org/api/search/search/#query_syntax>`_.

        Args:
             query (string): Query to be performed
             advanced (bool): Whether to perform an advanced query
             limit (int): Maximum number of entries to return
             only_latest (bool): Whether to return only the latest version of the model
        Returns:
            ([dict]): All records matching the search query
        """

        results = self.query.search(query, advanced=advanced, limit=limit)
        return filter_latest(results) if only_latest else results

    def search_by_servable(self,
                           servable_name=None,
                           owner=None,
                           version=None,
                           only_latest=True,
                           limit=None,
                           get_info=False):
        """Search by the ownership, name, or version of a servable

        Args:
            servable_name (str): The name of the servable. **Default**: None, to match
                    all servable names.
            owner (str): The name of the owner of the servable. **Default**: ``None``,
                    to match all owners.
            version (int): Model version, which corresponds to the date when the
                servable was published. **Default**: ``None``, to match all versions.
            only_latest (bool): When ``True``, will only return the latest version
                    of each servable. When ``False``, will return all matching versions.
                    **Default**: ``True``.
            limit (int): The maximum number of results to return.
                    **Default:** ``None``, for no limit.
            get_info (bool): If ``False``, search will return a list of the results.
                    If ``True``, search will return a tuple containing the results list
                    and other information about the query.
                    **Default:** ``False``.

        Returns:
            If ``info`` is ``False``, *list*: The search results.
            If ``info`` is ``True``, *tuple*: The search results,
            and a dictionary of query information.
        """
        if not servable_name and not owner and not version:
            raise ValueError(
                "One of 'servable_name', 'owner', or 'publication_date' is required."
            )

        # Perform the query
        results, info = (self.query.match_servable(
            servable_name=servable_name, owner=owner,
            publication_date=version).search(limit=limit, info=True))

        # Filter out the latest models
        if only_latest:
            results = filter_latest(results)

        if get_info:
            return results, info
        return results

    def search_by_authors(self,
                          authors,
                          match_all=True,
                          limit=None,
                          only_latest=True):
        """Execute a search for servables from certain authors.

        Authors in DLHub may be different than the owners of the servable and generally are
        the people who developed functionality of a certain servable (e.g., the creators
        of the machine learning model used in a servable).

        If you want to search by ownership, see :meth:`search_by_servable`

        Args:
            authors (str or list of str): The authors to match. Names must be in
                "Family Name, Given Name" format
            match_all (bool): If ``True``, will require all authors be on any results.
                    If ``False``, will only require one author to be in results.
                    **Default**: ``True``.
            limit (int): The maximum number of results to return.
                    **Default:** ``None``, for no limit.
            only_latest (bool): When ``True``, will only return the latest version
                    of each servable. When ``False``, will return all matching versions.
                    **Default**: ``True``.

        Returns:
            [dict]: List of servables from the desired authors
        """
        results = self.query.match_authors(
            authors, match_all=match_all).search(limit=limit)
        return filter_latest(results) if only_latest else results

    def search_by_related_doi(self, doi, limit=None, only_latest=True):
        """Get all of the servables associated with a certain publication

        Args:
            doi (string): DOI of related paper
            limit (int): Maximum number of results to return
            only_latest (bool): Whether to return only the most recent version of the model
        Returns:
            [dict]: List of servables from the requested paper
        """

        results = self.query.match_doi(doi).search(limit=limit)
        return filter_latest(results) if only_latest else results

    def clear_funcx_cache(self, servable=None):
        """Remove functions from the cache. Either remove a specific servable or wipe the whole cache.

        Args:
            Servable: str
                The name of the servable to remove. Default None
        """

        if servable:
            del (self.fx_cache[servable])
        else:
            self.fx_cache = {}

        return self.fx_cache
Beispiel #28
0
with open("/home/zzli/.funcx/{}/config.py".format(args.endpoint_name),
          'w') as f:
    f.write(config)

# Start the endpoint
endpoint_name = args.endpoint_name
cmd = "funcx-endpoint start {}".format(endpoint_name)
try:
    subprocess.call(cmd, shell=True)
except Exception as e:
    print(e)
print("Started the endpoint {}".format(args.endpoint_id))
print("Wating 60 seconds for the endpoint to start")
time.sleep(60)

fxc = FuncXClient()
func_uuid = fxc.register_function(dlhub_test, description="A sum function")
print("The functoin uuid is {}".format(func_uuid))

fxs = FuncXSerializer()


def test(tasks=1,
         data=[1],
         timeout=float('inf'),
         endpoint_id=None,
         function_id=None,
         poll=0.1):

    start = time.time()
    res = fxc.run(data, endpoint_id=endpoint_id, function_id=function_id)
import sys
from time import sleep
import json
NUM_RUNS = 70
import requests
from funcx.sdk.client import FuncXClient

pyhf_endpoint = 'a727e996-7836-4bec-9fa2-44ebf7ca5302'

fxc = FuncXClient()
fxc.max_requests = 200


def prepare_workspace(data):
    import pyhf
    w = pyhf.Workspace(data)
    return w


prepare_func = fxc.register_function(prepare_workspace)


def infer_hypotest(w, metadata, doc):
    import pyhf
    import time

    tick = time.time()
    m = w.model(patches=[doc],
                modifier_settings={
                    "normsys": {
                        "interpcode": "code4"
Beispiel #30
0
    task_ids = fxc.map_run(list(range(task_count)),
                           endpoint_id=ep_id,
                           function_id=fn_uuid)
    delta = time.time() - start
    print(f"Time to launch {task_count} tasks: {delta:8.3f} s")
    print(f"Got {len(task_ids)} tasks_ids ")

    for _i in range(3):
        x = fxc.get_batch_result(task_ids)
        complete_count = sum(1 for t in task_ids
                             if t in x and x[t].get("pending", False))
        print(f"Batch status : {complete_count}/{len(task_ids)} complete")
        if complete_count == len(task_ids):
            break
        time.sleep(2)

    delta = time.time() - start
    print(f"Time to complete {task_count} tasks: {delta:8.3f} s")
    print(f"Throughput : {task_count / delta:8.3f} Tasks/s")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--endpoint", required=True)
    parser.add_argument("-c", "--count", default="10")
    args = parser.parse_args()

    print("FuncX version : ", funcx.__version__)
    fxc = FuncXClient(funcx_service_address="https://dev.funcx.org/api/v1")
    test(fxc, args.endpoint, task_count=int(args.count))